From 8cedcdbd3dd343dc0b79b6304f78a72c7f527b83 Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Wed, 18 Mar 2026 16:07:09 +0000 Subject: [PATCH] feat(cli): polish UX and speed up local workflows --- docs/developer/profiling.md | 72 ++ roar/application/label_rendering.py | 37 + roar/application/labels.py | 32 +- roar/application/publish/collection.py | 30 +- roar/application/publish/job_preparation.py | 10 +- roar/application/publish/lineage.py | 142 ++++ .../application/publish/register_execution.py | 120 ++- .../publish/register_preparation.py | 11 +- .../publish/register_preview_jobs.py | 112 +++ roar/application/publish/service.py | 506 ++++++++++-- roar/application/query/__init__.py | 117 ++- roar/application/query/label.py | 3 +- roar/application/query/log.py | 13 +- roar/application/query/requests.py | 5 + roar/application/query/show.py | 126 ++- roar/application/query/status.py | 55 +- roar/cli/__init__.py | 110 ++- roar/cli/commands/auth.py | 27 +- roar/cli/commands/dag.py | 3 +- roar/cli/commands/init.py | 5 +- roar/cli/commands/label.py | 14 +- roar/cli/commands/lineage.py | 3 +- roar/cli/commands/log.py | 10 +- roar/cli/commands/put.py | 54 +- roar/cli/commands/register.py | 35 +- roar/cli/commands/show.py | 77 +- roar/cli/commands/status.py | 8 +- roar/core/bootstrap.py | 18 +- roar/core/models/__init__.py | 203 ++--- roar/db/__init__.py | 98 +-- roar/db/context.py | 233 +++--- roar/db/query_context.py | 729 ++++++++++++++++++ roar/db/repositories/session.py | 39 +- roar/db/step_priority.py | 91 +++ roar/execution/provenance/data_loader.py | 28 +- roar/integrations/config/__init__.py | 127 ++- roar/integrations/config/access.py | 30 +- roar/integrations/config/loader.py | 15 +- roar/integrations/config/raw.py | 196 +++++ roar/integrations/glaas/__init__.py | 79 +- roar/integrations/glaas/client.py | 5 +- .../glaas/registration/__init__.py | 43 +- roar/presenters/__init__.py | 39 +- roar/presenters/run_report.py | 10 +- roar/presenters/show_renderer.py | 2 +- scripts/profile_roar.py | 641 +++++++++++++++ tests/application/publish/test_collection.py | 49 ++ tests/application/publish/test_lineage.py | 50 ++ .../publish/test_register_preparation.py | 9 +- tests/application/publish/test_service.py | 120 ++- tests/application/query/test_log.py | 18 +- tests/application/query/test_show.py | 133 +++- tests/application/query/test_status.py | 22 +- tests/integration/test_put_cli_integration.py | 8 +- .../integration/test_register_dry_run_cli.py | 8 +- tests/integrations/config/test_access.py | 5 +- tests/integrations/glaas/test_client.py | 8 + tests/unit/put/test_put_cli.py | 24 +- tests/unit/test_auth_cli.py | 45 ++ tests/unit/test_bootstrap_config_path.py | 131 ++++ tests/unit/test_cli_init.py | 43 ++ tests/unit/test_cli_query_errors.py | 67 ++ tests/unit/test_cli_registry.py | 29 + tests/unit/test_core_models_lazy_imports.py | 80 ++ tests/unit/test_db_lazy_imports.py | 111 +++ tests/unit/test_presenters_lazy_imports.py | 76 ++ tests/unit/test_profile_roar.py | 33 + .../unit/test_publish_service_lazy_imports.py | 65 ++ tests/unit/test_put_cli.py | 103 +++ tests/unit/test_query_db_context.py | 252 ++++++ tests/unit/test_query_lazy_imports.py | 81 ++ tests/unit/test_raw_config.py | 61 ++ tests/unit/test_register_cli.py | 92 ++- .../test_register_execution_lazy_imports.py | 58 ++ tests/unit/test_register_preview_jobs.py | 86 +++ .../test_register_preview_lazy_imports.py | 95 +++ tests/unit/test_run_report.py | 84 ++ tests/unit/test_tracer_data_loader.py | 6 +- 78 files changed, 5460 insertions(+), 955 deletions(-) create mode 100644 docs/developer/profiling.md create mode 100644 roar/application/label_rendering.py create mode 100644 roar/application/publish/register_preview_jobs.py create mode 100644 roar/db/query_context.py create mode 100644 roar/db/step_priority.py create mode 100644 roar/integrations/config/raw.py create mode 100644 scripts/profile_roar.py create mode 100644 tests/unit/test_auth_cli.py create mode 100644 tests/unit/test_bootstrap_config_path.py create mode 100644 tests/unit/test_cli_init.py create mode 100644 tests/unit/test_cli_query_errors.py create mode 100644 tests/unit/test_core_models_lazy_imports.py create mode 100644 tests/unit/test_db_lazy_imports.py create mode 100644 tests/unit/test_presenters_lazy_imports.py create mode 100644 tests/unit/test_profile_roar.py create mode 100644 tests/unit/test_publish_service_lazy_imports.py create mode 100644 tests/unit/test_put_cli.py create mode 100644 tests/unit/test_query_db_context.py create mode 100644 tests/unit/test_query_lazy_imports.py create mode 100644 tests/unit/test_raw_config.py create mode 100644 tests/unit/test_register_execution_lazy_imports.py create mode 100644 tests/unit/test_register_preview_jobs.py create mode 100644 tests/unit/test_register_preview_lazy_imports.py create mode 100644 tests/unit/test_run_report.py diff --git a/docs/developer/profiling.md b/docs/developer/profiling.md new file mode 100644 index 00000000..7bec308d --- /dev/null +++ b/docs/developer/profiling.md @@ -0,0 +1,72 @@ +# Profiling Roar + +`roar` already has targeted benchmarks under `tests/benchmarks/`. The profiling harness in +[`scripts/profile_roar.py`](/home/trevor/dev/roar-cli-polish/scripts/profile_roar.py) adds a +repeatable way to capture wall-time summaries and Python hot spots for representative local +workflows. + +## What It Profiles + +- top-level CLI startup: `roar --help` +- a simple local `roar run` +- active-session query commands: `status` and `show --session` +- local publish flows without remote side effects: `register --dry-run` and `put --dry-run` +- Python startup overhead for `ROAR_WRAP=1`, with and without `ROAR_LOG_FILE` + +Each CLI scenario records: + +- repeated wall-time samples +- one `cProfile` run +- captured stdout/stderr +- top cumulative and internal Python hot spots + +The startup scenario records: + +- baseline vs wrapped wall time +- import-time breakdown from `python -X importtime -c pass` + +## Run It + +From the repo root: + +```bash +uv run --extra dev python scripts/profile_roar.py +``` + +Useful options: + +```bash +uv run --extra dev python scripts/profile_roar.py --iterations 5 --top 20 +uv run --extra dev python scripts/profile_roar.py --scenario cli_run_simple --scenario startup_wrap +``` + +## Output Files + +The harness writes: + +- JSON summary: `tests/benchmarks/results/profile_suite_latest.json` +- Markdown summary: `tests/benchmarks/results/profile_suite_latest.md` +- raw profile artifacts: `tests/benchmarks/results/profiles//` +- a copied latest artifact set: `tests/benchmarks/results/profiles/latest/` + +The raw artifact directory contains: + +- `*.prof` `cProfile` files +- `*.stdout.txt` and `*.stderr.txt` for profiled CLI runs +- `startup_wrap.importtime.txt` for import-time output + +## How To Read It + +- Start with the wall-time means to find the slowest end-user workflows. +- For a slow CLI scenario, inspect `top_cumulative` first; it shows what dominates total time. +- Inspect `top_internal` when cumulative time is dominated by wrappers and you need the leaf work. +- For `startup_wrap`, compare: + - import overhead: `ROAR_WRAP=1` minus baseline + - atexit overhead: `ROAR_WRAP=1 + LOG_FILE` minus `ROAR_WRAP=1` + +## Current Focus + +The existing performance guardrail in +[`tests/execution/runtime/test_sitecustomize_perf.py`](/home/trevor/dev/roar-cli-polish/tests/execution/runtime/test_sitecustomize_perf.py) +is close to the local threshold. The profiling harness is intended to make that startup/runtime +path measurable enough to optimize, not just to rerun the guardrail test. diff --git a/roar/application/label_rendering.py b/roar/application/label_rendering.py new file mode 100644 index 00000000..5df3ef79 --- /dev/null +++ b/roar/application/label_rendering.py @@ -0,0 +1,37 @@ +"""Pure helpers for rendering label metadata without DB-side imports.""" + +from __future__ import annotations + +import json +from typing import Any + + +def flatten_label_metadata(metadata: dict[str, Any]) -> list[tuple[str, str]]: + """Flatten metadata into sorted ``(key, display_value)`` pairs.""" + flat: list[tuple[str, str]] = [] + + def _walk(prefix: str, value: Any) -> None: + if isinstance(value, dict): + for key in sorted(value.keys()): + next_prefix = f"{prefix}.{key}" if prefix else key + _walk(next_prefix, value[key]) + return + flat.append((prefix, _display_scalar(value))) + + _walk("", metadata) + return flat + + +def render_label_lines(metadata: dict[str, Any], indent: str = "") -> list[str]: + """Render a metadata document as sorted ``key=value`` lines.""" + return [f"{indent}{key}={value}" for key, value in flatten_label_metadata(metadata)] + + +def _display_scalar(value: Any) -> str: + if isinstance(value, bool): + return "true" if value else "false" + if value is None: + return "null" + if isinstance(value, (int, float, str)): + return str(value) + return json.dumps(value, sort_keys=True) diff --git a/roar/application/labels.py b/roar/application/labels.py index 1bec06d3..0ae6809a 100644 --- a/roar/application/labels.py +++ b/roar/application/labels.py @@ -15,6 +15,7 @@ from ..db.context import DatabaseContext from ..execution.recording.dataset_metadata import AUTO_DATASET_LABEL_KEYS +from .label_rendering import flatten_label_metadata RESERVED_LABEL_KEYS = set(AUTO_DATASET_LABEL_KEYS) @@ -58,27 +59,6 @@ def parse_label_pairs(pairs: tuple[str, ...]) -> dict[str, Any]: return metadata -def flatten_label_metadata(metadata: dict[str, Any]) -> list[tuple[str, str]]: - """Flatten metadata into sorted ``(key, display_value)`` pairs.""" - flat: list[tuple[str, str]] = [] - - def _walk(prefix: str, value: Any) -> None: - if isinstance(value, dict): - for key in sorted(value.keys()): - next_prefix = f"{prefix}.{key}" if prefix else key - _walk(next_prefix, value[key]) - return - flat.append((prefix, _display_scalar(value))) - - _walk("", metadata) - return flat - - -def render_label_lines(metadata: dict[str, Any], indent: str = "") -> list[str]: - """Render a metadata document as sorted ``key=value`` lines.""" - return [f"{indent}{key}={value}" for key, value in flatten_label_metadata(metadata)] - - class LabelService: """High-level local label workflow service.""" @@ -353,16 +333,6 @@ def _parse_scalar(raw: str) -> Any: return stripped -def _display_scalar(value: Any) -> str: - if isinstance(value, bool): - return "true" if value else "false" - if value is None: - return "null" - if isinstance(value, (int, float, str)): - return str(value) - return json.dumps(value, sort_keys=True) - - def _deep_merge(current: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: merged = json.loads(json.dumps(current)) for key, value in patch.items(): diff --git a/roar/application/publish/collection.py b/roar/application/publish/collection.py index fae97850..9d595329 100644 --- a/roar/application/publish/collection.py +++ b/roar/application/publish/collection.py @@ -11,6 +11,7 @@ from ...core.interfaces.logger import ILogger from ...db.context import create_database_context from ...db.hashing.backend import compute_hashes_batch +from ...db.query_context import create_query_database_context from .lineage import LineageCollector from .session import PublishSessionService from .targets import ( @@ -39,6 +40,7 @@ def collect_register_lineage( lineage_collector: LineageCollector, session_service: PublishSessionService, logger: ILogger, + dry_run: bool = False, ) -> tuple[CollectedRegisterLineage | None, str | None]: """Collect local lineage for a resolved register target.""" if target.kind == "step_reference": @@ -46,6 +48,7 @@ def collect_register_lineage( step_reference=target.value, roar_dir=roar_dir, lineage_collector=lineage_collector, + dry_run=dry_run, ) if target.kind == "job_uid": return _collect_job_lineage( @@ -82,18 +85,35 @@ def _collect_step_lineage( step_reference: str, roar_dir: Path, lineage_collector: LineageCollector, + dry_run: bool, ) -> tuple[CollectedRegisterLineage | None, str | None]: parsed = parse_register_step_reference(step_reference) if parsed is None: return None, f"Invalid DAG reference: {step_reference}" step_number, is_build = parsed - with create_database_context(roar_dir) as db_ctx: - session = db_ctx.sessions.get_active() - if not session: - return None, "No active session. Run 'roar run' to create a session first." + if dry_run: + with create_query_database_context(roar_dir) as db_ctx: + session = db_ctx.sessions.get_active() + if not session: + return None, "No active session. Run 'roar run' to create a session first." + session_id = int(session["id"]) + + lineage = lineage_collector.collect_step_read_only( + session_id=session_id, + step_number=step_number, + roar_dir=roar_dir, + job_type="build" if is_build else None, + ) + else: + with create_database_context(roar_dir) as db_ctx: + session = db_ctx.sessions.get_active() + if not session: + return None, "No active session. Run 'roar run' to create a session first." + session_id = int(session["id"]) + lineage = lineage_collector.collect_step( - session_id=int(session["id"]), + session_id=session_id, step_number=step_number, roar_dir=roar_dir, job_type="build" if is_build else None, diff --git a/roar/application/publish/job_preparation.py b/roar/application/publish/job_preparation.py index ccabf282..cba8b681 100644 --- a/roar/application/publish/job_preparation.py +++ b/roar/application/publish/job_preparation.py @@ -4,7 +4,6 @@ from typing import Any -from ...application.publish.registration import normalize_registration_hashes from ...execution.framework.registry import ( is_execution_noise_job, is_execution_submit_job, @@ -13,6 +12,15 @@ from ...integrations.glaas.registration import _artifact_ref +def normalize_registration_hashes(*args: Any, **kwargs: Any) -> Any: + """Load heavy registration hashing helpers only when needed.""" + from ...application.publish.registration import ( + normalize_registration_hashes as _normalize_registration_hashes, + ) + + return _normalize_registration_hashes(*args, **kwargs) + + def normalize_jobs_for_registration(jobs: list[dict[str, Any]]) -> list[dict[str, Any]]: """Drop known noise jobs and repair unresolved local parent references.""" normalized = [dict(job) for job in jobs if not is_execution_noise_job(job)] diff --git a/roar/application/publish/lineage.py b/roar/application/publish/lineage.py index d6015430..fa38ed28 100644 --- a/roar/application/publish/lineage.py +++ b/roar/application/publish/lineage.py @@ -4,11 +4,14 @@ This module is the canonical home for publish-time lineage traversal. """ +from __future__ import annotations + from pathlib import Path from ...core.digests import extract_primary_digest from ...core.interfaces.lineage import LineageData from ...db.context import create_database_context +from ...db.query_context import create_query_database_context from ...execution.framework.registry import is_execution_task_job @@ -176,6 +179,70 @@ def collect_step( pipeline=session, ) + def collect_step_read_only( + self, + session_id: int, + step_number: int, + roar_dir: Path, + job_type: str | None = None, + ) -> LineageData: + """Collect step lineage through the sqlite query context for preview flows.""" + with create_query_database_context(roar_dir) as ctx_db: + session = ctx_db.sessions.get(session_id) + if not session: + return LineageData() + + step_jobs = self._get_step_jobs(ctx_db, session_id, step_number, job_type=job_type) + if not step_jobs: + return LineageData(pipeline=session) + + hydrated_step_jobs = [self._hydrate_job(ctx_db, job) for job in step_jobs] + target_hashes = sorted( + { + digest + for job in hydrated_step_jobs + for digest in job.get("_output_hashes", []) + if digest + } + ) + if not target_hashes: + target_hashes = sorted( + { + digest + for job in hydrated_step_jobs + for digest in job.get("_input_hashes", []) + if digest + } + ) + + if target_hashes: + lineage_jobs = self._get_lineage_jobs_read_only(ctx_db, target_hashes) + if session: + lineage_jobs = self._add_build_jobs( + ctx_db, session, lineage_jobs, set(target_hashes) + ) + lineage_jobs = self._add_parent_jobs(ctx_db, lineage_jobs) + lineage_jobs = self._add_parent_linked_execution_tasks(ctx_db, lineage_jobs) + else: + lineage_jobs = [] + + seen_ids = {job["id"] for job in lineage_jobs} + for job in hydrated_step_jobs: + if job["id"] not in seen_ids: + lineage_jobs.append(job) + seen_ids.add(job["id"]) + + lineage_jobs.sort(key=lambda job: job["timestamp"]) + all_hashes = self._collect_all_hashes(lineage_jobs) + artifacts = self._get_artifact_info(ctx_db, all_hashes) + + return LineageData( + jobs=lineage_jobs, + artifacts=artifacts, + artifact_hashes=all_hashes, + pipeline=session, + ) + def collect_session( self, session_id: int, @@ -311,6 +378,81 @@ def _add_build_jobs( # Combine build jobs with lineage jobs, avoiding duplicates return build_job_list + [j for j in lineage_jobs if j["id"] not in build_job_ids] + def _get_lineage_jobs_read_only( + self, + ctx_db, + artifact_ids: list[str], + max_depth: int = 10, + ) -> list[dict]: + """Reconstruct lineage jobs without loading the SQLAlchemy service stack.""" + resolved_ids = [] + for artifact_id in artifact_ids: + artifact = ctx_db.artifacts.get(artifact_id) + if artifact: + resolved_ids.append(artifact_id) + continue + artifact = ctx_db.artifacts.get_by_hash(artifact_id) + if artifact: + resolved_ids.append(artifact["id"]) + + visited_jobs: set[int] = set() + visited_artifacts: set[str] = set() + jobs: list[dict] = [] + + def trace_upstream(artifact_id: str, current_depth: int) -> None: + if current_depth > max_depth or artifact_id in visited_artifacts: + return + visited_artifacts.add(artifact_id) + + artifact_jobs = ctx_db.artifacts.get_jobs(artifact_id) + produced_by = artifact_jobs.get("produced_by", []) + producer = produced_by[0] if produced_by else None + + if producer and producer["id"] not in visited_jobs: + visited_jobs.add(producer["id"]) + job_dict = dict(producer) + + inputs = ctx_db.jobs.get_inputs(producer["id"]) + job_dict["_input_artifact_ids"] = [inp["artifact_id"] for inp in inputs] + job_dict["_input_hashes"] = [ + h for h in (_extract_primary_digest(inp) for inp in inputs) if h + ] + job_dict["_inputs"] = [ + { + "hash": h, + "path": inp.get("path") or inp.get("first_seen_path", ""), + "byte_ranges": inp.get("byte_ranges"), + } + for inp in inputs + if (h := _extract_primary_digest(inp)) + ] + + for inp in inputs: + trace_upstream(inp["artifact_id"], current_depth + 1) + + outputs = ctx_db.jobs.get_outputs(producer["id"]) + job_dict["_output_artifact_ids"] = [out["artifact_id"] for out in outputs] + job_dict["_output_hashes"] = [ + h for h in (_extract_primary_digest(out) for out in outputs) if h + ] + job_dict["_outputs"] = [ + { + "hash": h, + "path": out.get("path") or out.get("first_seen_path", ""), + "byte_ranges": out.get("byte_ranges"), + } + for out in outputs + if (h := _extract_primary_digest(out)) + ] + + jobs.append(job_dict) + + for artifact_id in resolved_ids: + trace_upstream(artifact_id, 0) + + jobs.sort(key=lambda job: job["timestamp"]) + return jobs + def _add_parent_linked_execution_tasks(self, ctx_db, lineage_jobs: list[dict]) -> list[dict]: """Include distributed child jobs reachable via parent_job_uid edges.""" if not lineage_jobs: diff --git a/roar/application/publish/register_execution.py b/roar/application/publish/register_execution.py index 62bf1315..f18ed9dc 100644 --- a/roar/application/publish/register_execution.py +++ b/roar/application/publish/register_execution.py @@ -4,37 +4,98 @@ Owns the registration mechanics after local lineage has already been collected. """ +from __future__ import annotations + from collections.abc import Callable from dataclasses import dataclass, field from functools import cached_property from pathlib import Path -from typing import Any - -from ...application.publish.composite_builder import CompositeArtifactBuilder -from ...application.publish.register_preparation import PreparedRegisterExecution -from ...application.publish.registration import ( - normalize_registration_hashes, - prepare_batch_registration_artifacts, - register_publish_lineage, -) -from ...core.interfaces.lineage import LineageData -from ...core.interfaces.logger import ILogger +from typing import TYPE_CHECKING, Any + from ...core.logging import get_logger -from ...db.context import create_database_context from ...filters.omit import OmitFilter from ...integrations.config import config_get -from ...integrations.glaas import GlaasClient -from ...integrations.glaas.registration import RegistrationCoordinator -from .blake3_upgrade import upgrade_s3_etags_to_blake3 from .job_preparation import ( estimate_links, normalize_jobs_for_registration, order_jobs_for_registration, - refresh_job_artifact_references, ) -from .lineage_composites import has_lineage_composites, preregister_lineage_composites_with_glaas from .secrets import detect_lineage_secrets, filter_lineage_secrets +if TYPE_CHECKING: + from ...application.publish.composite_builder import CompositeArtifactBuilder + from ...application.publish.register_preparation import PreparedRegisterExecution + from ...core.interfaces.lineage import LineageData + from ...core.interfaces.logger import ILogger + from ...integrations.glaas import GlaasClient + from ...integrations.glaas.registration import RegistrationCoordinator + + +def create_database_context(roar_dir: Path) -> Any: + """Load SQLAlchemy DB context only for non-dry-run registration paths.""" + from ...db.context import create_database_context as _create_database_context + + return _create_database_context(roar_dir) + + +def upgrade_s3_etags_to_blake3(*args: Any, **kwargs: Any) -> None: + """Load S3 hash upgrade support only when requested.""" + from .blake3_upgrade import upgrade_s3_etags_to_blake3 as _upgrade_s3_etags_to_blake3 + + _upgrade_s3_etags_to_blake3(*args, **kwargs) + + +def has_lineage_composites(*args: Any, **kwargs: Any) -> bool: + """Load composite detection only for non-dry-run registration paths.""" + from .lineage_composites import has_lineage_composites as _has_lineage_composites + + return _has_lineage_composites(*args, **kwargs) + + +def preregister_lineage_composites_with_glaas(*args: Any, **kwargs: Any) -> Any: + """Load composite preregistration only when needed.""" + from .lineage_composites import ( + preregister_lineage_composites_with_glaas as _preregister_lineage_composites_with_glaas, + ) + + return _preregister_lineage_composites_with_glaas(*args, **kwargs) + + +def refresh_job_artifact_references(*args: Any, **kwargs: Any) -> None: + """Load heavy job/artifact reconciliation only for real registration.""" + from .job_preparation import ( + refresh_job_artifact_references as _refresh_job_artifact_references, + ) + + _refresh_job_artifact_references(*args, **kwargs) + + +def prepare_batch_registration_artifacts(*args: Any, **kwargs: Any) -> Any: + """Load batch registration helpers only for real registration.""" + from ...application.publish.registration import ( + prepare_batch_registration_artifacts as _prepare_batch_registration_artifacts, + ) + + return _prepare_batch_registration_artifacts(*args, **kwargs) + + +def register_publish_lineage(*args: Any, **kwargs: Any) -> Any: + """Load batch registration only for real registration.""" + from ...application.publish.registration import ( + register_publish_lineage as _register_publish_lineage, + ) + + return _register_publish_lineage(*args, **kwargs) + + +def normalize_registration_hashes(*args: Any, **kwargs: Any) -> Any: + """Load hash normalization only when extracting registration payloads.""" + from ...application.publish.registration import ( + normalize_registration_hashes as _normalize_registration_hashes, + ) + + return _normalize_registration_hashes(*args, **kwargs) + @dataclass class RegisterResult: @@ -89,7 +150,7 @@ def __init__( """ self._glaas_client = glaas_client self._coordinator = coordinator - self._composite_builder = composite_builder or CompositeArtifactBuilder() + self._composite_builder = composite_builder self._omit_filter = omit_filter self._logger = logger or get_logger() @@ -106,12 +167,29 @@ def omit_filter(self) -> OmitFilter | None: @cached_property def glaas_client(self) -> GlaasClient: """Get or create GLaaS client.""" - return self._glaas_client or GlaasClient() + if self._glaas_client is None: + from ...integrations.glaas import GlaasClient + + self._glaas_client = GlaasClient() + return self._glaas_client @cached_property def coordinator(self) -> RegistrationCoordinator: """Get or create registration coordinator.""" - return self._coordinator or RegistrationCoordinator() + if self._coordinator is None: + from ...integrations.glaas.registration import RegistrationCoordinator + + self._coordinator = RegistrationCoordinator() + return self._coordinator + + @property + def composite_builder(self) -> CompositeArtifactBuilder: + """Get or create composite builder only for real registration paths.""" + if self._composite_builder is None: + from ...application.publish.composite_builder import CompositeArtifactBuilder + + self._composite_builder = CompositeArtifactBuilder() + return self._composite_builder def register_prepared_lineage( self, @@ -204,7 +282,7 @@ def register_prepared_lineage( lineage_artifacts=lineage.artifacts, session_hash=session_hash, registration_errors=pre_registration_errors, - composite_builder=self._composite_builder, + composite_builder=self.composite_builder, logger=self._logger, ) except Exception as e: diff --git a/roar/application/publish/register_preparation.py b/roar/application/publish/register_preparation.py index c6e61cc4..990d3776 100644 --- a/roar/application/publish/register_preparation.py +++ b/roar/application/publish/register_preparation.py @@ -7,7 +7,6 @@ from ...core.interfaces.logger import ILogger from ...core.interfaces.registration import GitContext -from ...integrations.config import config_get from ..git import build_roar_git_tag_name, ensure_clean_git_repo, resolve_roar_git_context from .runtime import PublishRuntime from .session import prepare_publish_session @@ -41,9 +40,13 @@ def prepare_register_execution( git_tag_name: str | None = None git_tag_repo_root: Path | None = None - tagging_enabled = config_get("registration.tagging.enabled") - if tagging_enabled is None: - tagging_enabled = True + tagging_enabled = True + if not dry_run: + from ...integrations.config import config_get + + tagging_enabled = config_get("registration.tagging.enabled") + if tagging_enabled is None: + tagging_enabled = True if not dry_run and tagging_enabled and git_context.commit: git_state = ensure_clean_git_repo( diff --git a/roar/application/publish/register_preview_jobs.py b/roar/application/publish/register_preview_jobs.py new file mode 100644 index 00000000..1e1df38a --- /dev/null +++ b/roar/application/publish/register_preview_jobs.py @@ -0,0 +1,112 @@ +"""Lightweight job normalization helpers for local register preview flows.""" + +from __future__ import annotations + +from typing import Any + +from ...db.step_priority import is_host_or_submit_job, is_noise_job, is_task_job + + +def normalize_jobs_for_registration(jobs: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Drop known noise jobs and repair unresolved local parent references.""" + normalized = [dict(job) for job in jobs if not is_noise_job(job)] + known_job_uids = { + str(job["job_uid"]) for job in normalized if isinstance(job.get("job_uid"), str) + } + root_candidates = [job for job in normalized if _is_local_parent_candidate(job)] + if not root_candidates: + root_candidates = [job for job in normalized if not is_task_job(job)] + + for job in normalized: + parent_uid = str(job.get("parent_job_uid") or "").strip() + if not parent_uid or parent_uid in known_job_uids: + continue + + inferred_parent_uid = _infer_local_parent_uid(job, root_candidates) + if inferred_parent_uid: + job["parent_job_uid"] = inferred_parent_uid + else: + job["parent_job_uid"] = None + + return normalized + + +def order_jobs_for_registration(jobs: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Order jobs so parents are registered before their children.""" + jobs_by_uid = {str(job["job_uid"]): job for job in jobs if isinstance(job.get("job_uid"), str)} + ordered: list[dict[str, Any]] = [] + seen: set[str] = set() + + def visit(job: dict[str, Any]) -> None: + parent_uid = job.get("parent_job_uid") + if isinstance(parent_uid, str) and parent_uid: + parent = jobs_by_uid.get(parent_uid) + if parent is not None: + visit(parent) + + visit_key = str(job.get("job_uid") or f"id:{job.get('id')}") + if visit_key in seen: + return + seen.add(visit_key) + ordered.append(job) + + for job in sorted( + jobs, + key=lambda item: ( + int(item.get("step_number") or 0), + float(item.get("timestamp") or 0.0), + int(item.get("id") or 0), + ), + ): + visit(job) + + return ordered + + +def estimate_links(jobs: list[dict[str, Any]]) -> int: + """Estimate number of artifact links represented by the lineage jobs.""" + links = 0 + for job in jobs: + links += len(job.get("_inputs", [])) + links += len(job.get("_outputs", [])) + return links + + +def _infer_local_parent_uid( + job: dict[str, Any], + candidates: list[dict[str, Any]], +) -> str | None: + job_step = int(job.get("step_number") or 0) + job_timestamp = float(job.get("timestamp") or 0.0) + + eligible = [ + candidate + for candidate in candidates + if ( + int(candidate.get("step_number") or 0) < job_step + or ( + int(candidate.get("step_number") or 0) == job_step + and float(candidate.get("timestamp") or 0.0) <= job_timestamp + ) + ) + ] + if not eligible: + return None + + preferred = max(eligible, key=_parent_candidate_sort_key) + inferred_uid = preferred.get("job_uid") + return str(inferred_uid) if inferred_uid else None + + +def _is_local_parent_candidate(job: dict[str, Any]) -> bool: + job_type = str(job.get("job_type", "") or "") + return not is_task_job(job) and not is_noise_job(job) and job_type != "build" + + +def _parent_candidate_sort_key(job: dict[str, Any]) -> tuple[int, int, float, int]: + return ( + 1 if is_host_or_submit_job(job) else 0, + int(job.get("step_number") or 0), + float(job.get("timestamp") or 0.0), + int(job.get("id") or 0), + ) diff --git a/roar/application/publish/service.py b/roar/application/publish/service.py index 5f1fd307..300b6ab8 100644 --- a/roar/application/publish/service.py +++ b/roar/application/publish/service.py @@ -2,42 +2,341 @@ from __future__ import annotations -from ...core.bootstrap import bootstrap +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + from ...core.logging import get_logger -from ...db.context import create_database_context -from ...integrations.glaas import get_glaas_url -from ...integrations.storage import resolve_publish_storage_backend -from ..git import finalize_put_git, finalize_register_git, prepare_put_git -from .collection import collect_register_lineage -from .put_execution import PutService -from .put_preparation import prepare_put_execution -from .register_execution import RegisterService -from .register_preparation import prepare_register_execution from .requests import ( PutRequest, RegisterLineageRequest, ) from .results import ( + PutCompositeRegistration, + PutDryRunItem, PutResponse, + PutUploadedFile, RegisterLineageResponse, ) -from .runtime import build_publish_runtime -from .targets import resolve_register_lineage_target + +if TYPE_CHECKING: + from ...db.query_context import QueryDatabaseContext + + +def bootstrap(roar_dir: Path) -> None: + """Load bootstrap dependencies only when the put workflow runs.""" + from ...core.bootstrap import bootstrap as _bootstrap + + _bootstrap(roar_dir) + + +def create_database_context(roar_dir: Path) -> Any: + """Load database context factory lazily for publish workflows.""" + from ...db.context import create_database_context as _create_database_context + + return _create_database_context(roar_dir) + + +def create_query_database_context(roar_dir: Path) -> Any: + """Load the lightweight query DB context only for read-only publish flows.""" + from ...db.query_context import create_query_database_context as _create_query_database_context + + return _create_query_database_context(roar_dir) + + +def get_glaas_url() -> str | None: + """Load GLaaS config lookup lazily.""" + from ...integrations.glaas import get_glaas_url as _get_glaas_url + + return _get_glaas_url() + + +def resolve_publish_storage_backend(destination: str) -> Any: + """Resolve storage backends only for non-dry-run put operations.""" + from ...integrations.storage import ( + resolve_publish_storage_backend as _resolve_publish_storage_backend, + ) + + return _resolve_publish_storage_backend(destination) + + +def prepare_put_git(*args: Any, **kwargs: Any) -> Any: + """Load git helpers only for put workflows.""" + from ..git import prepare_put_git as _prepare_put_git + + return _prepare_put_git(*args, **kwargs) + + +def finalize_put_git(*args: Any, **kwargs: Any) -> Any: + """Load git helpers only for put workflows.""" + from ..git import finalize_put_git as _finalize_put_git + + return _finalize_put_git(*args, **kwargs) + + +def finalize_register_git(*args: Any, **kwargs: Any) -> Any: + """Load git helpers only for register workflows.""" + from ..git import finalize_register_git as _finalize_register_git + + return _finalize_register_git(*args, **kwargs) + + +def collect_register_lineage(*args: Any, **kwargs: Any) -> Any: + """Load lineage collection only when register runs.""" + from .collection import collect_register_lineage as _collect_register_lineage + + return _collect_register_lineage(*args, **kwargs) + + +def prepare_put_execution(*args: Any, **kwargs: Any) -> Any: + """Load put preparation only for non-dry-run put operations.""" + from .put_preparation import prepare_put_execution as _prepare_put_execution + + return _prepare_put_execution(*args, **kwargs) + + +def prepare_register_execution(*args: Any, **kwargs: Any) -> Any: + """Load register preparation only when register runs.""" + from .register_preparation import ( + prepare_register_execution as _prepare_register_execution, + ) + + return _prepare_register_execution(*args, **kwargs) + + +def PutService(*args: Any, **kwargs: Any) -> Any: + """Construct the put service only when the put execution path is used.""" + from .put_execution import PutService as _PutService + + return _PutService(*args, **kwargs) + + +def RegisterService(*args: Any, **kwargs: Any) -> Any: + """Construct the register service only when the register path is used.""" + from .register_execution import RegisterService as _RegisterService + + return _RegisterService(*args, **kwargs) + + +def build_publish_runtime(*, glaas_url: str | None = None) -> Any: + """Load publish runtime assembly only when publish workflows execute.""" + from .runtime import build_publish_runtime as _build_publish_runtime + + return _build_publish_runtime(glaas_url=glaas_url) + + +def resolve_register_lineage_target(*args: Any, **kwargs: Any) -> Any: + """Load register target resolution only when register runs.""" + from .targets import ( + resolve_register_lineage_target as _resolve_register_lineage_target, + ) + + return _resolve_register_lineage_target(*args, **kwargs) + + +@dataclass(frozen=True) +class _PutPlanResult: + """Lightweight local plan result for `roar put --dry-run`.""" + + success: bool + dry_run: bool + would_upload: list[PutDryRunItem] = field(default_factory=list) + uploaded_files: list[PutUploadedFile] = field(default_factory=list) + composites_registered: list[PutCompositeRegistration] = field(default_factory=list) + job_id: int | None = None + job_uid: str | None = None + session_hash: str | None = None + session_url: str | None = None + error: str | None = None + + +@dataclass(frozen=True) +class _RegisterPreviewRuntime: + """Minimal runtime surface for local `roar register --dry-run` flows.""" + + glaas_client: Any + session_service: Any + lineage_collector: Any + + +@dataclass(frozen=True) +class _PreparedRegisterPreviewExecution: + """Local preview-only register preparation result.""" + + git_context: Any + session_id: int | None + session_hash: str + session_url: str | None + git_tag_name: str | None = None + git_tag_repo_root: Path | None = None + + +def build_register_preview_runtime() -> Any: + """Build only the dependencies needed for local register preview flows.""" + from ...integrations.glaas.client import GlaasClient + from ...integrations.glaas.registration.session import SessionRegistrationService + from .lineage import LineageCollector + + glaas_client = GlaasClient("") + return _RegisterPreviewRuntime( + glaas_client=glaas_client, + session_service=SessionRegistrationService(glaas_client), + lineage_collector=LineageCollector(), + ) + + +def prepare_register_preview_execution( + *, + runtime: Any, + roar_dir: Path, + cwd: Path, + session_id: int | None, + session_hash_override: str | None, + logger: Any, +) -> Any: + """Prepare local register preview state without importing full git workflow helpers.""" + git_context = _resolve_register_preview_git_context(path=cwd, logger=logger) + if session_hash_override: + session_hash = session_hash_override + else: + if session_id is None: + raise ValueError("Cannot compute a session hash without a local session id.") + session_hash = runtime.session_service.compute_session_hash( + roar_dir=str(roar_dir), + session_id=session_id, + ) + + logger.debug("Session hash: %s", session_hash[:12]) + return _PreparedRegisterPreviewExecution( + git_context=git_context, + session_id=session_id, + session_hash=session_hash, + session_url=None, + ) + + +def preview_register_lineage( + *, + lineage: Any, + artifact_hash: str, + prepared: Any, + cwd: Path, + skip_confirmation: bool, + confirm_callback: Any, +) -> RegisterLineageResponse: + """Build a local register preview result without importing real registration machinery.""" + from ...filters.omit import OmitFilter + from ...integrations.config.raw import get_raw_registration_omit_config + from .register_preview_jobs import ( + estimate_links, + normalize_jobs_for_registration, + order_jobs_for_registration, + ) + from .secrets import detect_lineage_secrets, filter_lineage_secrets + + omit_filter = None + omit_config = get_raw_registration_omit_config(start_dir=str(cwd)) + if omit_config.get("enabled", True): + omit_filter = OmitFilter(omit_config) + + detected_secrets: list[str] = [] + if omit_filter is not None: + detected_secrets = detect_lineage_secrets( + lineage=lineage, + git_context=prepared.git_context, + omit_filter=omit_filter, + ) + if detected_secrets and not skip_confirmation: + if confirm_callback is None: + return RegisterLineageResponse( + success=False, + session_hash=prepared.session_hash, + artifact_hash=artifact_hash, + error="Secrets detected in data. Use --yes to proceed with redacted data.", + secrets_detected=detected_secrets, + aborted_by_user=True, + ) + + if not confirm_callback(detected_secrets): + return RegisterLineageResponse( + success=False, + session_hash=prepared.session_hash, + artifact_hash=artifact_hash, + error="Registration aborted by user.", + secrets_detected=detected_secrets, + aborted_by_user=True, + ) + + if detected_secrets or (omit_filter is not None and omit_filter.enabled): + lineage = filter_lineage_secrets( + lineage=lineage, + omit_filter=omit_filter, + ) + + registration_jobs = order_jobs_for_registration(normalize_jobs_for_registration(lineage.jobs)) + return RegisterLineageResponse( + success=True, + session_hash=prepared.session_hash, + artifact_hash=artifact_hash, + jobs_registered=len(registration_jobs), + artifacts_registered=len(lineage.artifacts), + links_created=estimate_links(registration_jobs), + secrets_detected=detected_secrets, + secrets_redacted=bool(detected_secrets), + ) + + +def _resolve_register_preview_git_context(*, path: Path, logger: Any) -> Any: + """Resolve git context for local preview flows without importing git provider models.""" + from ...core.interfaces.registration import GitContext + + repo_root = _run_git(path, "rev-parse", "--show-toplevel") + if not repo_root: + logger.debug("No git repository found for register preview at %s", path) + return GitContext(repo=None, commit=None, branch=None) + + repo_root_path = Path(repo_root) + commit = _run_git(repo_root_path, "rev-parse", "HEAD") + branch = _run_git(repo_root_path, "rev-parse", "--abbrev-ref", "HEAD") + remote = _run_git(repo_root_path, "remote", "get-url", "origin") + repo = remote or repo_root_path.resolve().as_uri() + if remote is None: + logger.debug( + "No git remote configured for %s; using local repository URI %s", + repo_root_path, + repo, + ) + + return GitContext(repo=repo, commit=commit, branch=branch) + + +def _run_git(path: Path, *args: str) -> str | None: + """Run a git command and return stripped stdout on success.""" + try: + return subprocess.check_output( + ["git", *args], + cwd=path, + stderr=subprocess.DEVNULL, + text=True, + ).strip() + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + return None def register_lineage_target(request: RegisterLineageRequest) -> RegisterLineageResponse: """Run the `roar register` application workflow.""" logger = get_logger() - runtime = build_publish_runtime(glaas_url=get_glaas_url()) - service = RegisterService( - glaas_client=runtime.glaas_client, - coordinator=runtime.registration_coordinator, - ) resolved_target = resolve_register_lineage_target( request.target, cwd=request.cwd, roar_dir=request.roar_dir, ) + runtime = ( + build_register_preview_runtime() + if request.dry_run + else build_publish_runtime(glaas_url=get_glaas_url()) + ) collected_lineage, error = collect_register_lineage( target=resolved_target, roar_dir=request.roar_dir, @@ -45,20 +344,31 @@ def register_lineage_target(request: RegisterLineageRequest) -> RegisterLineageR lineage_collector=runtime.lineage_collector, session_service=runtime.session_service, logger=logger, + dry_run=request.dry_run, ) if collected_lineage is None: return RegisterLineageResponse(success=False, error=error) try: - prepared = prepare_register_execution( - runtime=runtime, - roar_dir=request.roar_dir, - cwd=request.cwd, - session_id=collected_lineage.session_id, - dry_run=request.dry_run, - session_hash_override=collected_lineage.session_hash_override, - logger=logger, - ) + if request.dry_run: + prepared = prepare_register_preview_execution( + runtime=runtime, + roar_dir=request.roar_dir, + cwd=request.cwd, + session_id=collected_lineage.session_id, + session_hash_override=collected_lineage.session_hash_override, + logger=logger, + ) + else: + prepared = prepare_register_execution( + runtime=runtime, + roar_dir=request.roar_dir, + cwd=request.cwd, + session_id=collected_lineage.session_id, + dry_run=False, + session_hash_override=collected_lineage.session_hash_override, + logger=logger, + ) except ValueError as exc: return RegisterLineageResponse( success=False, @@ -66,6 +376,20 @@ def register_lineage_target(request: RegisterLineageRequest) -> RegisterLineageR error=str(exc), ) + if request.dry_run: + return preview_register_lineage( + lineage=collected_lineage.lineage, + artifact_hash=collected_lineage.artifact_hash, + prepared=prepared, + cwd=request.cwd, + skip_confirmation=request.skip_confirmation, + confirm_callback=request.confirm_callback, + ) + + service = RegisterService( + glaas_client=runtime.glaas_client, + coordinator=runtime.registration_coordinator, + ) result = service.register_prepared_lineage( lineage=collected_lineage.lineage, roar_dir=request.roar_dir, @@ -101,59 +425,78 @@ def register_lineage_target(request: RegisterLineageRequest) -> RegisterLineageR def put_artifacts(request: PutRequest) -> PutResponse: """Run the `roar put` application workflow.""" - bootstrap(request.roar_dir) + if not request.dry_run: + bootstrap(request.roar_dir) logger = get_logger() - backend = resolve_publish_storage_backend(request.destination) - repo_root = request.repo_root or request.cwd - git_state = prepare_put_git( - repo_root=repo_root, - dry_run=request.dry_run, - no_tag=request.no_tag, - logger=logger, - ) - warnings = list(git_state.warnings) - - with create_database_context(request.roar_dir) as db_ctx: - runtime = build_publish_runtime(glaas_url=get_glaas_url()) - service = PutService( - db_context=db_ctx, - backend=backend, - destination=request.destination, - repo_root=repo_root, - roar_dir=request.roar_dir, - lineage_collector=runtime.lineage_collector, - registration_coordinator=runtime.registration_coordinator, - ) - - prepared = prepare_put_execution( - db_ctx=db_ctx, - runtime=runtime, - roar_dir=request.roar_dir, + if request.dry_run: + git_state = None + git_commit = None + expected_tag = None + warnings: list[str] = [] + else: + git_state = prepare_put_git( repo_root=repo_root, - sources=request.sources, - destination=request.destination, - git_commit=git_state.git_commit, + dry_run=False, + no_tag=request.no_tag, logger=logger, ) + git_commit = git_state.git_commit + expected_tag = git_state.expected_tag + warnings = list(git_state.warnings) - result = service.put_prepared( - prepared=prepared, - sources=request.sources, - message=request.message, - dry_run=request.dry_run, - git_commit=git_state.git_commit, - git_tag=git_state.expected_tag, - ) + if request.dry_run: + with create_query_database_context(request.roar_dir) as db_ctx: + result = _plan_put_dry_run( + db_ctx=db_ctx, + repo_root=repo_root, + sources=request.sources, + ) + else: + with create_database_context(request.roar_dir) as db_ctx: + backend = resolve_publish_storage_backend(request.destination) + runtime = build_publish_runtime(glaas_url=get_glaas_url()) + service = PutService( + db_context=db_ctx, + backend=backend, + destination=request.destination, + repo_root=repo_root, + roar_dir=request.roar_dir, + lineage_collector=runtime.lineage_collector, + registration_coordinator=runtime.registration_coordinator, + ) + + prepared = prepare_put_execution( + db_ctx=db_ctx, + runtime=runtime, + roar_dir=request.roar_dir, + repo_root=repo_root, + sources=request.sources, + destination=request.destination, + git_commit=git_commit, + logger=logger, + ) + result = service.put_prepared( + prepared=prepared, + sources=request.sources, + message=request.message, + dry_run=request.dry_run, + git_commit=git_commit, + git_tag=expected_tag, + ) + + if request.dry_run: + created_git_tag = None + else: created_git_tag, git_tag_warnings = finalize_put_git( result_success=result.success, result_dry_run=result.dry_run, no_tag=request.no_tag, - git_commit=git_state.git_commit, - expected_tag=git_state.expected_tag, - git_state=git_state.git_state, + git_commit=git_commit, + expected_tag=expected_tag, + git_state=git_state.git_state if git_state is not None else None, repo_root=repo_root, logger=logger, ) @@ -174,3 +517,32 @@ def put_artifacts(request: PutRequest) -> PutResponse: warnings=warnings, error=result.error, ) + + +def _plan_put_dry_run( + *, + db_ctx: QueryDatabaseContext, + repo_root: Path, + sources: list[str], +) -> _PutPlanResult: + """Resolve the local source plan for `roar put --dry-run`.""" + from .source_resolution import SourceResolver + + active_session = db_ctx.sessions.get_active() + if active_session is None: + raise ValueError("No active session") + + resolver = SourceResolver( + repo_root=repo_root, + session_repo=db_ctx.sessions, + job_repo=db_ctx.jobs, + ) + resolved_sources = resolver.resolve(sources) + would_upload = [ + PutDryRunItem(path=str(source.path), exists=source.exists) for source in resolved_sources + ] + return _PutPlanResult( + success=True, + dry_run=True, + would_upload=would_upload, + ) diff --git a/roar/application/query/__init__.py b/roar/application/query/__init__.py index a4c5a2f2..50f847ab 100644 --- a/roar/application/query/__init__.py +++ b/roar/application/query/__init__.py @@ -1,67 +1,54 @@ -"""Application entrypoints for local query and label workflows.""" +"""Lazy exports for local query and label workflows.""" -from .dag import render_dag -from .label import ( - build_copy_labels_summary, - build_label_history_summary, - build_set_labels_summary, - build_show_labels_summary, - copy_labels, - label_history, - set_labels, - show_labels, -) -from .lineage import render_lineage -from .log import render_log -from .requests import ( - DagQueryRequest, - LabelCopyRequest, - LabelHistoryRequest, - LabelSetRequest, - LabelShowRequest, - LineageQueryRequest, - LogQueryRequest, - ShowQueryRequest, - StatusQueryRequest, -) -from .results import ( - LabelCurrentSummary, - LabelHistorySummary, - LineageSummary, - LogSummary, - ShowSummary, - StatusSummary, -) -from .show import render_show -from .status import render_status +from __future__ import annotations -__all__ = [ - "DagQueryRequest", - "LabelCopyRequest", - "LabelCurrentSummary", - "LabelHistoryRequest", - "LabelHistorySummary", - "LabelSetRequest", - "LabelShowRequest", - "LineageQueryRequest", - "LineageSummary", - "LogQueryRequest", - "LogSummary", - "ShowQueryRequest", - "ShowSummary", - "StatusQueryRequest", - "StatusSummary", - "build_copy_labels_summary", - "build_label_history_summary", - "build_set_labels_summary", - "build_show_labels_summary", - "copy_labels", - "label_history", - "render_dag", - "render_lineage", - "render_log", - "render_show", - "render_status", - "set_labels", - "show_labels", -] +from importlib import import_module +from typing import Any + +_EXPORTS = { + "DagQueryRequest": ".requests", + "LabelCopyRequest": ".requests", + "LabelHistoryRequest": ".requests", + "LabelSetRequest": ".requests", + "LabelShowRequest": ".requests", + "LineageQueryRequest": ".requests", + "LogQueryRequest": ".requests", + "ShowQueryRequest": ".requests", + "StatusQueryRequest": ".requests", + "LabelCurrentSummary": ".results", + "LabelHistorySummary": ".results", + "LineageSummary": ".results", + "LogSummary": ".results", + "ShowSummary": ".results", + "StatusSummary": ".results", + "build_copy_labels_summary": ".label", + "build_label_history_summary": ".label", + "build_set_labels_summary": ".label", + "build_show_labels_summary": ".label", + "copy_labels": ".label", + "label_history": ".label", + "render_dag": ".dag", + "render_lineage": ".lineage", + "render_log": ".log", + "render_show": ".show", + "render_status": ".status", + "set_labels": ".label", + "show_labels": ".label", +} + +__all__ = sorted(_EXPORTS) + + +def __getattr__(name: str) -> Any: + module_name = _EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module = import_module(module_name, __name__) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/roar/application/query/label.py b/roar/application/query/label.py index a7caf980..78fa13d4 100644 --- a/roar/application/query/label.py +++ b/roar/application/query/label.py @@ -3,7 +3,8 @@ from __future__ import annotations from ...db.context import create_database_context -from ..labels import LabelService, flatten_label_metadata, parse_label_pairs +from ..label_rendering import flatten_label_metadata +from ..labels import LabelService, parse_label_pairs from .requests import ( LabelCopyRequest, LabelHistoryRequest, diff --git a/roar/application/query/log.py b/roar/application/query/log.py index 4cf7cd6e..d31f1a57 100644 --- a/roar/application/query/log.py +++ b/roar/application/query/log.py @@ -2,11 +2,13 @@ from __future__ import annotations -from ...db.context import create_database_context +from ...db.query_context import create_query_database_context from ...presenters.formatting import format_duration, format_timestamp from .requests import LogQueryRequest from .results import LogJobSummary, LogSummary +_NO_ACTIVE_SESSION_MESSAGE = "No active session. Run 'roar run' to create a session first." + class LogQueryError(RuntimeError): """Raised when a log query cannot build a summary.""" @@ -14,10 +16,7 @@ class LogQueryError(RuntimeError): def render_log(request: LogQueryRequest) -> str: """Render recent job execution history.""" - try: - summary = build_log_summary(request) - except LogQueryError as exc: - return str(exc) + summary = build_log_summary(request) if not summary.jobs: return "No log entries found." @@ -56,10 +55,10 @@ def render_log(request: LogQueryRequest) -> str: def build_log_summary(request: LogQueryRequest) -> LogSummary: """Build a typed summary of recent job execution history.""" - with create_database_context(request.roar_dir) as db_ctx: + with create_query_database_context(request.roar_dir) as db_ctx: session = db_ctx.sessions.get_active() if not session: - raise LogQueryError("No active session.") + raise LogQueryError(_NO_ACTIVE_SESSION_MESSAGE) jobs = db_ctx.jobs.get_by_session(session["id"], limit=20) diff --git a/roar/application/query/requests.py b/roar/application/query/requests.py index c6683a9e..2789c71d 100644 --- a/roar/application/query/requests.py +++ b/roar/application/query/requests.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from pathlib import Path +from typing import Literal @dataclass(frozen=True) @@ -16,11 +17,15 @@ class DagQueryRequest: stale_only: bool +ShowQuerySelector = Literal["auto", "session", "path", "job", "artifact"] + + @dataclass(frozen=True) class ShowQueryRequest: roar_dir: Path cwd: Path ref: str | None + selector: ShowQuerySelector = "auto" @dataclass(frozen=True) diff --git a/roar/application/query/show.py b/roar/application/query/show.py index 85236067..d2d8a58b 100644 --- a/roar/application/query/show.py +++ b/roar/application/query/show.py @@ -7,9 +7,9 @@ from pathlib import Path from typing import Any, cast -from ...core.bootstrap import bootstrap from ...core.logging import get_logger -from ...db.context import create_database_context, optional_repo +from ...db.context import optional_repo +from ...db.query_context import create_query_database_context from ...presenters.show_renderer import ShowRenderer from .requests import ShowQueryRequest from .results import ( @@ -25,6 +25,8 @@ ShowSummary, ) +_NO_ACTIVE_SESSION_MESSAGE = "No active session. Run 'roar run' to create a session first." + class ShowQueryError(RuntimeError): """Raised when a show query cannot build a summary.""" @@ -32,10 +34,7 @@ class ShowQueryError(RuntimeError): def render_show(request: ShowQueryRequest) -> str: """Render session, job, or artifact details.""" - try: - summary = build_show_summary(request) - except ShowQueryError as exc: - return str(exc) + summary = build_show_summary(request) renderer = ShowRenderer() if isinstance(summary, ShowSessionSummary): @@ -60,57 +59,56 @@ def render_show(request: ShowQueryRequest) -> str: def build_show_summary(request: ShowQueryRequest) -> ShowSummary: """Build a typed show summary for session, job, or artifact details.""" - bootstrap(request.roar_dir) logger = _logger() if logger: - logger.debug("show: entry with ref=%r", request.ref) + logger.debug("show: entry with ref=%r selector=%r", request.ref, request.selector) + + with create_query_database_context(request.roar_dir) as db_ctx: + if request.selector == "session": + return _build_active_session_summary(db_ctx) + + if request.selector == "job": + if request.ref is None: + raise ShowQueryError("Job reference is required.") + return _build_job_summary_for_ref(db_ctx, request.ref) + + if request.selector == "path": + if request.ref is None: + raise ShowQueryError("Artifact path is required.") + return _build_artifact_summary_for_path(db_ctx, request.cwd, request.ref) + + if request.selector == "artifact": + if request.ref is None: + raise ShowQueryError("Artifact hash is required.") + return _build_artifact_summary_for_hash(db_ctx, request.ref) - with create_database_context(request.roar_dir) as db_ctx: if request.ref is None: - session = db_ctx.sessions.get_active() - if not session: - raise ShowQueryError("No active session.") - return _build_session_summary(db_ctx, session) + return _build_active_session_summary(db_ctx) ref_type = _classify_ref(request.ref, request.cwd) if logger: logger.debug("show: ref_type=%r for ref=%r", ref_type, request.ref) if ref_type == "job_step": - session = db_ctx.sessions.get_active() - if not session: - raise ShowQueryError("No active session.") - job = _resolve_job_ref(db_ctx, int(session["id"]), request.ref) - if not job: - raise ShowQueryError(f"Job not found: {request.ref}") - return _build_job_summary(db_ctx, job) + return _build_job_summary_for_ref(db_ctx, request.ref) if ref_type == "file_path": - path_obj = Path(os.path.expanduser(request.ref)) - if not path_obj.is_absolute(): - path_obj = request.cwd / path_obj - resolved_path = os.path.normpath(str(path_obj.absolute())) - artifact = db_ctx.artifacts.get_by_path(resolved_path) - if not artifact: - raise ShowQueryError(f"No artifact found for path: {request.ref}") - return _build_artifact_summary(db_ctx, artifact) + return _build_artifact_summary_for_path(db_ctx, request.cwd, request.ref) if ref_type == "job_uid": - job = db_ctx.jobs.get_by_uid(request.ref) - if not job: - raise ShowQueryError(f"Job not found: {request.ref}") - return _build_job_summary(db_ctx, job) + return _build_job_summary_for_ref(db_ctx, request.ref) if ref_type == "artifact_hash": job = db_ctx.jobs.get_by_uid(request.ref) if job: return _build_job_summary(db_ctx, job) - artifact = db_ctx.artifacts.get_by_hash(request.ref) - if artifact: - return _build_artifact_summary(db_ctx, artifact) - raise ShowQueryError(f"Not found: {request.ref}") + return _build_artifact_summary_for_hash(db_ctx, request.ref, missing_prefix="Not found") + + artifact = _lookup_artifact_by_path(db_ctx, request.cwd, request.ref) + if artifact: + return _build_artifact_summary(db_ctx, artifact) - raise ShowQueryError(f"Unknown reference format: {request.ref}") + raise ShowQueryError(f"No artifact found for path: {request.ref}") def _logger(): @@ -139,14 +137,60 @@ def _classify_ref(ref: str, cwd: Path) -> str: return "job_step" if "/" in ref or ref.startswith(("./", "../", "~")): return "file_path" - if (cwd / ref).exists(): - return "file_path" - is_hex = all(char in "0123456789abcdefABCDEF" for char in ref) + is_hex = bool(ref) and all(char in "0123456789abcdefABCDEF" for char in ref) if is_hex and len(ref) <= 8: return "job_uid" if is_hex and len(ref) > 8: return "artifact_hash" - return "unknown" + return "path_candidate" + + +def _lookup_artifact_by_path(db_ctx, cwd: Path, ref: str) -> dict[str, Any] | None: + path_obj = Path(os.path.expanduser(ref)) + if not path_obj.is_absolute(): + path_obj = cwd / path_obj + resolved_path = os.path.normpath(str(path_obj.absolute())) + return db_ctx.artifacts.get_by_path(resolved_path) + + +def _build_active_session_summary(db_ctx) -> ShowSessionSummary: + session = db_ctx.sessions.get_active() + if not session: + raise ShowQueryError(_NO_ACTIVE_SESSION_MESSAGE) + return _build_session_summary(db_ctx, session) + + +def _build_job_summary_for_ref(db_ctx, ref: str) -> ShowJobSummary: + if ref.startswith("@"): + session = db_ctx.sessions.get_active() + if not session: + raise ShowQueryError(_NO_ACTIVE_SESSION_MESSAGE) + job = _resolve_job_ref(db_ctx, int(session["id"]), ref) + else: + job = db_ctx.jobs.get_by_uid(ref) + + if not job: + raise ShowQueryError(f"Job not found: {ref}") + return _build_job_summary(db_ctx, job) + + +def _build_artifact_summary_for_path(db_ctx, cwd: Path, ref: str) -> ShowArtifactSummary: + artifact = _lookup_artifact_by_path(db_ctx, cwd, ref) + if not artifact: + raise ShowQueryError(f"No artifact found for path: {ref}") + return _build_artifact_summary(db_ctx, artifact) + + +def _build_artifact_summary_for_hash( + db_ctx, + ref: str, + *, + missing_prefix: str = "Artifact not found", +) -> ShowArtifactSummary: + artifact = db_ctx.artifacts.get_by_hash(ref) + if not artifact: + raise ShowQueryError(f"{missing_prefix}: {ref}") + return _build_artifact_summary(db_ctx, artifact) def _resolve_job_ref(db_ctx, session_id: int, job_ref: str) -> dict | None: diff --git a/roar/application/query/status.py b/roar/application/query/status.py index cd139d8f..16c0d14e 100644 --- a/roar/application/query/status.py +++ b/roar/application/query/status.py @@ -4,18 +4,21 @@ from pathlib import Path -from ...core.bootstrap import bootstrap -from ...db.context import create_database_context +from ...db.query_context import create_query_database_context from ...presenters.formatting import format_size from .requests import StatusQueryRequest from .results import StatusArtifactSummary, StatusSummary +_NO_ACTIVE_SESSION_MESSAGE = "No active session. Run 'roar run' to create a session first." + + +class StatusQueryError(RuntimeError): + """Raised when a status query cannot build a summary.""" + def render_status(request: StatusQueryRequest) -> str: """Render a summary of the active session.""" summary = build_status_summary(request) - if summary is None: - return "No active session." lines = [ "DAG:", f" Build steps: {summary.build_steps}", @@ -47,14 +50,12 @@ def render_status(request: StatusQueryRequest) -> str: return "\n".join(lines) -def build_status_summary(request: StatusQueryRequest) -> StatusSummary | None: +def build_status_summary(request: StatusQueryRequest) -> StatusSummary: """Build a typed summary of the active session status.""" - bootstrap(request.roar_dir) - - with create_database_context(request.roar_dir) as db_ctx: + with create_query_database_context(request.roar_dir) as db_ctx: session = db_ctx.sessions.get_active() if not session: - return None + raise StatusQueryError(_NO_ACTIVE_SESSION_MESSAGE) jobs = db_ctx.jobs.get_by_session(session["id"], limit=10000) @@ -67,21 +68,29 @@ def build_status_summary(request: StatusQueryRequest) -> StatusSummary | None: else: run_steps.add(step) - seen_artifact_ids: set[int] = set() artifacts: list[StatusArtifactSummary] = [] - for job in jobs: - for output in db_ctx.jobs.get_outputs(job["id"]): - artifact_id = output["artifact_id"] - if artifact_id not in seen_artifact_ids: - seen_artifact_ids.add(artifact_id) - artifacts.append( - StatusArtifactSummary( - artifact_hash=str(output["artifact_hash"] or ""), - size_bytes=int(output["size"] or 0), - path=str(output["path"]), - present=Path(output["path"]).exists(), - ) - ) + distinct_outputs = getattr(db_ctx.jobs, "get_distinct_outputs_by_session", None) + if callable(distinct_outputs): + outputs = distinct_outputs(session["id"]) + else: + seen_artifact_ids: set[int | str] = set() + outputs = [] + for job in jobs: + for output in db_ctx.jobs.get_outputs(job["id"]): + artifact_id = output["artifact_id"] + if artifact_id not in seen_artifact_ids: + seen_artifact_ids.add(artifact_id) + outputs.append(output) + + for output in outputs: + artifacts.append( + StatusArtifactSummary( + artifact_hash=str(output["artifact_hash"] or ""), + size_bytes=int(output["size"] or 0), + path=str(output["path"]), + present=Path(output["path"]).exists(), + ) + ) return StatusSummary( build_steps=len(build_steps), diff --git a/roar/cli/__init__.py b/roar/cli/__init__.py index b9d5cba9..e4c6e748 100644 --- a/roar/cli/__init__.py +++ b/roar/cli/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations +from collections.abc import Iterable from importlib import import_module import click @@ -28,28 +29,35 @@ # Lazy command registry: maps command name to (module_path, command_name, short_help) # Short help is stored here to avoid importing commands just for --help LAZY_COMMANDS: dict[str, tuple[str, str, str]] = { - "auth": ("roar.cli.commands.auth", "auth", "Manage authentication with GLaaS"), - "build": ("roar.cli.commands.build", "build", "Run a build step before the main pipeline"), + "auth": ("roar.cli.commands.auth", "auth", "Manage GLaaS auth and SSH keys"), + "build": ("roar.cli.commands.build", "build", "Track a build step before the main pipeline"), "config": ("roar.cli.commands.config", "config", "View or set configuration"), - "dag": ("roar.cli.commands.dag", "dag", "Show the execution DAG"), - "env": ("roar.cli.commands.env", "env", "Show environment information"), - "get": ("roar.cli.commands.get", "get", "Download artifacts from cloud storage"), - "init": ("roar.cli.commands.init", "init", "Initialize roar in current directory"), + "dag": ("roar.cli.commands.dag", "dag", "Inspect the local execution DAG"), + "env": ("roar.cli.commands.env", "env", "Manage persistent environment variables"), + "get": ("roar.cli.commands.get", "get", "Download published artifacts"), + "init": ("roar.cli.commands.init", "init", "Set up roar in a project"), "label": ("roar.cli.commands.label", "label", "Manage local labels"), - "lineage": ("roar.cli.commands.lineage", "lineage", "Show lineage for an artifact"), - "log": ("roar.cli.commands.log", "log", "Show execution log"), - "pop": ("roar.cli.commands.pop", "pop", "Pop the last step from the session"), + "lineage": ("roar.cli.commands.lineage", "lineage", "Inspect lineage for a tracked artifact"), + "log": ("roar.cli.commands.log", "log", "List jobs in the active session"), + "pop": ("roar.cli.commands.pop", "pop", "Remove the last local step"), "proxy": ("roar.cli.commands.proxy", "proxy", "Manage S3 proxy for lineage tracking"), - "put": ("roar.cli.commands.put", "put", "Publish artifacts to cloud storage"), - "register": ("roar.cli.commands.register", "register", "Register artifacts or jobs"), - "reproduce": ("roar.cli.commands.reproduce", "reproduce", "Reproduce an artifact"), + "put": ("roar.cli.commands.put", "put", "Publish artifacts and register lineage"), + "register": ("roar.cli.commands.register", "register", "Register local lineage with GLaaS"), + "reproduce": ("roar.cli.commands.reproduce", "reproduce", "Generate a reproduction plan"), "reset": ("roar.cli.commands.reset", "reset", "Reset roar state"), - "run": ("roar.cli.commands.run", "run", "Run a command with provenance tracking"), - "show": ("roar.cli.commands.show", "show", "Show details of a job or artifact"), - "status": ("roar.cli.commands.status", "status", "Show current session status"), - "tracer": ("roar.cli.commands.tracer", "tracer", "Manage tracer backend defaults"), + "run": ("roar.cli.commands.run", "run", "Track a command with provenance"), + "show": ("roar.cli.commands.show", "show", "Inspect a session, job, or artifact"), + "status": ("roar.cli.commands.status", "status", "Show the active session summary"), + "tracer": ("roar.cli.commands.tracer", "tracer", "Configure tracer backend defaults"), } +HELP_GROUPS: tuple[tuple[str, tuple[str, ...]], ...] = ( + ("Start Here", ("init", "run", "build", "dag")), + ("Inspect Local Lineage", ("status", "log", "show", "lineage", "pop", "reproduce")), + ("Share and Publish", ("put", "register", "get", "label")), + ("Setup and Admin", ("auth", "config", "env", "tracer", "proxy", "reset")), +) + class LazyCommand(click.Command): """A placeholder command that loads the real implementation on invoke.""" @@ -63,7 +71,16 @@ def __init__(self, name: str, module_path: str, attr_name: str, short_help: str) def _load(self) -> click.Command: """Load the real command implementation.""" if self._real_command is None: - module = import_module(self._module_path) + try: + module = import_module(self._module_path) + except ModuleNotFoundError as exc: + missing = exc.name or "unknown" + raise click.ClickException( + f"Failed to load '{self.name}' because import '{missing}' is unavailable. " + "Reinstall roar-cli or run it from a fully provisioned environment." + ) from exc + except ImportError as exc: + raise click.ClickException(f"Failed to load '{self.name}': {exc}") from exc self._real_command = getattr(module, self._attr_name) return self._real_command @@ -114,6 +131,50 @@ def list_commands(self, ctx: click.Context) -> list[str]: """List all available commands.""" return sorted(super().list_commands(ctx)) + def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """Render top-level commands grouped by workflow.""" + rendered_commands: set[str] = set() + + for section_name, command_names in HELP_GROUPS: + rows = self._get_command_rows(ctx, formatter, command_names) + if not rows: + continue + rendered_commands.update(name for name, _ in rows) + with formatter.section(section_name): + formatter.write_dl(rows) + + remaining = [name for name in self.list_commands(ctx) if name not in rendered_commands] + rows = self._get_command_rows(ctx, formatter, remaining) + if rows: + with formatter.section("Other Commands"): + formatter.write_dl(rows) + + def _get_command_rows( + self, + ctx: click.Context, + formatter: click.HelpFormatter, + command_names: Iterable[str], + ) -> list[tuple[str, str]]: + """Build help rows for a command section.""" + commands: list[tuple[str, click.Command]] = [] + max_name_len = 0 + + for command_name in command_names: + command = self.get_command(ctx, command_name) + if command is None or command.hidden: + continue + commands.append((command_name, command)) + max_name_len = max(max_name_len, len(command_name)) + + if not commands: + return [] + + help_limit = formatter.width - 6 - max_name_len + return [ + (command_name, command.get_short_help_str(help_limit)) + for command_name, command in commands + ] + @click.command(cls=LazyGroup, lazy_commands=LAZY_COMMANDS, invoke_without_command=True) @click.version_option(version=__version__, prog_name="roar") @@ -130,13 +191,20 @@ def cli(ctx: click.Context) -> None: roar run Run a command with provenance tracking \b - Information: - roar reproduce Reproduce an artifact + Inspect Local Lineage: + roar status Show the active session summary + roar show @1 Inspect a specific step, job, or artifact + roar dag See the full local execution graph + + \b + Share and Publish: + roar put ... Upload artifacts and register lineage + roar register Register local lineage with GLaaS \b - Configuration: + Setup and Configuration: roar config View or set configuration - roar auth Manage authentication with https://glaas.ai + roar auth key Show the SSH key used for GLaaS signup """ ctx.ensure_object(dict) diff --git a/roar/cli/commands/auth.py b/roar/cli/commands/auth.py index 5467b3b9..7bf5bc60 100644 --- a/roar/cli/commands/auth.py +++ b/roar/cli/commands/auth.py @@ -74,14 +74,14 @@ def auth(ctx: click.Context) -> None: """Manage authentication with https://glaas.ai \b - To register with GLaaS: - 1. Run 'roar auth register' to display your public key + To set up GLaaS auth: + 1. Run 'roar auth key' to display your public key 2. Sign up for GLaaS at https://glaas.ai where you can paste your public key 3. Once added, run 'roar auth test' to verify \b Examples: - roar auth register # Show your SSH key + roar auth key # Show your SSH key roar auth test # Test connection roar auth status # Show auth status """ @@ -89,9 +89,8 @@ def auth(ctx: click.Context) -> None: click.echo(ctx.get_help()) -@auth.command("register") -def auth_register() -> None: - """Show SSH public key for registration.""" +def _show_auth_key() -> None: + """Render the SSH public key guidance used by auth key and its legacy alias.""" key_info = _find_ssh_pubkey() if not key_info: @@ -99,7 +98,7 @@ def auth_register() -> None: "No SSH public key found.\n\n" "Generate one with:\n" " ssh-keygen -t ed25519\n\n" - "Then run 'roar auth register' again." + "Then run 'roar auth key' again." ) key_type, pubkey, path = key_info @@ -113,6 +112,18 @@ def auth_register() -> None: click.echo("Copy and paste this key when you sign up at https://glaas.ai") +@auth.command("key") +def auth_key() -> None: + """Show SSH public key for GLaaS signup.""" + _show_auth_key() + + +@auth.command("register", hidden=True) +def auth_register() -> None: + """Backward-compatible alias for 'roar auth key'.""" + _show_auth_key() + + @auth.command("test") def auth_test() -> None: """Test connection to GLaaS server.""" @@ -150,7 +161,7 @@ def auth_test() -> None: key_info = glaas_find_ssh_pubkey() if not key_info: - raise click.ClickException("No SSH key found. Run 'roar auth register' first.") + raise click.ClickException("No SSH key found. Run 'roar auth key' first.") _, pubkey, key_path = key_info fingerprint = compute_pubkey_fingerprint(pubkey) diff --git a/roar/cli/commands/dag.py b/roar/cli/commands/dag.py index 3e54dde0..afdb6efd 100644 --- a/roar/cli/commands/dag.py +++ b/roar/cli/commands/dag.py @@ -6,7 +6,8 @@ import click -from ...application.query import DagQueryRequest, render_dag +from ...application.query.dag import render_dag +from ...application.query.requests import DagQueryRequest from ..context import RoarContext from ..decorators import require_init diff --git a/roar/cli/commands/init.py b/roar/cli/commands/init.py index 5045af67..56dd1ef1 100644 --- a/roar/cli/commands/init.py +++ b/roar/cli/commands/init.py @@ -227,6 +227,7 @@ def init(ctx: RoarContext, yes: bool, no: bool, init_path: Path | None) -> None: roar init --path /some/dir # Initialize in a specific directory """ cwd = init_path if init_path is not None else ctx.cwd + target_repo_root = RoarContext._get_repo_root(cwd) # Check if .roar already exists roar_dir = cwd / ".roar" @@ -248,12 +249,12 @@ def init(ctx: RoarContext, yes: bool, no: bool, init_path: Path | None) -> None: click.echo(f"Created {roar_dir / 'config.toml'}") # Check if we're in a git repo - if ctx.repo_root is None: + if target_repo_root is None: click.echo("Not in a git repository. Done.") return # Check if .gitignore exists - gitignore_path = ctx.repo_root / ".gitignore" + gitignore_path = target_repo_root / ".gitignore" if not gitignore_path.exists(): click.echo("No .gitignore found. Done.") return diff --git a/roar/cli/commands/label.py b/roar/cli/commands/label.py index fb314214..a8865050 100644 --- a/roar/cli/commands/label.py +++ b/roar/cli/commands/label.py @@ -12,18 +12,20 @@ import click -from ...application.query import ( - LabelCopyRequest, - LabelHistoryRequest, - LabelSetRequest, - LabelShowRequest, +from ...application.query.label import ( copy_labels, set_labels, show_labels, ) -from ...application.query import ( +from ...application.query.label import ( label_history as render_label_history, ) +from ...application.query.requests import ( + LabelCopyRequest, + LabelHistoryRequest, + LabelSetRequest, + LabelShowRequest, +) from ..context import RoarContext from ..decorators import require_init diff --git a/roar/cli/commands/lineage.py b/roar/cli/commands/lineage.py index 68fa63b8..e3d40e95 100644 --- a/roar/cli/commands/lineage.py +++ b/roar/cli/commands/lineage.py @@ -2,7 +2,8 @@ import click -from ...application.query import LineageQueryRequest, render_lineage +from ...application.query.lineage import render_lineage +from ...application.query.requests import LineageQueryRequest from ..context import RoarContext from ..decorators import require_init diff --git a/roar/cli/commands/log.py b/roar/cli/commands/log.py index ddd0c082..b92acec0 100644 --- a/roar/cli/commands/log.py +++ b/roar/cli/commands/log.py @@ -6,7 +6,8 @@ import click -from ...application.query import LogQueryRequest, render_log +from ...application.query.log import LogQueryError, render_log +from ...application.query.requests import LogQueryRequest from ..context import RoarContext from ..decorators import require_init @@ -25,4 +26,9 @@ def log(ctx: RoarContext) -> None: roar log # Show recent job history """ - click.echo(render_log(LogQueryRequest(roar_dir=ctx.roar_dir, use_color=sys.stdout.isatty()))) + try: + click.echo( + render_log(LogQueryRequest(roar_dir=ctx.roar_dir, use_color=sys.stdout.isatty())) + ) + except LogQueryError as exc: + raise click.ClickException(str(exc)) from exc diff --git a/roar/cli/commands/put.py b/roar/cli/commands/put.py index 70b1aa44..572c3c43 100644 --- a/roar/cli/commands/put.py +++ b/roar/cli/commands/put.py @@ -13,11 +13,24 @@ from ...application.publish.requests import PutRequest from ...application.publish.service import put_artifacts -from ...integrations.config import config_get from ..context import RoarContext from ..decorators import require_init +def _preview_hash(value: str | None) -> str | None: + """Shorten long hashes for CLI summaries.""" + if not value: + return None + return f"{value[:12]}..." if len(value) > 12 else value + + +def _resolve_glaas_web_url() -> str: + """Load the GLaaS web URL lazily for success output.""" + from ...integrations.config import config_get + + return config_get("glaas.web_url") or "https://glaas.ai" + + @click.command("put") @click.argument("args", nargs=-1, required=True) @click.option( @@ -110,26 +123,38 @@ def put( # Handle dry run output if response.dry_run: - click.echo("Dry run - would upload:") + click.echo( + f"Dry run: would upload {len(response.would_upload)} file(s) to {response.destination}" + ) for dry_run_item in response.would_upload: click.echo(f" {dry_run_item.path}") - click.echo(f"\nTotal: {len(response.would_upload)} file(s)") return # Check for registration errors if not response.success: click.echo(f"Published {len(response.uploaded_files)} file(s) to {response.destination}") + if response.job_uid: + click.echo(f"Local details: roar show --job {response.job_uid}") click.echo("\nWarning: Registration completed with errors:", err=True) if response.error: for error in response.error.split("; "): click.echo(f" - {error}", err=True) raise click.ClickException("Registration completed with errors") - if response.git_tag: - click.echo(f"Created git tag: {response.git_tag}") - # Success output click.echo(f"Published {len(response.uploaded_files)} file(s) to {response.destination}") + session_preview = _preview_hash(response.session_hash) + if session_preview: + click.echo(f"Session: {session_preview}") + if response.job_id is not None: + click.echo(f"Job step: @{response.job_id}") + if response.job_uid: + click.echo(f"Job UID: {response.job_uid}") + if response.git_tag: + click.echo(f"Git tag: {response.git_tag}") + if response.uploaded_files: + click.echo("") + click.echo("Uploaded files:") for uploaded_file in response.uploaded_files: click.echo(f" {uploaded_file.local_path} -> {uploaded_file.remote_url}") if response.composites_registered: @@ -157,12 +182,15 @@ def put( f"Warning: local composite metadata was not persisted for {root_path}{detail}", err=True, ) - click.echo(f"\nJob created: step {response.job_id}") - if response.git_tag: - click.echo(f"Git tag: {response.git_tag}") - # Show GLaaS registration info - web_url = config_get("glaas.web_url") or "https://glaas.ai" + + web_url = _resolve_glaas_web_url() session_hash = response.session_hash or "" session_url = response.session_url or (f"{web_url}/dag/{session_hash}" if session_hash else "") - click.echo("\nRegistered with GLaaS:") - click.echo(f" View: {session_url}") + if session_url: + click.echo("\nGLaaS:") + click.echo(f" Session: {session_url}") + + click.echo("\nNext:") + if response.job_uid: + click.echo(f" roar show --job {response.job_uid}") + click.echo(" roar show --session") diff --git a/roar/cli/commands/register.py b/roar/cli/commands/register.py index 61236970..d0bee8d5 100644 --- a/roar/cli/commands/register.py +++ b/roar/cli/commands/register.py @@ -10,11 +10,22 @@ from ...application.publish.requests import RegisterLineageRequest from ...application.publish.service import register_lineage_target -from ...integrations.config import config_get from ..context import RoarContext from ..decorators import require_init +def _preview_hash(value: str) -> str: + """Shorten hashes in command summaries.""" + return f"{value[:12]}..." if len(value) > 12 else value + + +def _resolve_glaas_web_url(*, start_dir: str | None = None) -> str: + """Load the GLaaS web URL with the lightweight preview config path.""" + from ...integrations.config.raw import get_raw_glaas_web_url + + return get_raw_glaas_web_url(start_dir=start_dir) or "https://glaas.ai" + + def _confirm_secrets(detected_secrets: list[str]) -> bool: """Prompt user to confirm registration with secrets.""" click.echo("") @@ -99,25 +110,26 @@ def register(ctx: RoarContext, target: str, dry_run: bool, yes: bool, as_blake3: raise SystemExit(1) raise click.ClickException(response.error or "Registration failed") - web_url = config_get("glaas.web_url") or "https://glaas.ai" + web_url = _resolve_glaas_web_url(start_dir=str(ctx.cwd)) + session_preview = _preview_hash(response.session_hash) if response.session_hash else "" # Format output if dry_run: - click.echo("Dry run - would register:") - click.echo(f" Session: {response.session_hash[:12]}...") + click.echo(f"Dry run: would register lineage for: {target}") + click.echo(f" Session: {session_preview}") click.echo(f" Jobs: {response.jobs_registered}") click.echo(f" Artifacts: {response.artifacts_registered}") click.echo(f" Links: {response.links_created}") if response.secrets_detected: click.echo(f" Secrets to redact: {len(response.secrets_detected)} types") click.echo("") - click.echo("View on GLaaS:") + click.echo("GLaaS:") click.echo(f" Session: {web_url}/dag/{response.session_hash}") if response.artifact_hash: click.echo(f" Artifact: {web_url}/artifact/{response.artifact_hash}") else: click.echo(f"Registered lineage for: {target}") - click.echo(f" Session: {response.session_hash[:12]}...") + click.echo(f" Session: {session_preview}") click.echo(f" Jobs: {response.jobs_registered}") click.echo(f" Artifacts: {response.artifacts_registered}") click.echo(f" Links: {response.links_created}") @@ -131,13 +143,12 @@ def register(ctx: RoarContext, target: str, dry_run: bool, yes: bool, as_blake3: for error in response.error.split("; "): click.echo(f" - {error}", err=True) - if response.artifact_hash: - click.echo("") - click.echo("To reproduce this artifact:") - click.echo(f" roar reproduce {response.artifact_hash}") - click.echo("") - click.echo("View on GLaaS:") + click.echo("GLaaS:") click.echo(f" Session: {web_url}/dag/{response.session_hash}") if response.artifact_hash: click.echo(f" Artifact: {web_url}/artifact/{response.artifact_hash}") + click.echo("") + click.echo("Next:") + click.echo(f" roar show --artifact {response.artifact_hash}") + click.echo(f" roar reproduce {response.artifact_hash}") diff --git a/roar/cli/commands/show.py b/roar/cli/commands/show.py index f1043d13..535b7941 100644 --- a/roar/cli/commands/show.py +++ b/roar/cli/commands/show.py @@ -4,20 +4,38 @@ import click -from ...application.query import ShowQueryRequest, render_show +from ...application.query.requests import ShowQueryRequest, ShowQuerySelector +from ...application.query.show import ShowQueryError, render_show from ..context import RoarContext from ..decorators import require_init @click.command("show") +@click.option("--path", "path_ref", metavar="PATH", help="Show an artifact by path.") +@click.option( + "--job", + "job_ref", + metavar="REF", + help="Show a job by step ref or UID (for example, @1 or deadbeef).", +) +@click.option("--artifact", "artifact_ref", metavar="HASH", help="Show an artifact by hash.") +@click.option("--session", "show_session", is_flag=True, help="Show the active session.") @click.argument("ref", required=False) @click.pass_obj @require_init -def show(ctx: RoarContext, ref: str | None) -> None: +def show( + ctx: RoarContext, + path_ref: str | None, + job_ref: str | None, + artifact_ref: str | None, + show_session: bool, + ref: str | None, +) -> None: """Show session, job, or artifact details. Without arguments, displays the active session and its jobs. With a reference, displays detailed information based on the reference type. + Explicit selectors avoid ambiguous auto-detection. \b REF can be: @@ -29,10 +47,63 @@ def show(ctx: RoarContext, ref: str | None) -> None: \b Examples: roar show # Show active session overview + roar show --session # Show active session overview explicitly roar show @1 # Show details for step 1 + roar show --job @B1 # Show build step details explicitly roar show @B1 # Show details for build step 1 + roar show --artifact deadbeef # Force artifact lookup for an ambiguous hash roar show a1b2c3d4 # Show job by UID roar show a1b2c3d4e5f67890... # Show artifact by hash + roar show --path deadbeef # Force path lookup for an ambiguous filename roar show ./output/model.pkl # Show artifact by path """ - click.echo(render_show(ShowQueryRequest(roar_dir=ctx.roar_dir, cwd=ctx.cwd, ref=ref))) + request = _build_show_request( + ctx=ctx, + ref=ref, + path_ref=path_ref, + job_ref=job_ref, + artifact_ref=artifact_ref, + show_session=show_session, + ) + try: + click.echo(render_show(request)) + except ShowQueryError as exc: + raise click.ClickException(str(exc)) from exc + + +def _build_show_request( + *, + ctx: RoarContext, + ref: str | None, + path_ref: str | None, + job_ref: str | None, + artifact_ref: str | None, + show_session: bool, +) -> ShowQueryRequest: + explicit_targets: list[tuple[str, str | None, ShowQuerySelector]] = [] + if path_ref is not None: + explicit_targets.append(("--path", path_ref, "path")) + if job_ref is not None: + explicit_targets.append(("--job", job_ref, "job")) + if artifact_ref is not None: + explicit_targets.append(("--artifact", artifact_ref, "artifact")) + if show_session: + explicit_targets.append(("--session", None, "session")) + + if len(explicit_targets) > 1: + raise click.UsageError("Specify only one of --path, --job, --artifact, or --session.") + + if explicit_targets: + if ref is not None: + raise click.UsageError( + "Positional REF cannot be combined with --path, --job, --artifact, or --session." + ) + _, explicit_ref, selector = explicit_targets[0] + return ShowQueryRequest( + roar_dir=ctx.roar_dir, + cwd=ctx.cwd, + ref=explicit_ref, + selector=selector, + ) + + return ShowQueryRequest(roar_dir=ctx.roar_dir, cwd=ctx.cwd, ref=ref) diff --git a/roar/cli/commands/status.py b/roar/cli/commands/status.py index 29f0f865..6a81f5bb 100644 --- a/roar/cli/commands/status.py +++ b/roar/cli/commands/status.py @@ -4,7 +4,8 @@ import click -from ...application.query import StatusQueryRequest, render_status +from ...application.query.requests import StatusQueryRequest +from ...application.query.status import StatusQueryError, render_status from ..context import RoarContext from ..decorators import require_init @@ -14,4 +15,7 @@ @require_init def status(ctx: RoarContext) -> None: """Show a summary of the active session.""" - click.echo(render_status(StatusQueryRequest(roar_dir=ctx.roar_dir))) + try: + click.echo(render_status(StatusQueryRequest(roar_dir=ctx.roar_dir))) + except StatusQueryError as exc: + raise click.ClickException(str(exc)) from exc diff --git a/roar/core/bootstrap.py b/roar/core/bootstrap.py index 95249571..fa5fcc77 100644 --- a/roar/core/bootstrap.py +++ b/roar/core/bootstrap.py @@ -31,7 +31,7 @@ def bootstrap(roar_dir: Path | None = None) -> None: if _initialized: return - _configure_core_logging() + _configure_core_logging(roar_dir) # Register built-in integrations that should not depend on plugin discovery. _register_builtin_integrations() @@ -43,13 +43,19 @@ def bootstrap(roar_dir: Path | None = None) -> None: return -def _configure_core_logging() -> None: +def _configure_core_logging(roar_dir: Path | None = None) -> None: """Configure the process-wide logger from local config.""" - from ..integrations.config import config_get + from ..integrations.config import load_settings - level = config_get("logging.level") or "warning" - console_enabled = config_get("logging.console") or False - file_enabled = config_get("logging.file") + start_dir: str | None = None + if roar_dir is not None: + search_root = roar_dir.parent if roar_dir.name == ".roar" else roar_dir + start_dir = str(search_root) + + settings = load_settings(start_dir=start_dir) + level = settings.logging.level or "warning" + console_enabled = bool(settings.logging.console) + file_enabled = settings.logging.file if file_enabled is None: file_enabled = True configure_logger( diff --git a/roar/core/models/__init__.py b/roar/core/models/__init__.py index c8fa3dff..4dc3c331 100644 --- a/roar/core/models/__init__.py +++ b/roar/core/models/__init__.py @@ -1,140 +1,81 @@ -""" -Pydantic models for roar. +"""Lazy exports for core Pydantic model types.""" -This package provides typed, validated models for all roar data structures. -All models use Pydantic v2 with strict validation. -""" +from __future__ import annotations -# Base models -# Core domain models -from .artifact import Artifact, ArtifactHash -from .base import ImmutableModel, RoarBaseModel +from importlib import import_module +from typing import Any -# DAG visualization models -from .dag import ( - DagArtifactInfo, - DagNodeInfo, - DagNodeMetrics, - DagNodeState, - DagVisualization, -) -from .dataset_identifier import DatasetIdentifier +_EXPORTS = { + "Artifact": ".artifact", + "ArtifactDagResponse": ".glaas", + "ArtifactHash": ".artifact", + "ArtifactHashRequest": ".glaas", + "ArtifactResponse": ".glaas", + "CheckTagRequest": ".glaas", + "CheckTagResponse": ".glaas", + "CompleteLiveJobRequest": ".glaas", + "ContainerInfo": ".provenance", + "CreateDagRequest": ".glaas", + "CreateLiveJobRequest": ".glaas", + "DagArtifactInfo": ".dag", + "DagNodeInfo": ".dag", + "DagNodeMetrics": ".dag", + "DagNodeState": ".dag", + "DagResponse": ".glaas", + "DagVisualization": ".dag", + "DatasetIdentifier": ".dataset_identifier", + "FileClassification": ".provenance", + "FilteredFiles": ".provenance", + "GitInfo": ".provenance", + "HardwareInfo": ".provenance", + "IOEntry": ".glaas", + "ImmutableModel": ".base", + "Job": ".job", + "JobInput": ".job", + "JobOutput": ".job", + "JobResponse": ".glaas", + "LineageArtifactInfo": ".lineage", + "LineageJobInfo": ".lineage", + "LineageResponse": ".glaas", + "LineageResult": ".lineage", + "LiveJobResponse": ".glaas", + "PackageInfo": ".provenance", + "ProvenanceContext": ".provenance", + "PythonInjectData": ".provenance", + "RecordTagRequest": ".glaas", + "RegisterArtifactRequest": ".glaas", + "RegisterArtifactsBatchRequest": ".glaas", + "RegisterJobRequest": ".glaas", + "RegisterJobsBatchRequest": ".glaas", + "RegisterSessionRequest": ".glaas", + "ResolvedStep": ".run", + "RoarBaseModel": ".base", + "RunArguments": ".run", + "RunContext": ".run", + "RunResult": ".run", + "RuntimeInfo": ".provenance", + "Session": ".session", + "SessionResponse": ".glaas", + "TelemetryRunInfo": ".telemetry", + "TracerData": ".provenance", + "TracerResult": ".run", + "UpdateLiveJobRequest": ".glaas", + "VCSInfo": ".vcs", +} -# GLaaS API models -from .glaas import ( - ArtifactDagResponse, - ArtifactHashRequest, - ArtifactResponse, - CheckTagRequest, - CheckTagResponse, - CompleteLiveJobRequest, - CreateDagRequest, - CreateLiveJobRequest, - DagResponse, - IOEntry, - JobResponse, - LineageResponse, - LiveJobResponse, - RecordTagRequest, - RegisterArtifactRequest, - RegisterArtifactsBatchRequest, - RegisterJobRequest, - RegisterJobsBatchRequest, - RegisterSessionRequest, - SessionResponse, - UpdateLiveJobRequest, -) -from .job import Job, JobInput, JobOutput +__all__ = sorted(_EXPORTS) -# Lineage models -from .lineage import LineageArtifactInfo, LineageJobInfo, LineageResult -# Provenance models -from .provenance import ( - ContainerInfo, - FileClassification, - FilteredFiles, - GitInfo, - HardwareInfo, - PackageInfo, - ProvenanceContext, - PythonInjectData, - RuntimeInfo, - TracerData, -) +def __getattr__(name: str) -> Any: + module_name = _EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -# Run/execution models -from .run import ( - ResolvedStep, - RunArguments, - RunContext, - RunResult, - TracerResult, -) + module = import_module(module_name, __name__) + value = getattr(module, name) + globals()[name] = value + return value -# Session models -from .session import Session -# Telemetry models -from .telemetry import TelemetryRunInfo - -# VCS models -from .vcs import VCSInfo - -__all__ = [ - "Artifact", - "ArtifactDagResponse", - "ArtifactHash", - "ArtifactHashRequest", - "ArtifactResponse", - "CheckTagRequest", - "CheckTagResponse", - "CompleteLiveJobRequest", - "ContainerInfo", - "CreateDagRequest", - "CreateLiveJobRequest", - "DagArtifactInfo", - "DagNodeInfo", - "DagNodeMetrics", - "DagNodeState", - "DagResponse", - "DagVisualization", - "DatasetIdentifier", - "FileClassification", - "FilteredFiles", - "GitInfo", - "HardwareInfo", - "IOEntry", - "ImmutableModel", - "Job", - "JobInput", - "JobOutput", - "JobResponse", - "LineageArtifactInfo", - "LineageJobInfo", - "LineageResponse", - "LineageResult", - "LiveJobResponse", - "PackageInfo", - "ProvenanceContext", - "PythonInjectData", - "RecordTagRequest", - "RegisterArtifactRequest", - "RegisterArtifactsBatchRequest", - "RegisterJobRequest", - "RegisterJobsBatchRequest", - "RegisterSessionRequest", - "ResolvedStep", - "RoarBaseModel", - "RunArguments", - "RunContext", - "RunResult", - "RuntimeInfo", - "Session", - "SessionResponse", - "TelemetryRunInfo", - "TracerData", - "TracerResult", - "UpdateLiveJobRequest", - "VCSInfo", -] +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/roar/db/__init__.py b/roar/db/__init__.py index acd58174..2793bdac 100644 --- a/roar/db/__init__.py +++ b/roar/db/__init__.py @@ -1,65 +1,47 @@ -""" -Roar database layer. +"""Lazy exports for the roar database layer.""" -This package provides the database layer for roar lineage tracking, -following SOLID principles: +from __future__ import annotations -- Models: SQLAlchemy ORM models for all entities -- Repositories: Focused data access interfaces -- Services: Business logic orchestration -- Hashing: Strategy pattern for hash algorithms -- Engine: SQLAlchemy engine and session configuration +from importlib import import_module +from typing import Any -Usage: - Use the DatabaseContext for database access: +_EXPORTS = { + "Artifact": ".models", + "ArtifactHash": ".models", + "Base": ".models", + "Collection": ".models", + "CollectionMember": ".models", + "CompositeArtifactComponent": ".models", + "CompositeMembershipIndex": ".models", + "DatabaseContext": ".context", + "HashCache": ".models", + "Job": ".models", + "JobInput": ".models", + "JobOutput": ".models", + "Label": ".models", + "QueryDatabaseContext": ".query_context", + "SchemaVersion": ".models", + "Session": ".models", + "create_database_context": ".context", + "create_query_database_context": ".query_context", + "create_roar_engine": ".engine", + "create_session_factory": ".engine", + "init_database": ".engine", +} - from roar.db.context import create_database_context +__all__ = sorted(_EXPORTS) - with create_database_context(roar_dir) as ctx: - artifacts = ctx.artifacts.get_all() - jobs = ctx.jobs.get_recent() -""" -from .context import DatabaseContext, create_database_context -from .engine import create_roar_engine, create_session_factory, init_database -from .models import ( - Artifact, - ArtifactHash, - Base, - Collection, - CollectionMember, - CompositeArtifactComponent, - CompositeMembershipIndex, - HashCache, - Job, - JobInput, - JobOutput, - Label, - SchemaVersion, - Session, -) +def __getattr__(name: str) -> Any: + module_name = _EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -__all__ = [ - "Artifact", - "ArtifactHash", - # Models - "Base", - "Collection", - "CollectionMember", - "CompositeArtifactComponent", - "CompositeMembershipIndex", - # Context - "DatabaseContext", - "HashCache", - "Job", - "JobInput", - "JobOutput", - "Label", - "SchemaVersion", - "Session", - "create_database_context", - # Engine - "create_roar_engine", - "create_session_factory", - "init_database", -] + module = import_module(module_name, __name__) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/roar/db/context.py b/roar/db/context.py index 5ca1f74c..e3686cd9 100644 --- a/roar/db/context.py +++ b/roar/db/context.py @@ -1,38 +1,48 @@ -""" -Database context for roar. +"""Database context for roar.""" -Provides a lightweight context manager that exposes typed repository -and service properties, replacing the monolithic RoarDB facade. -""" +from __future__ import annotations +from functools import lru_cache from pathlib import Path - -from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session +from typing import TYPE_CHECKING, Any from ..core.exceptions import DatabaseConnectionError -from .engine import create_roar_engine, create_session_factory, init_database -from .repositories import ( - SQLAlchemyArtifactRepository, - SQLAlchemyCollectionRepository, - SQLAlchemyCompositeRepository, - SQLAlchemyHashCacheRepository, - SQLAlchemyJobRepository, - SQLAlchemyLabelRepository, - SQLAlchemySessionRepository, -) -from .schema import run_migrations -from .services import ( - DefaultHashingService, - DefaultLineageService, - DefaultSessionService, - JobRecordingService, -) - -try: - import sqlite3 as sqlite_module -except ImportError: - import pysqlite3 as sqlite_module # type: ignore[import-not-found, no-redef] + +if TYPE_CHECKING: + from sqlalchemy.engine import Engine as Engine + from sqlalchemy.orm import Session as Session + + from .repositories.artifact import ( + SQLAlchemyArtifactRepository as SQLAlchemyArtifactRepository, + ) + from .repositories.collection import ( + SQLAlchemyCollectionRepository as SQLAlchemyCollectionRepository, + ) + from .repositories.composite import ( + SQLAlchemyCompositeRepository as SQLAlchemyCompositeRepository, + ) + from .repositories.hash_cache import ( + SQLAlchemyHashCacheRepository as SQLAlchemyHashCacheRepository, + ) + from .repositories.job import SQLAlchemyJobRepository as SQLAlchemyJobRepository + from .repositories.label import SQLAlchemyLabelRepository as SQLAlchemyLabelRepository + from .repositories.session import ( + SQLAlchemySessionRepository as SQLAlchemySessionRepository, + ) + from .services.hashing import DefaultHashingService as DefaultHashingService + from .services.job_recording import JobRecordingService as JobRecordingService + from .services.lineage import DefaultLineageService as DefaultLineageService + from .services.session import DefaultSessionService as DefaultSessionService + + +@lru_cache(maxsize=1) +def _get_sqlite_module() -> Any: + try: + import sqlite3 as sqlite_module + except ImportError: + import pysqlite3 as sqlite_module # type: ignore[import-not-found, no-redef] + + return sqlite_module class DatabaseContext: @@ -60,7 +70,7 @@ def __init__(self, db_path: Path): self._engine: Engine | None = None self._session: Session | None = None - # Repositories (initialized on connect) + # Repositories (initialized on first access) self._hash_cache_repo: SQLAlchemyHashCacheRepository | None = None self._artifact_repo: SQLAlchemyArtifactRepository | None = None self._job_repo: SQLAlchemyJobRepository | None = None @@ -69,7 +79,7 @@ def __init__(self, db_path: Path): self._composite_repo: SQLAlchemyCompositeRepository | None = None self._label_repo: SQLAlchemyLabelRepository | None = None - # Services (initialized on connect) + # Services (initialized on first access) self._hashing_service: DefaultHashingService | None = None self._session_service: DefaultSessionService | None = None self._lineage_service: DefaultLineageService | None = None @@ -77,6 +87,11 @@ def __init__(self, db_path: Path): def connect(self) -> None: """Connect to the database and initialize schema if needed.""" + from .engine import create_roar_engine, create_session_factory, init_database + from .schema import run_migrations + + sqlite_module = _get_sqlite_module() + self._engine = create_roar_engine(self.db_path) init_database(self._engine) raw_conn = sqlite_module.connect(str(self.db_path)) @@ -86,32 +101,7 @@ def connect(self) -> None: raw_conn.commit() finally: raw_conn.close() - session_factory = create_session_factory(self._engine) - self._session = session_factory() - - # Initialize repositories - self._hash_cache_repo = SQLAlchemyHashCacheRepository(self._session) - self._artifact_repo = SQLAlchemyArtifactRepository(self._session) - self._job_repo = SQLAlchemyJobRepository(self._session, self._artifact_repo) - self._session_repo = SQLAlchemySessionRepository(self._session) - self._collection_repo = SQLAlchemyCollectionRepository(self._session) - self._composite_repo = SQLAlchemyCompositeRepository(self._session) - self._label_repo = SQLAlchemyLabelRepository(self._session) - - # Initialize services - self._hashing_service = DefaultHashingService(self._hash_cache_repo) - self._session_service = DefaultSessionService( - self._session_repo, self._job_repo, self._artifact_repo - ) - self._lineage_service = DefaultLineageService(self._artifact_repo, self._job_repo) - self._job_recording_service = JobRecordingService( - self._session, - self._job_repo, - self._artifact_repo, - self._session_repo, - self._hashing_service, - self._session_service, - ) + self._session = create_session_factory(self._engine)() def close(self) -> None: """Close database connection.""" @@ -122,12 +112,25 @@ def close(self) -> None: self._engine.dispose() self._engine = None + self._hash_cache_repo = None + self._artifact_repo = None + self._job_repo = None + self._session_repo = None + self._collection_repo = None + self._composite_repo = None + self._label_repo = None + + self._hashing_service = None + self._session_service = None + self._lineage_service = None + self._job_recording_service = None + def commit(self) -> None: """Commit the current transaction.""" if self._session: self._session.commit() - def __enter__(self) -> "DatabaseContext": + def __enter__(self) -> DatabaseContext: self.connect() return self @@ -139,6 +142,14 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: self._session.commit() self.close() + def _require_session(self) -> Session: + if self._session is None: + raise DatabaseConnectionError( + "DatabaseContext not connected. Use as context manager.", + db_path=str(self.db_path), + ) + return self._session + # ------------------------------------------------------------------------- # Repository properties # ------------------------------------------------------------------------- @@ -146,95 +157,74 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: @property def session(self) -> Session: """Get the underlying database session.""" - if self._session is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) - return self._session + return self._require_session() @property def conn(self): - """Get raw database connection for direct SQL queries. - - Note: Prefer using repositories when possible. This is for - legacy compatibility with raw SQL usage. - """ - if self._session is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) - return self._session.connection() + """Get raw database connection for direct SQL queries.""" + return self._require_session().connection() @property def hash_cache(self) -> SQLAlchemyHashCacheRepository: """Hash cache repository for file hash caching.""" if self._hash_cache_repo is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .repositories.hash_cache import SQLAlchemyHashCacheRepository + + self._hash_cache_repo = SQLAlchemyHashCacheRepository(self._require_session()) return self._hash_cache_repo @property def artifacts(self) -> SQLAlchemyArtifactRepository: """Artifact repository for content-addressed file storage.""" if self._artifact_repo is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .repositories.artifact import SQLAlchemyArtifactRepository + + self._artifact_repo = SQLAlchemyArtifactRepository(self._require_session()) return self._artifact_repo @property def jobs(self) -> SQLAlchemyJobRepository: """Job repository for execution records.""" if self._job_repo is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .repositories.job import SQLAlchemyJobRepository + + self._job_repo = SQLAlchemyJobRepository(self._require_session(), self.artifacts) return self._job_repo @property def sessions(self) -> SQLAlchemySessionRepository: """Session repository for step sequences.""" if self._session_repo is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .repositories.session import SQLAlchemySessionRepository + + self._session_repo = SQLAlchemySessionRepository(self._require_session()) return self._session_repo @property def collections(self) -> SQLAlchemyCollectionRepository: """Collection repository for artifact groups.""" if self._collection_repo is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .repositories.collection import SQLAlchemyCollectionRepository + + self._collection_repo = SQLAlchemyCollectionRepository(self._require_session()) return self._collection_repo @property def composites(self) -> SQLAlchemyCompositeRepository: """Composite repository for component rows and membership metadata.""" if self._composite_repo is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .repositories.composite import SQLAlchemyCompositeRepository + + self._composite_repo = SQLAlchemyCompositeRepository(self._require_session()) return self._composite_repo @property def labels(self) -> SQLAlchemyLabelRepository: """Label repository for versioned local label documents.""" if self._label_repo is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .repositories.label import SQLAlchemyLabelRepository + + self._label_repo = SQLAlchemyLabelRepository(self._require_session()) return self._label_repo # ------------------------------------------------------------------------- @@ -245,19 +235,21 @@ def labels(self) -> SQLAlchemyLabelRepository: def hashing(self) -> DefaultHashingService: """Hashing service for computing and caching file hashes.""" if self._hashing_service is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .services.hashing import DefaultHashingService + + self._hashing_service = DefaultHashingService(self.hash_cache) return self._hashing_service @property def session_service(self) -> DefaultSessionService: """Session service for session management operations.""" if self._session_service is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), + from .services.session import DefaultSessionService + + self._session_service = DefaultSessionService( + self.sessions, + self.jobs, + self.artifacts, ) return self._session_service @@ -265,19 +257,24 @@ def session_service(self) -> DefaultSessionService: def lineage(self) -> DefaultLineageService: """Lineage service for artifact lineage queries.""" if self._lineage_service is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), - ) + from .services.lineage import DefaultLineageService + + self._lineage_service = DefaultLineageService(self.artifacts, self.jobs) return self._lineage_service @property def job_recording(self) -> JobRecordingService: """Job recording service for recording jobs with lineage.""" if self._job_recording_service is None: - raise DatabaseConnectionError( - "DatabaseContext not connected. Use as context manager.", - db_path=str(self.db_path), + from .services.job_recording import JobRecordingService + + self._job_recording_service = JobRecordingService( + self._require_session(), + self.jobs, + self.artifacts, + self.sessions, + self.hashing, + self.session_service, ) return self._job_recording_service diff --git a/roar/db/query_context.py b/roar/db/query_context.py new file mode 100644 index 00000000..a3d1faf6 --- /dev/null +++ b/roar/db/query_context.py @@ -0,0 +1,729 @@ +"""Lightweight sqlite-backed context for hot read-only query commands.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from ..core.exceptions import DatabaseConnectionError +from .context import _get_sqlite_module +from .step_priority import step_sort_key + + +class QueryDatabaseContext: + """Context manager providing lightweight read-only query access.""" + + def __init__(self, db_path: Path): + self.db_path = db_path + self._sqlite_module = _get_sqlite_module() + self._connection: Any | None = None + self._sessions_repo = _QuerySessionRepository(self) + self._jobs_repo = _QueryJobRepository(self) + self._artifacts_repo = _QueryArtifactRepository(self) + self._labels_repo = _QueryLabelRepository(self) + self._composites_repo = _QueryCompositeRepository(self) + + def connect(self) -> None: + """Open a direct sqlite connection when the database exists.""" + if not self.db_path.exists(): + return + + try: + self._connection = self._sqlite_module.connect(str(self.db_path)) + self._connection.row_factory = self._sqlite_module.Row + except Exception as exc: # pragma: no cover - defensive parity with DB context + raise DatabaseConnectionError(f"Failed to connect to database: {exc}") from exc + + def close(self) -> None: + """Close the sqlite connection if one was opened.""" + if self._connection is not None: + self._connection.close() + self._connection = None + + def __enter__(self) -> QueryDatabaseContext: + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + @property + def sessions(self) -> _QuerySessionRepository: + return self._sessions_repo + + @property + def jobs(self) -> _QueryJobRepository: + return self._jobs_repo + + @property + def artifacts(self) -> _QueryArtifactRepository: + return self._artifacts_repo + + @property + def labels(self) -> _QueryLabelRepository: + return self._labels_repo + + @property + def composites(self) -> _QueryCompositeRepository: + return self._composites_repo + + def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[Any]: + if self._connection is None: + return [] + + try: + cursor = self._connection.execute(sql, params) + return cursor.fetchall() + except self._sqlite_module.OperationalError as exc: + if "no such table" in str(exc).lower(): + return [] + raise + + def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> Any | None: + rows = self._fetchall(sql, params) + if not rows: + return None + return rows[0] + + +class _QueryRepository: + """Shared helpers for direct sqlite repositories.""" + + def __init__(self, db_ctx: QueryDatabaseContext): + self._db_ctx = db_ctx + + def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[Any]: + return self._db_ctx._fetchall(sql, params) + + def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> Any | None: + return self._db_ctx._fetchone(sql, params) + + +class _QuerySessionRepository(_QueryRepository): + def get(self, session_id: int) -> dict[str, Any] | None: + row = self._fetchone( + """ + SELECT id, hash, created_at, source_artifact_hash, current_step, is_active, + git_repo, git_commit_start, git_commit_end, synced_at, metadata + FROM sessions + WHERE id = ? + LIMIT 1 + """, + (session_id,), + ) + return _session_row_to_dict(row) if row is not None else None + + def get_active(self) -> dict[str, Any] | None: + row = self._fetchone( + """ + SELECT id, hash, created_at, source_artifact_hash, current_step, is_active, + git_repo, git_commit_start, git_commit_end, synced_at, metadata + FROM sessions + WHERE is_active = 1 + LIMIT 1 + """ + ) + return _session_row_to_dict(row) if row is not None else None + + def get_steps(self, session_id: int) -> list[dict[str, Any]]: + rows = self._fetchall( + """ + SELECT * + FROM jobs + WHERE session_id = ? + ORDER BY step_number ASC, timestamp ASC + """, + (session_id,), + ) + return [_job_row_to_dict(row) for row in rows] + + def get_step_by_number( + self, session_id: int, step_number: int, job_type: str | None = None + ) -> dict[str, Any] | None: + if job_type == "build": + rows = self._fetchall( + """ + SELECT * + FROM jobs + WHERE session_id = ? AND step_number = ? AND job_type = 'build' + """, + (session_id, step_number), + ) + else: + rows = self._fetchall( + """ + SELECT * + FROM jobs + WHERE session_id = ? AND step_number = ? + AND (job_type IS NULL OR job_type != 'build') + """, + (session_id, step_number), + ) + + jobs = [_job_row_to_dict(row) for row in rows] + if not jobs: + return None + return max(jobs, key=step_sort_key) + + +class _QueryJobRepository(_QueryRepository): + def get_by_parent_uids( + self, parent_job_uids: list[str], job_type: str | None = None + ) -> list[dict[str, Any]]: + if not parent_job_uids: + return [] + + placeholders = ", ".join("?" for _ in parent_job_uids) + sql = f""" + SELECT * + FROM jobs + WHERE parent_job_uid IN ({placeholders}) + """ + params: list[Any] = [*parent_job_uids] + if job_type is not None: + sql += " AND job_type = ?" + params.append(job_type) + sql += " ORDER BY timestamp ASC" + + rows = self._fetchall(sql, tuple(params)) + return [_job_row_to_dict(row) for row in rows] + + def get_by_session(self, session_id: int, limit: int = 50) -> list[dict[str, Any]]: + rows = self._fetchall( + """ + SELECT * + FROM jobs + WHERE session_id = ? + ORDER BY timestamp DESC + LIMIT ? + """, + (session_id, limit), + ) + return [_job_row_to_dict(row) for row in rows] + + def get_latest_build_jobs(self, session_id: int) -> list[dict[str, Any]]: + rows = self._fetchall( + """ + SELECT j.* + FROM jobs AS j + JOIN ( + SELECT step_number, MAX(id) AS max_id + FROM jobs + WHERE session_id = ? AND job_type = 'build' + GROUP BY step_number + ) AS latest_build_ids ON j.id = latest_build_ids.max_id + ORDER BY j.step_number ASC + """, + (session_id,), + ) + return [_job_row_to_dict(row) for row in rows] + + def get_by_uid(self, job_uid: str) -> dict[str, Any] | None: + row = self._fetchone("SELECT * FROM jobs WHERE job_uid = ? LIMIT 1", (job_uid,)) + if row is not None: + return _job_row_to_dict(row) + + if len(job_uid) < 4: + return None + + rows = self._fetchall( + "SELECT * FROM jobs WHERE job_uid LIKE ? LIMIT 2", + (f"{job_uid}%",), + ) + if len(rows) != 1: + return None + return _job_row_to_dict(rows[0]) + + def get_inputs(self, job_id: int) -> list[dict[str, Any]]: + return self._get_job_artifacts("job_inputs", job_id) + + def get_outputs(self, job_id: int) -> list[dict[str, Any]]: + return self._get_job_artifacts("job_outputs", job_id) + + def get_distinct_outputs_by_session(self, session_id: int) -> list[dict[str, Any]]: + rows = self._fetchall( + """ + SELECT j.id AS job_id, + j.timestamp AS job_timestamp, + jo.path, + jo.artifact_id, + a.size, + a.first_seen_path, + a.kind, + a.component_count, + ah.algorithm, + ah.digest + FROM jobs AS j + JOIN job_outputs AS jo ON jo.job_id = j.id + JOIN artifacts AS a ON a.id = jo.artifact_id + LEFT JOIN artifact_hashes AS ah ON ah.artifact_id = a.id + WHERE j.session_id = ? + ORDER BY j.timestamp DESC, j.id DESC, jo.path ASC, ah.algorithm ASC, ah.digest ASC + """, + (session_id,), + ) + + results: list[dict[str, Any]] = [] + seen_artifact_ids: set[str] = set() + current_by_artifact: dict[str, dict[str, Any]] = {} + + for row in rows: + artifact_id = str(row["artifact_id"]) + if artifact_id not in current_by_artifact and artifact_id not in seen_artifact_ids: + artifact = { + "artifact_id": artifact_id, + "artifact_hash": None, + "size": int(row["size"] or 0), + "path": str(row["path"] or row["first_seen_path"] or ""), + "kind": row["kind"], + "component_count": row["component_count"], + "hashes": [], + } + current_by_artifact[artifact_id] = artifact + results.append(artifact) + seen_artifact_ids.add(artifact_id) + + artifact_entry = current_by_artifact.get(artifact_id) + if artifact_entry is None: + continue + + algorithm = row["algorithm"] + digest = row["digest"] + if algorithm and digest: + artifact_entry["hashes"].append( + { + "algorithm": str(algorithm), + "digest": str(digest), + } + ) + if artifact_entry["artifact_hash"] is None: + artifact_entry["artifact_hash"] = str(digest) + + return results + + def _get_job_artifacts(self, table_name: str, job_id: int) -> list[dict[str, Any]]: + rows = self._fetchall( + f""" + SELECT io.path, + io.artifact_id, + io.byte_ranges, + a.size, + a.first_seen_path, + a.kind, + a.component_count, + ah.algorithm, + ah.digest + FROM {table_name} AS io + JOIN artifacts AS a ON a.id = io.artifact_id + LEFT JOIN artifact_hashes AS ah ON ah.artifact_id = a.id + WHERE io.job_id = ? + ORDER BY io.path ASC, io.artifact_id ASC, ah.algorithm ASC, ah.digest ASC + """, + (job_id,), + ) + + results: list[dict[str, Any]] = [] + current_key: tuple[Any, ...] | None = None + current: dict[str, Any] | None = None + + for row in rows: + key = ( + row["path"], + row["artifact_id"], + row["byte_ranges"], + row["size"], + row["first_seen_path"], + row["kind"], + row["component_count"], + ) + if key != current_key: + current = { + "path": str(row["path"] or row["first_seen_path"] or ""), + "artifact_id": str(row["artifact_id"]), + "size": int(row["size"] or 0), + "kind": row["kind"], + "component_count": row["component_count"], + "hashes": [], + "artifact_hash": None, + "first_seen_path": row["first_seen_path"], + "byte_ranges": json.loads(row["byte_ranges"]) if row["byte_ranges"] else None, + } + results.append(current) + current_key = key + + if current is None: + continue + + algorithm = row["algorithm"] + digest = row["digest"] + if algorithm and digest: + current["hashes"].append( + { + "algorithm": str(algorithm), + "digest": str(digest), + } + ) + if current["artifact_hash"] is None: + current["artifact_hash"] = str(digest) + + return results + + +class _QueryArtifactRepository(_QueryRepository): + def get(self, artifact_id: str) -> dict[str, Any] | None: + row = self._fetchone( + """ + SELECT id, size, first_seen_at, first_seen_path, source_type, source_url, + uploaded_to, synced_at, kind, component_count, metadata + FROM artifacts + WHERE id = ? + LIMIT 1 + """, + (artifact_id,), + ) + if row is None: + return None + + artifact = _artifact_row_to_dict(row) + artifact["hashes"] = self.get_hashes(artifact_id) + return artifact + + def get_hashes(self, artifact_id: str) -> list[dict[str, Any]]: + rows = self._fetchall( + """ + SELECT algorithm, digest + FROM artifact_hashes + WHERE artifact_id = ? + ORDER BY algorithm ASC, digest ASC + """, + (artifact_id,), + ) + return [ + { + "algorithm": str(row["algorithm"]), + "digest": str(row["digest"]), + } + for row in rows + ] + + def get_by_hash(self, digest: str, algorithm: str | None = None) -> dict[str, Any] | None: + normalized_digest = digest.lower() + if algorithm is None: + rows = self._fetchall( + """ + SELECT DISTINCT artifact_id + FROM artifact_hashes + WHERE digest LIKE ? + LIMIT 2 + """, + (f"{normalized_digest}%",), + ) + else: + rows = self._fetchall( + """ + SELECT DISTINCT artifact_id + FROM artifact_hashes + WHERE algorithm = ? AND digest LIKE ? + LIMIT 2 + """, + (algorithm, f"{normalized_digest}%"), + ) + + if len(rows) != 1: + return None + + artifact = self.get(str(rows[0]["artifact_id"])) + if artifact is None: + return None + hashes = artifact.get("hashes", []) + artifact["hash"] = hashes[0]["digest"] if hashes else None + return artifact + + def get_by_path(self, path: str) -> dict[str, Any] | None: + output = self._fetchone( + """ + SELECT jo.artifact_id + FROM job_outputs AS jo + JOIN jobs AS j ON j.id = jo.job_id + WHERE jo.path = ? + ORDER BY j.timestamp DESC + LIMIT 1 + """, + (path,), + ) + if output is not None: + return self.get(str(output["artifact_id"])) + + input_row = self._fetchone( + """ + SELECT ji.artifact_id + FROM job_inputs AS ji + JOIN jobs AS j ON j.id = ji.job_id + WHERE ji.path = ? + ORDER BY j.timestamp DESC + LIMIT 1 + """, + (path,), + ) + if input_row is not None: + return self.get(str(input_row["artifact_id"])) + + artifact = self._fetchone( + """ + SELECT id, size, first_seen_at, first_seen_path, source_type, source_url, + uploaded_to, synced_at, kind, component_count, metadata + FROM artifacts + WHERE first_seen_path = ? + LIMIT 1 + """, + (path,), + ) + if artifact is None: + return None + + result = _artifact_row_to_dict(artifact) + result["hashes"] = self.get_hashes(str(artifact["id"])) + return result + + def get_locations(self, artifact_id: str) -> list[dict[str, str]]: + paths = { + str(row["path"]) + for row in self._fetchall( + "SELECT DISTINCT path FROM job_outputs WHERE artifact_id = ?", + (artifact_id,), + ) + if row["path"] + } + paths.update( + { + str(row["path"]) + for row in self._fetchall( + "SELECT DISTINCT path FROM job_inputs WHERE artifact_id = ?", + (artifact_id,), + ) + if row["path"] + } + ) + + artifact = self._fetchone( + "SELECT first_seen_path FROM artifacts WHERE id = ?", (artifact_id,) + ) + if artifact is not None and artifact["first_seen_path"]: + paths.add(str(artifact["first_seen_path"])) + + return [{"path": path} for path in sorted(paths)] + + def get_jobs(self, artifact_id: str) -> dict[str, list[dict[str, Any]]]: + produced_by = [ + _job_row_to_dict(row) + for row in self._fetchall( + """ + SELECT j.* + FROM jobs AS j + JOIN job_outputs AS jo ON j.id = jo.job_id + WHERE jo.artifact_id = ? + ORDER BY j.timestamp DESC + """, + (artifact_id,), + ) + ] + consumed_by = [ + _job_row_to_dict(row) + for row in self._fetchall( + """ + SELECT j.* + FROM jobs AS j + JOIN job_inputs AS ji ON j.id = ji.job_id + WHERE ji.artifact_id = ? + ORDER BY j.timestamp DESC + """, + (artifact_id,), + ) + ] + return { + "produced_by": produced_by, + "consumed_by": consumed_by, + } + + +class _QueryLabelRepository(_QueryRepository): + def get_current( + self, + entity_type: str, + *, + session_id: int | None = None, + job_id: int | None = None, + artifact_id: str | None = None, + ) -> dict[str, Any] | None: + if session_id is not None: + clauses = "session_id = ? AND job_id IS NULL AND artifact_id IS NULL" + params: tuple[Any, ...] = (entity_type, session_id) + elif job_id is not None: + clauses = "session_id IS NULL AND job_id = ? AND artifact_id IS NULL" + params = (entity_type, job_id) + elif artifact_id is not None: + clauses = "session_id IS NULL AND job_id IS NULL AND artifact_id = ?" + params = (entity_type, artifact_id) + else: + return None + + row = self._fetchone( + f""" + SELECT id, entity_type, session_id, job_id, artifact_id, version, metadata, + created_at, synced_at, synced_server_label_id + FROM labels + WHERE entity_type = ? AND {clauses} + ORDER BY version DESC + LIMIT 1 + """, + params, + ) + if row is None: + return None + + return { + "id": int(row["id"]), + "entity_type": str(row["entity_type"]), + "session_id": row["session_id"], + "job_id": row["job_id"], + "artifact_id": row["artifact_id"], + "version": int(row["version"]), + "metadata": json.loads(row["metadata"]), + "created_at": row["created_at"], + "synced_at": row["synced_at"], + "synced_server_label_id": row["synced_server_label_id"], + } + + +class _QueryCompositeRepository(_QueryRepository): + def get(self, artifact_id: str) -> dict[str, Any] | None: + row = self._fetchone( + "SELECT id, kind, component_count FROM artifacts WHERE id = ? LIMIT 1", + (artifact_id,), + ) + if row is None or row["kind"] != "composite": + return None + + return { + "artifact_id": str(row["id"]), + "kind": str(row["kind"]), + "component_count": row["component_count"], + "membership_index": self.get_membership_index(artifact_id), + } + + def get_components(self, artifact_id: str, limit: int = 1000) -> list[dict[str, Any]]: + rows = self._fetchall( + """ + SELECT id, composite_artifact_id, ordinal, relative_path, leaf_kind, + component_algorithm, component_digest, component_size, + component_type, artifact_id + FROM composite_artifact_components + WHERE composite_artifact_id = ? + ORDER BY ordinal ASC, id ASC + LIMIT ? + """, + (artifact_id, limit), + ) + return [ + { + "id": int(row["id"]), + "composite_artifact_id": str(row["composite_artifact_id"]), + "ordinal": row["ordinal"], + "relative_path": row["relative_path"], + "leaf_kind": row["leaf_kind"], + "component_algorithm": row["component_algorithm"], + "component_digest": row["component_digest"], + "component_size": row["component_size"], + "component_type": row["component_type"], + "artifact_id": row["artifact_id"], + } + for row in rows + ] + + def get_membership_index(self, artifact_id: str) -> dict[str, Any] | None: + row = self._fetchone( + """ + SELECT total_components, stored_components, bloom_filter_base64, bloom_bits, + bloom_hashes, bloom_version + FROM composite_membership_indexes + WHERE composite_artifact_id = ? + LIMIT 1 + """, + (artifact_id,), + ) + if row is None: + return None + + return { + "total_components": int(row["total_components"]), + "stored_components": int(row["stored_components"]), + "bloom_filter_base64": row["bloom_filter_base64"], + "bloom_bits": row["bloom_bits"], + "bloom_hashes": row["bloom_hashes"], + "bloom_version": row["bloom_version"], + } + + +def _session_row_to_dict(row: Any) -> dict[str, Any]: + return { + "id": int(row["id"]), + "hash": row["hash"], + "created_at": row["created_at"], + "source_artifact_hash": row["source_artifact_hash"], + "current_step": row["current_step"], + "is_active": row["is_active"], + "git_repo": row["git_repo"], + "git_commit_start": row["git_commit_start"], + "git_commit_end": row["git_commit_end"], + "synced_at": row["synced_at"], + "metadata": row["metadata"], + } + + +def _job_row_to_dict(row: Any) -> dict[str, Any]: + return { + "id": int(row["id"]), + "job_uid": row["job_uid"], + "parent_job_uid": row["parent_job_uid"], + "timestamp": row["timestamp"], + "command": row["command"], + "script": row["script"], + "step_identity": row["step_identity"], + "session_id": row["session_id"], + "step_number": row["step_number"], + "step_name": row["step_name"], + "git_repo": row["git_repo"], + "git_commit": row["git_commit"], + "git_branch": row["git_branch"], + "duration_seconds": row["duration_seconds"], + "exit_code": row["exit_code"], + "synced_at": row["synced_at"], + "status": row["status"], + "execution_backend": row["execution_backend"], + "execution_role": row["execution_role"], + "job_type": row["job_type"], + "metadata": row["metadata"], + "telemetry": row["telemetry"], + } + + +def _artifact_row_to_dict(row: Any) -> dict[str, Any]: + return { + "id": str(row["id"]), + "size": int(row["size"] or 0), + "first_seen_at": row["first_seen_at"], + "first_seen_path": row["first_seen_path"], + "source_type": row["source_type"], + "source_url": row["source_url"], + "uploaded_to": row["uploaded_to"], + "synced_at": row["synced_at"], + "kind": row["kind"], + "component_count": row["component_count"], + "metadata": row["metadata"], + } + + +def create_query_database_context(roar_dir: Path) -> QueryDatabaseContext: + """Create a lightweight query context for the given .roar directory.""" + return QueryDatabaseContext(roar_dir / "roar.db") diff --git a/roar/db/repositories/session.py b/roar/db/repositories/session.py index c005634d..fb9bb689 100644 --- a/roar/db/repositories/session.py +++ b/roar/db/repositories/session.py @@ -15,12 +15,8 @@ from sqlalchemy.orm import Session as SASession from ...core.interfaces.repositories import SessionRepository -from ...execution.framework.registry import ( - is_execution_noise_job, - is_execution_phase_job, - is_execution_task_job, -) from ..models import Job, Session +from ..step_priority import step_sort_key class SQLAlchemySessionRepository(SessionRepository): @@ -327,7 +323,7 @@ def get_step_by_number( jobs = self._session.execute(query).scalars().all() if not jobs: return None - return self._job_to_dict(max(jobs, key=self._step_sort_key)) + return self._job_to_dict(max(jobs, key=step_sort_key)) def get_step_by_name(self, session_id: int, step_name: str) -> dict[str, Any] | None: """ @@ -648,37 +644,6 @@ def _job_to_dict(self, job: Job) -> dict[str, Any]: "telemetry": job.telemetry, } - @staticmethod - def _step_sort_key(job: Job) -> tuple[int, float, int]: - job_dict = { - "id": job.id, - "timestamp": job.timestamp, - "command": job.command, - "job_type": job.job_type, - "script": job.script, - "parent_job_uid": job.parent_job_uid, - "execution_backend": job.execution_backend, - "execution_role": job.execution_role, - } - if job.job_type in (None, "run") and not is_execution_task_job(job_dict): - if is_execution_noise_job(job_dict): - priority = 1 - else: - priority = 6 - elif is_execution_phase_job(job_dict): - priority = 5 - elif is_execution_task_job(job_dict): - priority = 4 - elif is_execution_noise_job(job_dict): - priority = 1 - else: - priority = 2 - return ( - priority, - float(job.timestamp or 0.0), - int(job.id or 0), - ) - # Backward compatibility alias SQLiteSessionRepository = SQLAlchemySessionRepository diff --git a/roar/db/step_priority.py b/roar/db/step_priority.py new file mode 100644 index 00000000..a376008c --- /dev/null +++ b/roar/db/step_priority.py @@ -0,0 +1,91 @@ +"""Local step-priority helpers for session step selection. + +These helpers intentionally avoid importing execution backend discovery. +The hot read/query path only needs stable ordering for known local records, +not dynamic backend registration. +""" + +from __future__ import annotations + +from typing import Any + +_NOISE_COMMANDS = { + "ray_task:unknown", + "ray_task:__init__", + "ray_task:shutdown", + "ray_task:s3_proxy", + "ray_task:s3_driver_proxy", + "ray_task:RoarNodeAgent.__init__", +} + + +def _get_job_value(job: Any, key: str) -> Any: + if job is None: + return None + if isinstance(job, dict): + return job.get(key) + return getattr(job, key, None) + + +def step_sort_key(job: Any) -> tuple[int, float, int]: + """Return a stable ordering key for choosing the visible step record.""" + return ( + step_priority(job), + float(_get_job_value(job, "timestamp") or 0.0), + int(_get_job_value(job, "id") or 0), + ) + + +def is_noise_job(job: Any) -> bool: + """Return True when a job is local execution noise.""" + execution_role = str(_get_job_value(job, "execution_role") or "").strip().lower() + if execution_role == "noise": + return True + + command = str(_get_job_value(job, "command") or "") + return command in _NOISE_COMMANDS + + +def is_task_job(job: Any) -> bool: + """Return True when a job is a worker/task record.""" + execution_role = str(_get_job_value(job, "execution_role") or "").strip().lower() + if execution_role == "task": + return True + if is_noise_job(job): + return False + + command = str(_get_job_value(job, "command") or "") + if command.startswith("ray_task:"): + return True + + job_type = str(_get_job_value(job, "job_type") or "").strip().lower() + return job_type == "ray_task" + + +def is_host_or_submit_job(job: Any) -> bool: + """Return True when a job should be treated as the local host/submit record.""" + execution_role = str(_get_job_value(job, "execution_role") or "").strip().lower() + if execution_role in {"host", "submit"}: + return True + + command = str(_get_job_value(job, "command") or "") + if command.startswith("ray job submit"): + return True + + job_type = str(_get_job_value(job, "job_type") or "").strip().lower() + return job_type in {"", "run"} + + +def step_priority(job: Any) -> int: + """Classify a job into the same priority bands used for step resolution.""" + execution_role = str(_get_job_value(job, "execution_role") or "").strip().lower() + if is_noise_job(job): + return 1 + if execution_role == "phase": + return 5 + if is_task_job(job): + return 4 + if is_host_or_submit_job(job): + return 6 + + return 2 diff --git a/roar/execution/provenance/data_loader.py b/roar/execution/provenance/data_loader.py index de701702..5fea3e24 100644 --- a/roar/execution/provenance/data_loader.py +++ b/roar/execution/provenance/data_loader.py @@ -45,16 +45,9 @@ def load_tracer_data(self, path: str) -> TracerData: """ self.logger.debug("Loading tracer data from: %s", path) with open(path, "rb") as f: - raw_data = f.read() + payload = f.read() - try: - data = msgpack.unpackb(raw_data, raw=False) - except msgpack.ExtraData: - stripped = raw_data.lstrip() - if not stripped.startswith((b"{", b"[")): - raise - self.logger.debug("Tracer report was JSON-encoded; falling back to JSON parsing") - data = json.loads(raw_data.decode("utf-8")) + data = self._parse_tracer_payload(payload, path) self.logger.debug("Tracer data parsed successfully: %d keys", len(data)) files = self._normalize_files(data) @@ -73,6 +66,23 @@ def load_tracer_data(self, path: str) -> TracerData: events_dropped=int(data.get("events_dropped", 0) or 0), ) + def _parse_tracer_payload(self, payload: bytes, path: str) -> dict: + """Parse the tracer report from MessagePack or a legacy JSON payload.""" + stripped = payload.lstrip() + if stripped.startswith((b"{", b"[")): + self.logger.warning( + "Tracer report at %s is JSON, not MessagePack; accepting legacy format", path + ) + data = json.loads(payload.decode("utf-8")) + if not isinstance(data, dict): + raise ValueError(f"Expected tracer report object in {path}") + return data + + data = msgpack.unpackb(payload, raw=False) + if not isinstance(data, dict): + raise ValueError(f"Expected tracer report object in {path}") + return data + def _normalize_files(self, data: dict) -> list[dict]: """Normalize tracer file records to a common shape.""" normalized: list[dict] = [] diff --git a/roar/integrations/config/__init__.py b/roar/integrations/config/__init__.py index 50cf615e..e72d1bab 100644 --- a/roar/integrations/config/__init__.py +++ b/roar/integrations/config/__init__.py @@ -1,70 +1,61 @@ -"""Canonical configuration integration surface.""" +"""Lazy exports for roar configuration helpers.""" -from .access import ( - CONFIGURABLE_KEYS, - CORE_CONFIGURABLE_KEYS, - VALID_HASH_ALGORITHMS, - _get_default_config, - config_get, - config_list, - config_set, - get_config_path_for_write, - get_configurable_keys, - get_hash_algorithms, - get_roar_dir, - load_config, - save_config, -) -from .loader import RoarSettings, find_config_file, find_roar_dir, load_settings -from .schema import ( - AnalyzersConfig, - CleanupConfig, - CompositesConfig, - ConfigBaseModel, - FiltersConfig, - GlaasConfig, - HashConfig, - LoggingConfig, - OutputConfig, - ProxyConfig, - RegisterConfig, - ReversibleConfig, - RoarConfig, - RunCompositeConfig, - TracerConfig, -) +from __future__ import annotations -__all__ = [ - "CONFIGURABLE_KEYS", - "CORE_CONFIGURABLE_KEYS", - "VALID_HASH_ALGORITHMS", - "AnalyzersConfig", - "CleanupConfig", - "CompositesConfig", - "ConfigBaseModel", - "FiltersConfig", - "GlaasConfig", - "HashConfig", - "LoggingConfig", - "OutputConfig", - "ProxyConfig", - "RegisterConfig", - "ReversibleConfig", - "RoarConfig", - "RoarSettings", - "RunCompositeConfig", - "TracerConfig", - "_get_default_config", - "config_get", - "config_list", - "config_set", - "find_config_file", - "find_roar_dir", - "get_config_path_for_write", - "get_configurable_keys", - "get_hash_algorithms", - "get_roar_dir", - "load_config", - "load_settings", - "save_config", -] +from importlib import import_module +from typing import Any + +_EXPORTS = { + "AnalyzersConfig": ".schema", + "CleanupConfig": ".schema", + "CONFIGURABLE_KEYS": ".access", + "CORE_CONFIGURABLE_KEYS": ".access", + "CompositesConfig": ".schema", + "ConfigBaseModel": ".schema", + "FiltersConfig": ".schema", + "GlaasConfig": ".schema", + "HashConfig": ".schema", + "LoggingConfig": ".schema", + "OutputConfig": ".schema", + "ProxyConfig": ".schema", + "RegisterConfig": ".schema", + "ReversibleConfig": ".schema", + "RoarConfig": ".schema", + "RoarSettings": ".loader", + "RunCompositeConfig": ".schema", + "TracerConfig": ".schema", + "VALID_HASH_ALGORITHMS": ".access", + "_get_default_config": ".access", + "config_get": ".access", + "config_list": ".access", + "config_set": ".access", + "find_config_file": ".loader", + "find_raw_config_file": ".raw", + "find_roar_dir": ".loader", + "get_config_path_for_write": ".access", + "get_configurable_keys": ".access", + "get_hash_algorithms": ".access", + "get_raw_glaas_web_url": ".raw", + "get_raw_registration_omit_config": ".raw", + "get_roar_dir": ".access", + "load_config": ".access", + "load_settings": ".loader", + "save_config": ".access", +} + +__all__ = sorted(_EXPORTS) + + +def __getattr__(name: str) -> Any: + module_name = _EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module = import_module(module_name, __name__) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/roar/integrations/config/access.py b/roar/integrations/config/access.py index a65e09ad..0ea6ab2e 100644 --- a/roar/integrations/config/access.py +++ b/roar/integrations/config/access.py @@ -251,6 +251,29 @@ def load_config(config_path: Path | None = None, start_dir: str | None = None) - return settings.to_dict() +_MISSING = object() + + +def _coerce_settings_value(value: Any) -> Any: + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + return model_dump() + return value + + +def _get_nested_from_settings(settings: Any, key: str, default: Any = _MISSING) -> Any: + current: Any = settings + for part in key.split("."): + if hasattr(current, part): + current = getattr(current, part) + continue + if isinstance(current, dict) and part in current: + current = current[part] + continue + return default + return _coerce_settings_value(current) + + def get_roar_dir(start_dir: str | None = None) -> Path: """ Get the .roar directory path, creating it if needed. @@ -388,7 +411,12 @@ def save_config(config: dict, config_path: Path): def config_get(key: str, start_dir: str | None = None): """Get a config value.""" - config = load_config(start_dir=start_dir) + settings = load_settings(start_dir=start_dir) + value = _get_nested_from_settings(settings, key, _MISSING) + if value is not _MISSING: + return value + + config = settings.to_dict() return _get_nested(config, key) diff --git a/roar/integrations/config/loader.py b/roar/integrations/config/loader.py index e3c52292..bf713e44 100644 --- a/roar/integrations/config/loader.py +++ b/roar/integrations/config/loader.py @@ -207,6 +207,8 @@ class RoarSettings(BaseSettings): composites: CompositesConfig = CompositesConfig() env: dict[str, str] = {} _backend_configs: dict[str, dict[str, Any]] = PrivateAttr(default_factory=dict) + _backend_config_source: dict[str, Any] = PrivateAttr(default_factory=dict) + _backend_configs_loaded: bool = PrivateAttr(default=False) # Internal fields (not from config) _config_file: str | None = None @@ -285,6 +287,7 @@ def settings_customise_sources( def to_dict(self) -> dict[str, Any]: """Convert settings to dict format compatible with legacy config.""" + self._ensure_backend_configs() result: dict[str, Any] = { "output": self.output.model_dump(), "analyzers": self.analyzers.model_dump(), @@ -307,6 +310,13 @@ def to_dict(self) -> dict[str, Any]: result["_config_error"] = self._config_error return result + def _ensure_backend_configs(self) -> None: + if self._backend_configs_loaded: + return + + self._backend_configs = _resolve_backend_config_sections(self._backend_config_source) + self._backend_configs_loaded = True + # Module-level variables for passing to settings_customise_sources _current_config_path: Path | None = None @@ -339,7 +349,10 @@ def load_settings(config_path: Path | None = None, start_dir: str | None = None) settings._config_file = toml_data["_config_file"] if "_config_error" in toml_data: settings._config_error = toml_data["_config_error"] - settings._backend_configs = _resolve_backend_config_sections(toml_data) + settings._backend_config_source = { + key: value for key, value in toml_data.items() if not str(key).startswith("_") + } + settings._backend_configs_loaded = False return settings finally: diff --git a/roar/integrations/config/raw.py b/roar/integrations/config/raw.py new file mode 100644 index 00000000..92df75ae --- /dev/null +++ b/roar/integrations/config/raw.py @@ -0,0 +1,196 @@ +"""Lightweight config access for performance-sensitive preview paths.""" + +from __future__ import annotations + +import json +import os +import subprocess +from copy import deepcopy +from pathlib import Path +from typing import Any + +try: + import tomllib +except ImportError: + import tomli as tomllib + +_DEFAULT_GLAAS_WEB_URL = "https://glaas.ai" +_DEFAULT_REGISTRATION_OMIT = { + "enabled": True, + "secrets": {"values": []}, + "env_vars": { + "names": [ + "WANDB_API_KEY", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "GITHUB_TOKEN", + "DATABASE_URL", + "AWS_SECRET_ACCESS_KEY", + ] + }, + "patterns": [], + "allowlist": {"patterns": []}, +} + + +def _infer_search_stop(start: Path) -> Path: + """Infer an upward-search boundary without importing the full settings stack.""" + try: + out = subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], + stderr=subprocess.DEVNULL, + cwd=start, + text=True, + ).strip() + if out: + return Path(out).resolve() + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + pass + return start.resolve() + + +def find_raw_config_file( + start_dir: str | os.PathLike[str] | None = None, + stop_dir: str | os.PathLike[str] | None = None, +) -> Path | None: + """Find a local roar config file without loading Pydantic settings.""" + start = Path(start_dir) if start_dir else Path.cwd() + stop_path = Path(stop_dir).resolve() if stop_dir else _infer_search_stop(start) + + for parent in [start, *list(start.parents)]: + config_path = parent / ".roar" / "config.toml" + if config_path.exists(): + return config_path + + pyproject = parent / "pyproject.toml" + if pyproject.exists(): + try: + with open(pyproject, "rb") as handle: + data = tomllib.load(handle) + if "tool" in data and "roar" in data["tool"]: + return pyproject + except (tomllib.TOMLDecodeError, OSError): + pass + + if parent.resolve() == stop_path: + break + + return None + + +def get_raw_glaas_web_url( + start_dir: str | os.PathLike[str] | None = None, +) -> str | None: + """Resolve the GLaaS web URL for local preview output.""" + env_value = _normalize_url(os.environ.get("ROAR_GLAAS__WEB_URL")) + if env_value is not None: + return env_value + + raw_config = _load_raw_config(start_dir=start_dir) + glaas_config = raw_config.get("glaas") + if isinstance(glaas_config, dict): + return _normalize_url(glaas_config.get("web_url")) + return None + + +def get_raw_registration_omit_config( + start_dir: str | os.PathLike[str] | None = None, +) -> dict[str, Any]: + """Resolve preview-only omit config without importing the full config package.""" + resolved = deepcopy(_DEFAULT_REGISTRATION_OMIT) + raw_config = _load_raw_config(start_dir=start_dir) + registration = raw_config.get("registration") + if isinstance(registration, dict): + omit_config = registration.get("omit") + if isinstance(omit_config, dict): + _deep_update(resolved, omit_config) + + _apply_registration_omit_env_overrides(resolved) + return resolved + + +def _load_raw_config( + *, + start_dir: str | os.PathLike[str] | None = None, +) -> dict[str, Any]: + path = find_raw_config_file(start_dir=start_dir) + if path is None: + return {} + + try: + with open(path, "rb") as handle: + data = tomllib.load(handle) + except (tomllib.TOMLDecodeError, OSError): + return {} + + if path.name == "pyproject.toml": + tool_config = data.get("tool", {}) + if isinstance(tool_config, dict): + roar_config = tool_config.get("roar", {}) + return roar_config if isinstance(roar_config, dict) else {} + return {} + + return data if isinstance(data, dict) else {} + + +def _deep_update(base: dict[str, Any], override: dict[str, Any]) -> None: + for key, value in override.items(): + if isinstance(base.get(key), dict) and isinstance(value, dict): + _deep_update(base[key], value) + continue + base[key] = value + + +def _apply_registration_omit_env_overrides(config: dict[str, Any]) -> None: + enabled = os.environ.get("ROAR_REGISTRATION__OMIT__ENABLED") + if enabled is not None: + config["enabled"] = _parse_bool(enabled) + + secrets_values = os.environ.get("ROAR_REGISTRATION__OMIT__SECRETS__VALUES") + if secrets_values is not None: + config.setdefault("secrets", {})["values"] = _parse_list(secrets_values) + + env_var_names = os.environ.get("ROAR_REGISTRATION__OMIT__ENV_VARS__NAMES") + if env_var_names is not None: + config.setdefault("env_vars", {})["names"] = _parse_list(env_var_names) + + allowlist_patterns = os.environ.get("ROAR_REGISTRATION__OMIT__ALLOWLIST__PATTERNS") + if allowlist_patterns is not None: + config.setdefault("allowlist", {})["patterns"] = _parse_list(allowlist_patterns) + + custom_patterns = os.environ.get("ROAR_REGISTRATION__OMIT__PATTERNS") + if custom_patterns is not None: + parsed = _parse_json(custom_patterns) + config["patterns"] = parsed if isinstance(parsed, list) else [] + + +def _parse_bool(value: str) -> bool: + normalized = value.strip().lower() + return normalized in {"1", "true", "yes", "on"} + + +def _parse_list(value: str) -> list[str]: + parsed = _parse_json(value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + + stripped = value.strip() + if not stripped: + return [] + return [item.strip() for item in stripped.split(",") if item.strip()] + + +def _parse_json(value: str) -> Any: + try: + return json.loads(value) + except json.JSONDecodeError: + return None + + +def _normalize_url(value: Any) -> str | None: + if not isinstance(value, str): + return None + stripped = value.strip() + if not stripped: + return None + return stripped.rstrip("/") diff --git a/roar/integrations/glaas/__init__.py b/roar/integrations/glaas/__init__.py index dca68ad7..b332e4aa 100644 --- a/roar/integrations/glaas/__init__.py +++ b/roar/integrations/glaas/__init__.py @@ -1,38 +1,41 @@ -"""GLaaS API support modules.""" - -from .auth import ( - compute_pubkey_fingerprint, - create_signature_payload, - find_ssh_private_key, - find_ssh_pubkey, - get_glaas_url, - make_auth_header, - sign_payload, -) -from .client import GlaasClient -from .fragment_streamer import GlaasFragmentStreamer -from .registration import ( - ArtifactRegistrationService, - JobRegistrationService, - RegistrationCoordinator, - SessionRegistrationService, -) -from .transport import parse_json_response, request_json - -__all__ = [ - "ArtifactRegistrationService", - "GlaasClient", - "GlaasFragmentStreamer", - "JobRegistrationService", - "RegistrationCoordinator", - "SessionRegistrationService", - "compute_pubkey_fingerprint", - "create_signature_payload", - "find_ssh_private_key", - "find_ssh_pubkey", - "get_glaas_url", - "make_auth_header", - "parse_json_response", - "request_json", - "sign_payload", -] +"""Lazy exports for GLaaS integration helpers.""" + +from __future__ import annotations + +from importlib import import_module +from typing import Any + +_EXPORTS = { + "ArtifactRegistrationService": ".registration", + "GlaasClient": ".client", + "GlaasFragmentStreamer": ".fragment_streamer", + "JobRegistrationService": ".registration", + "RegistrationCoordinator": ".registration", + "SessionRegistrationService": ".registration", + "compute_pubkey_fingerprint": ".auth", + "create_signature_payload": ".auth", + "find_ssh_private_key": ".auth", + "find_ssh_pubkey": ".auth", + "get_glaas_url": ".auth", + "make_auth_header": ".auth", + "parse_json_response": ".transport", + "request_json": ".transport", + "sign_payload": ".auth", +} + +__all__ = sorted(_EXPORTS) + + +def __getattr__(name: str) -> Any: + module_name = _EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module = import_module(module_name, __name__) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/roar/integrations/glaas/client.py b/roar/integrations/glaas/client.py index 2866c763..f38543c7 100644 --- a/roar/integrations/glaas/client.py +++ b/roar/integrations/glaas/client.py @@ -63,9 +63,8 @@ class GlaasClient: """Client for interacting with GLaaS server.""" def __init__(self, base_url: str | None = None): - self.base_url = base_url or get_glaas_url() - if self.base_url: - self.base_url = self.base_url.rstrip("/") + resolved_base_url = get_glaas_url() if base_url is None else base_url + self.base_url = resolved_base_url.rstrip("/") if resolved_base_url else None def is_configured(self) -> bool: """Check if GLaaS is configured.""" diff --git a/roar/integrations/glaas/registration/__init__.py b/roar/integrations/glaas/registration/__init__.py index 4b1cd285..bcca60f3 100644 --- a/roar/integrations/glaas/registration/__init__.py +++ b/roar/integrations/glaas/registration/__init__.py @@ -1,13 +1,30 @@ -"""GLaaS registration protocol helpers.""" - -from .artifact import ArtifactRegistrationService -from .coordinator import RegistrationCoordinator -from .job import JobRegistrationService -from .session import SessionRegistrationService - -__all__ = [ - "ArtifactRegistrationService", - "JobRegistrationService", - "RegistrationCoordinator", - "SessionRegistrationService", -] +"""Lazy exports for GLaaS registration helpers.""" + +from __future__ import annotations + +from importlib import import_module +from typing import Any + +_EXPORTS = { + "ArtifactRegistrationService": ".artifact", + "JobRegistrationService": ".job", + "RegistrationCoordinator": ".coordinator", + "SessionRegistrationService": ".session", +} + +__all__ = sorted(_EXPORTS) + + +def __getattr__(name: str) -> Any: + module_name = _EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module = import_module(module_name, __name__) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/roar/presenters/__init__.py b/roar/presenters/__init__.py index 3cdb461f..f87ea687 100644 --- a/roar/presenters/__init__.py +++ b/roar/presenters/__init__.py @@ -1,14 +1,31 @@ -""" -Output presenters for roar CLI. +"""Lazy exports for presenter implementations.""" -Implements different output formats (console, JSON, etc.) -following the Strategy pattern. -""" +from __future__ import annotations -from .console import ConsolePresenter -from .dag_data_builder import DagDataBuilder -from .dag_renderer import DagRenderer -from .null import NullPresenter -from .show_renderer import ShowRenderer +from importlib import import_module +from typing import Any -__all__ = ["ConsolePresenter", "DagDataBuilder", "DagRenderer", "NullPresenter", "ShowRenderer"] +_EXPORTS = { + "ConsolePresenter": ".console", + "DagDataBuilder": ".dag_data_builder", + "DagRenderer": ".dag_renderer", + "NullPresenter": ".null", + "ShowRenderer": ".show_renderer", +} + +__all__ = sorted(_EXPORTS) + + +def __getattr__(name: str) -> Any: + module_name = _EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module = import_module(module_name, __name__) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/roar/presenters/run_report.py b/roar/presenters/run_report.py index d771499b..ba7c4dae 100644 --- a/roar/presenters/run_report.py +++ b/roar/presenters/run_report.py @@ -94,7 +94,15 @@ def show_report( if result.interrupted and result.outputs: self._out.print("") self._out.print("Note: Run was interrupted. Output files may be incomplete.") - self._out.print("Use 'roar clean' to remove written files if needed.") + self._out.print("Use 'roar pop' to remove this job and delete safe written files.") + + self._out.print("") + self._out.print("Next:") + self._out.print(f" roar show --job {result.job_uid}") + if result.interrupted and result.outputs: + self._out.print(" roar pop") + else: + self._out.print(" roar dag") def show_stale_warnings( self, diff --git a/roar/presenters/show_renderer.py b/roar/presenters/show_renderer.py index 3544c548..ed73bb3d 100644 --- a/roar/presenters/show_renderer.py +++ b/roar/presenters/show_renderer.py @@ -9,7 +9,7 @@ import json -from ..application.labels import render_label_lines +from ..application.label_rendering import render_label_lines from .formatting import format_duration, format_size, format_timestamp diff --git a/scripts/profile_roar.py b/scripts/profile_roar.py new file mode 100644 index 00000000..f65e1209 --- /dev/null +++ b/scripts/profile_roar.py @@ -0,0 +1,641 @@ +#!/usr/bin/env python3 +"""Reusable profiling harness for representative roar workflows.""" + +from __future__ import annotations + +import argparse +import json +import os +import pstats +import re +import shutil +import statistics +import subprocess +import sys +import tempfile +import time +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal + +REPO_ROOT = Path(__file__).resolve().parents[1] +RESULTS_DIR = REPO_ROOT / "tests" / "benchmarks" / "results" +PROFILE_ROOT = RESULTS_DIR / "profiles" +INJECT_DIR = REPO_ROOT / "roar" / "execution" / "runtime" / "inject" +DEFAULT_SCENARIOS = ( + "cli_help", + "cli_run_simple", + "cli_status_active", + "cli_show_session", + "cli_register_dry_run", + "cli_put_dry_run", + "startup_wrap", +) +IMPORTTIME_RE = re.compile( + r"^import time:\s+(?P\d+)\s+\|\s+(?P\d+)\s+\|\s+(?P.+?)\s*$" +) + + +@dataclass(frozen=True) +class Hotspot: + label: str + primitive_calls: int + total_calls: int + cumulative_ms: float + internal_ms: float + + +@dataclass(frozen=True) +class ImportHotspot: + module: str + self_ms: float + cumulative_ms: float + + +@dataclass(frozen=True) +class CommandProfileResult: + name: str + kind: Literal["cli"] + command: list[str] + iterations: int + mean_ms: float + median_ms: float + min_ms: float + max_ms: float + stdev_ms: float + profile_path: str + stdout_path: str + stderr_path: str + top_cumulative: list[Hotspot] + top_internal: list[Hotspot] + + +@dataclass(frozen=True) +class StartupProfileResult: + name: str + kind: Literal["startup"] + iterations: int + baseline_mean_ms: float + wrapped_mean_ms: float + wrapped_with_log_mean_ms: float + import_overhead_ms: float + atexit_overhead_ms: float + total_overhead_ms: float + importtime_stderr_path: str + top_imports: list[ImportHotspot] + + +def _now_iso_utc() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _git_commit() -> str: + proc = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=REPO_ROOT, + text=True, + capture_output=True, + check=False, + ) + if proc.returncode == 0: + return proc.stdout.strip() + return "unknown" + + +def _slug_timestamp() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +def _summary_stats(samples_ms: list[float]) -> dict[str, float]: + if not samples_ms: + raise ValueError("No timing samples were collected") + return { + "mean_ms": statistics.fmean(samples_ms), + "median_ms": statistics.median(samples_ms), + "min_ms": min(samples_ms), + "max_ms": max(samples_ms), + "stdev_ms": statistics.stdev(samples_ms) if len(samples_ms) > 1 else 0.0, + } + + +def _format_label(filename: str, line_no: int, func_name: str) -> str: + try: + path = Path(filename) + rel = path.relative_to(REPO_ROOT) + display = str(rel) + except ValueError: + display = filename + return f"{display}:{line_no}::{func_name}" + + +def _summarize_profile(profile_path: Path, limit: int) -> tuple[list[Hotspot], list[Hotspot]]: + stats = pstats.Stats(str(profile_path)) + entries: list[Hotspot] = [] + for func, stat in stats.stats.items(): + primitive_calls, total_calls, internal_s, cumulative_s, _callers = stat + filename, line_no, func_name = func + entries.append( + Hotspot( + label=_format_label(filename, line_no, func_name), + primitive_calls=int(primitive_calls), + total_calls=int(total_calls), + cumulative_ms=cumulative_s * 1000.0, + internal_ms=internal_s * 1000.0, + ) + ) + + top_cumulative = sorted( + entries, + key=lambda item: (item.cumulative_ms, item.internal_ms), + reverse=True, + )[:limit] + top_internal = sorted( + entries, + key=lambda item: (item.internal_ms, item.cumulative_ms), + reverse=True, + )[:limit] + return top_cumulative, top_internal + + +def _parse_importtime(stderr: str, limit: int) -> list[ImportHotspot]: + hotspots: list[ImportHotspot] = [] + for line in stderr.splitlines(): + match = IMPORTTIME_RE.match(line) + if not match: + continue + hotspots.append( + ImportHotspot( + module=match.group("module").strip(), + self_ms=int(match.group("self_us")) / 1000.0, + cumulative_ms=int(match.group("cum_us")) / 1000.0, + ) + ) + return sorted(hotspots, key=lambda item: item.cumulative_ms, reverse=True)[:limit] + + +def _repo_env(*, inject_dir: bool = False) -> dict[str, str]: + env = os.environ.copy() + pythonpath_entries: list[str] = [] + if inject_dir: + pythonpath_entries.append(str(INJECT_DIR)) + pythonpath_entries.append(str(REPO_ROOT)) + existing = env.get("PYTHONPATH") + if existing: + pythonpath_entries.append(existing) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries) + return env + + +def _run( + command: list[str], + *, + cwd: Path | None = None, + env: dict[str, str] | None = None, + timeout: int = 180, + capture_output: bool = True, +) -> subprocess.CompletedProcess[str]: + return subprocess.run( + command, + cwd=str(cwd) if cwd is not None else None, + env=env, + text=True, + capture_output=capture_output, + check=False, + timeout=timeout, + ) + + +def _run_checked( + command: list[str], + *, + cwd: Path | None = None, + env: dict[str, str] | None = None, + timeout: int = 180, +) -> subprocess.CompletedProcess[str]: + proc = _run(command, cwd=cwd, env=env, timeout=timeout, capture_output=True) + if proc.returncode != 0: + raise RuntimeError( + f"Command failed ({proc.returncode}): {' '.join(command)}\n" + f"stdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}" + ) + return proc + + +def _python_module_command(*args: str) -> list[str]: + return [sys.executable, "-m", "roar", *args] + + +def _git(repo: Path, *args: str) -> None: + _run_checked(["git", *args], cwd=repo, env=_repo_env()) + + +def _write_workspace_files(repo: Path) -> None: + (repo / "input.txt").write_text("hello roar\n", encoding="utf-8") + (repo / "transform.py").write_text( + """ +from pathlib import Path +import sys + +source = Path(sys.argv[1]) +destination = Path(sys.argv[2]) +destination.write_text(source.read_text(encoding="utf-8").upper(), encoding="utf-8") +""".strip() + + "\n", + encoding="utf-8", + ) + + +def _create_initialized_workspace(base_dir: Path) -> Path: + repo = base_dir / "repo" + repo.mkdir() + _git(repo, "init", "-b", "main") + _git(repo, "config", "user.name", "Roar Profiler") + _git(repo, "config", "user.email", "profiler@example.com") + _write_workspace_files(repo) + _git(repo, "add", ".") + _git(repo, "commit", "-m", "initial fixture") + _run_checked(_python_module_command("init"), cwd=repo, env=_repo_env()) + _git(repo, "add", ".") + _git(repo, "commit", "-m", "initialize roar") + return repo + + +def _seed_workspace(repo: Path) -> None: + _run_checked( + _python_module_command( + "run", + sys.executable, + "transform.py", + "input.txt", + "output.txt", + ), + cwd=repo, + env=_repo_env(), + timeout=240, + ) + + +def _prepare_workspace( + mode: Literal["none", "initialized", "seeded"], +) -> tuple[tempfile.TemporaryDirectory[str], Path | None]: + temp_dir = tempfile.TemporaryDirectory(prefix="roar-profile-") + repo: Path | None = None + if mode != "none": + repo = _create_initialized_workspace(Path(temp_dir.name)) + if mode == "seeded": + _seed_workspace(repo) + return temp_dir, repo + + +def _profile_command_scenario( + *, + name: str, + args: list[str], + workspace_mode: Literal["none", "initialized", "seeded"], + iterations: int, + profile_dir: Path, + top: int, +) -> CommandProfileResult: + samples_ms: list[float] = [] + profile_path = profile_dir / f"{name}.prof" + stdout_path = profile_dir / f"{name}.stdout.txt" + stderr_path = profile_dir / f"{name}.stderr.txt" + + for iteration in range(iterations): + temp_dir, repo = _prepare_workspace(workspace_mode) + try: + cwd = repo if repo is not None else REPO_ROOT + env = _repo_env() + command = [sys.executable, "-m", "roar", *args] + started = time.perf_counter() + proc = _run(command, cwd=cwd, env=env, timeout=240, capture_output=True) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + if proc.returncode != 0: + raise RuntimeError( + f"Scenario {name} failed on iteration {iteration + 1}: {proc.returncode}\n" + f"stdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}" + ) + samples_ms.append(elapsed_ms) + finally: + temp_dir.cleanup() + + temp_dir, repo = _prepare_workspace(workspace_mode) + try: + cwd = repo if repo is not None else REPO_ROOT + env = _repo_env() + profile_command = [ + sys.executable, + "-m", + "cProfile", + "-o", + str(profile_path), + "-m", + "roar", + *args, + ] + profiled = _run(profile_command, cwd=cwd, env=env, timeout=240, capture_output=True) + stdout_path.write_text(profiled.stdout, encoding="utf-8") + stderr_path.write_text(profiled.stderr, encoding="utf-8") + if profiled.returncode != 0: + raise RuntimeError( + f"Profiled scenario {name} failed: {profiled.returncode}\n" + f"stdout:\n{profiled.stdout}\n\nstderr:\n{profiled.stderr}" + ) + finally: + temp_dir.cleanup() + + top_cumulative, top_internal = _summarize_profile(profile_path, top) + summary = _summary_stats(samples_ms) + return CommandProfileResult( + name=name, + kind="cli", + command=_python_module_command(*args), + iterations=iterations, + profile_path=str(profile_path), + stdout_path=str(stdout_path), + stderr_path=str(stderr_path), + top_cumulative=top_cumulative, + top_internal=top_internal, + **summary, + ) + + +def _measure_startup_command(env: dict[str, str], iterations: int) -> list[float]: + samples_ms: list[float] = [] + for _ in range(iterations): + started = time.perf_counter() + proc = _run( + [sys.executable, "-c", "pass"], + cwd=REPO_ROOT, + env=env, + timeout=60, + capture_output=False, + ) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + if proc.returncode != 0: + raise RuntimeError("Startup probe failed") + samples_ms.append(elapsed_ms) + return samples_ms + + +def _profile_startup_scenario( + *, + iterations: int, + profile_dir: Path, + top: int, +) -> StartupProfileResult: + baseline_env = _repo_env() + baseline_env["ROAR_WRAP"] = "0" + baseline_env.pop("ROAR_LOG_FILE", None) + + wrapped_env = _repo_env(inject_dir=True) + wrapped_env["ROAR_WRAP"] = "1" + wrapped_env.pop("ROAR_LOG_FILE", None) + + wrapped_with_log_env = _repo_env(inject_dir=True) + wrapped_with_log_env["ROAR_WRAP"] = "1" + wrapped_with_log_env["ROAR_LOG_FILE"] = str(profile_dir / "startup-log.json") + + baseline_samples = _measure_startup_command(baseline_env, iterations) + wrapped_samples = _measure_startup_command(wrapped_env, iterations) + wrapped_with_log_samples = _measure_startup_command(wrapped_with_log_env, iterations) + + importtime_path = profile_dir / "startup_wrap.importtime.txt" + importtime_proc = _run( + [sys.executable, "-X", "importtime", "-c", "pass"], + cwd=REPO_ROOT, + env=wrapped_env, + timeout=60, + capture_output=True, + ) + importtime_path.write_text(importtime_proc.stderr, encoding="utf-8") + if importtime_proc.returncode != 0: + raise RuntimeError( + f"Importtime startup probe failed:\nstdout:\n{importtime_proc.stdout}\n\nstderr:\n{importtime_proc.stderr}" + ) + + baseline_mean = statistics.fmean(baseline_samples) + wrapped_mean = statistics.fmean(wrapped_samples) + wrapped_with_log_mean = statistics.fmean(wrapped_with_log_samples) + top_imports = _parse_importtime(importtime_proc.stderr, top) + + return StartupProfileResult( + name="startup_wrap", + kind="startup", + iterations=iterations, + baseline_mean_ms=baseline_mean, + wrapped_mean_ms=wrapped_mean, + wrapped_with_log_mean_ms=wrapped_with_log_mean, + import_overhead_ms=wrapped_mean - baseline_mean, + atexit_overhead_ms=wrapped_with_log_mean - wrapped_mean, + total_overhead_ms=wrapped_with_log_mean - baseline_mean, + importtime_stderr_path=str(importtime_path), + top_imports=top_imports, + ) + + +def _scenario_specs() -> dict[ + str, + tuple[ + Literal["cli", "startup"], list[str] | None, Literal["none", "initialized", "seeded"] | None + ], +]: + return { + "cli_help": ("cli", ["--help"], "none"), + "cli_run_simple": ( + "cli", + ["run", sys.executable, "transform.py", "input.txt", "output.txt"], + "initialized", + ), + "cli_status_active": ("cli", ["status"], "seeded"), + "cli_show_session": ("cli", ["show", "--session"], "seeded"), + "cli_register_dry_run": ("cli", ["register", "--dry-run", "@1"], "seeded"), + "cli_put_dry_run": ( + "cli", + ["put", "@1", "s3://benchmark-bucket/profiles", "-m", "profile run", "--dry-run"], + "seeded", + ), + "startup_wrap": ("startup", None, None), + } + + +def _markdown_report( + *, + metadata: dict[str, Any], + command_results: list[CommandProfileResult], + startup_result: StartupProfileResult | None, +) -> str: + lines = [ + "# Roar Profiling Report", + "", + f"- Timestamp: `{metadata['timestamp']}`", + f"- Git commit: `{metadata['git_commit']}`", + f"- Python: `{metadata['python_version']}`", + "", + "## CLI Scenarios", + "", + ] + + for result in command_results: + lines.extend( + [ + f"### `{result.name}`", + "", + f"- Mean wall time: `{result.mean_ms:.1f}ms`", + f"- Median wall time: `{result.median_ms:.1f}ms`", + f"- Std dev: `{result.stdev_ms:.1f}ms`", + f"- Command: `{subprocess.list2cmdline(result.command)}`", + f"- Profile: `{result.profile_path}`", + "", + "Top cumulative hotspots:", + ] + ) + for hotspot in result.top_cumulative[:5]: + lines.append( + f"- `{hotspot.label}`: cumulative `{hotspot.cumulative_ms:.1f}ms`, internal `{hotspot.internal_ms:.1f}ms`" + ) + lines.append("") + + if startup_result is not None: + lines.extend( + [ + "## Startup Scenario", + "", + f"- Baseline: `{startup_result.baseline_mean_ms:.1f}ms`", + f"- ROAR_WRAP=1: `{startup_result.wrapped_mean_ms:.1f}ms`", + f"- ROAR_WRAP=1 + LOG_FILE: `{startup_result.wrapped_with_log_mean_ms:.1f}ms`", + f"- Import overhead: `{startup_result.import_overhead_ms:.1f}ms`", + f"- Atexit overhead: `{startup_result.atexit_overhead_ms:.1f}ms`", + f"- Total overhead: `{startup_result.total_overhead_ms:.1f}ms`", + "", + "Top imports by cumulative time:", + ] + ) + for hotspot in startup_result.top_imports[:8]: + lines.append( + f"- `{hotspot.module}`: cumulative `{hotspot.cumulative_ms:.1f}ms`, self `{hotspot.self_ms:.1f}ms`" + ) + lines.append("") + + return "\n".join(lines) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Profile representative roar CLI and runtime scenarios." + ) + parser.add_argument( + "--scenario", + action="append", + choices=sorted(_scenario_specs()), + help="Run only the selected scenario(s). Can be passed multiple times.", + ) + parser.add_argument( + "--iterations", + type=int, + default=3, + help="Wall-time iterations per scenario (default: 3).", + ) + parser.add_argument( + "--top", + type=int, + default=12, + help="Number of hotspots to keep per scenario (default: 12).", + ) + parser.add_argument( + "--output", + type=Path, + default=RESULTS_DIR / "profile_suite_latest.json", + help="JSON summary output path.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + scenario_names = args.scenario or list(DEFAULT_SCENARIOS) + specs = _scenario_specs() + + run_dir = PROFILE_ROOT / _slug_timestamp() + run_dir.mkdir(parents=True, exist_ok=True) + + command_results: list[CommandProfileResult] = [] + startup_result: StartupProfileResult | None = None + + for name in scenario_names: + kind, cli_args, workspace_mode = specs[name] + print(f"[profile] running {name}...", flush=True) + if kind == "cli": + assert cli_args is not None + assert workspace_mode is not None + result = _profile_command_scenario( + name=name, + args=cli_args, + workspace_mode=workspace_mode, + iterations=args.iterations, + profile_dir=run_dir, + top=args.top, + ) + command_results.append(result) + print( + f" mean={result.mean_ms:.1f}ms median={result.median_ms:.1f}ms stdev={result.stdev_ms:.1f}ms", + flush=True, + ) + else: + startup_result = _profile_startup_scenario( + iterations=args.iterations, + profile_dir=run_dir, + top=args.top, + ) + print( + " baseline=" + f"{startup_result.baseline_mean_ms:.1f}ms " + f"wrap={startup_result.wrapped_mean_ms:.1f}ms " + f"wrap+log={startup_result.wrapped_with_log_mean_ms:.1f}ms", + flush=True, + ) + + metadata = { + "timestamp": _now_iso_utc(), + "git_commit": _git_commit(), + "python_version": sys.version.split()[0], + "iterations": args.iterations, + "top": args.top, + "profile_artifacts_dir": str(run_dir), + } + + payload = { + "metadata": metadata, + "scenarios": [asdict(result) for result in command_results], + "startup": asdict(startup_result) if startup_result is not None else None, + } + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + markdown_path = args.output.with_suffix(".md") + markdown_path.write_text( + _markdown_report( + metadata=metadata, + command_results=command_results, + startup_result=startup_result, + ), + encoding="utf-8", + ) + + latest_profile_dir = PROFILE_ROOT / "latest" + if latest_profile_dir.exists() or latest_profile_dir.is_symlink(): + if latest_profile_dir.is_symlink() or latest_profile_dir.is_file(): + latest_profile_dir.unlink() + else: + shutil.rmtree(latest_profile_dir) + shutil.copytree(run_dir, latest_profile_dir) + + print(f"[profile] wrote JSON summary to {args.output}", flush=True) + print(f"[profile] wrote Markdown summary to {markdown_path}", flush=True) + print(f"[profile] stored profile artifacts in {run_dir}", flush=True) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/application/publish/test_collection.py b/tests/application/publish/test_collection.py index 03884915..99e133af 100644 --- a/tests/application/publish/test_collection.py +++ b/tests/application/publish/test_collection.py @@ -176,6 +176,55 @@ def test_collect_register_lineage_for_job_uses_representative_hash(tmp_path: Pat ) +def test_collect_register_lineage_step_reference_dry_run_uses_read_only_collector( + tmp_path: Path, +) -> None: + collector = MagicMock() + collector.collect_step_read_only.return_value = LineageData( + jobs=[{"id": 1, "job_uid": "job-1"}], + artifacts=[], + artifact_hashes={"c" * 64}, + pipeline={"id": 7}, + ) + + with ( + patch("roar.application.publish.collection.create_database_context") as create_db_ctx, + patch( + "roar.application.publish.collection.create_query_database_context" + ) as create_query_db, + ): + db_ctx = MagicMock() + db_ctx.__enter__ = MagicMock(return_value=db_ctx) + db_ctx.__exit__ = MagicMock(return_value=None) + db_ctx.sessions.get_active.return_value = {"id": 7} + create_query_db.return_value = db_ctx + + collected, error = collect_register_lineage( + target=ResolvedRegisterTarget(kind="step_reference", value="@1"), + roar_dir=tmp_path / ".roar", + cwd=tmp_path, + lineage_collector=collector, + session_service=MagicMock(), + logger=MagicMock(), + dry_run=True, + ) + + assert error is None + assert collected == CollectedRegisterLineage( + lineage=collector.collect_step_read_only.return_value, + session_id=7, + artifact_hash="c" * 64, + session_hash_override=None, + ) + create_db_ctx.assert_not_called() + collector.collect_step_read_only.assert_called_once_with( + session_id=7, + step_number=1, + roar_dir=tmp_path / ".roar", + job_type=None, + ) + + def test_select_representative_hash_returns_empty_string_for_multiple_hashes() -> None: lineage = LineageData( jobs=[], diff --git a/tests/application/publish/test_lineage.py b/tests/application/publish/test_lineage.py index 4bdf884c..ff247561 100644 --- a/tests/application/publish/test_lineage.py +++ b/tests/application/publish/test_lineage.py @@ -8,6 +8,7 @@ _extract_primary_digest, compute_io_signature, ) +from roar.db.context import create_database_context from roar.db.schema import SCHEMA, run_migrations @@ -501,3 +502,52 @@ def test_collect_task_output_includes_parent_and_sibling_ray_tasks(tmp_path): assert jobs_by_uid["task-a"]["parent_job_uid"] == "driver-main" assert jobs_by_uid["task-b"]["parent_job_uid"] == "driver-main" assert jobs_by_uid["task-c"]["parent_job_uid"] == "driver-main" + + +def test_collect_step_read_only_matches_collect_step_for_simple_session(tmp_path): + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + + input_path = tmp_path / "input.txt" + output_path = tmp_path / "output.txt" + input_path.write_text("hello\n", encoding="utf-8") + output_path.write_text("world\n", encoding="utf-8") + + with create_database_context(roar_dir) as db_ctx: + session_id = db_ctx.sessions.create( + git_repo="/repo", + git_commit="abc123", + make_active=True, + ) + input_artifact_id, _ = db_ctx.artifacts.register( + {"blake3": "1" * 64}, + size=input_path.stat().st_size, + path=str(input_path), + ) + output_artifact_id, _ = db_ctx.artifacts.register( + {"blake3": "a" * 64}, + size=output_path.stat().st_size, + path=str(output_path), + ) + job_id, job_uid = db_ctx.jobs.create( + "python train.py", + 1700000000.0, + session_id=session_id, + step_number=1, + duration_seconds=1.5, + exit_code=0, + ) + db_ctx.jobs.add_input(job_id, input_artifact_id, str(input_path)) + db_ctx.jobs.add_output(job_id, output_artifact_id, str(output_path)) + + collector = LineageCollector() + sql_lineage = collector.collect_step(session_id, 1, roar_dir) + query_lineage = collector.collect_step_read_only(session_id, 1, roar_dir) + + assert query_lineage.pipeline == sql_lineage.pipeline + assert query_lineage.artifact_hashes == sql_lineage.artifact_hashes + assert [job["job_uid"] for job in query_lineage.jobs] == [job_uid] + assert [job["job_uid"] for job in sql_lineage.jobs] == [job_uid] + assert {artifact["id"] for artifact in query_lineage.artifacts} == { + artifact["id"] for artifact in sql_lineage.artifacts + } diff --git a/tests/application/publish/test_register_preparation.py b/tests/application/publish/test_register_preparation.py index 0a833cd1..892d0004 100644 --- a/tests/application/publish/test_register_preparation.py +++ b/tests/application/publish/test_register_preparation.py @@ -38,7 +38,7 @@ def test_prepare_register_execution_builds_session_git_and_tag_plan(tmp_path: Pa "roar.application.publish.register_preparation.prepare_publish_session", return_value=prepared_session, ) as prepare_session, - patch("roar.application.publish.register_preparation.config_get", return_value=True), + patch("roar.integrations.config.config_get", return_value=True), ): prepared = prepare_register_execution( runtime=runtime, @@ -92,7 +92,7 @@ def test_prepare_register_execution_skips_git_tagging_and_glaas_on_dry_run(tmp_p "roar.application.publish.register_preparation.prepare_publish_session", return_value=prepared_session, ) as prepare_session, - patch("roar.application.publish.register_preparation.config_get", return_value=True), + patch("roar.integrations.config.config_get") as config_get, ): prepared = prepare_register_execution( runtime=runtime, @@ -106,6 +106,7 @@ def test_prepare_register_execution_skips_git_tagging_and_glaas_on_dry_run(tmp_p assert prepared.git_tag_name is None assert prepared.git_tag_repo_root is None + config_get.assert_not_called() ensure_clean.assert_not_called() assert prepare_session.call_args.kwargs["register_with_glaas"] is False @@ -122,7 +123,7 @@ def test_prepare_register_execution_propagates_dirty_repo_error(tmp_path: Path) "roar.application.publish.register_preparation.ensure_clean_git_repo", side_effect=ValueError("dirty repo"), ), - patch("roar.application.publish.register_preparation.config_get", return_value=True), + patch("roar.integrations.config.config_get", return_value=True), pytest.raises(ValueError, match="dirty repo"), ): prepare_register_execution( @@ -148,7 +149,7 @@ def test_prepare_register_execution_propagates_session_preparation_error(tmp_pat "roar.application.publish.register_preparation.prepare_publish_session", side_effect=ValueError("GLaaS not configured"), ), - patch("roar.application.publish.register_preparation.config_get", return_value=False), + patch("roar.integrations.config.config_get", return_value=False), pytest.raises(ValueError, match="GLaaS not configured"), ): prepare_register_execution( diff --git a/tests/application/publish/test_service.py b/tests/application/publish/test_service.py index 66ad48e9..bb4b791f 100644 --- a/tests/application/publish/test_service.py +++ b/tests/application/publish/test_service.py @@ -10,14 +10,14 @@ from roar.application.publish.register_execution import RegisterResult from roar.application.publish.register_preparation import PreparedRegisterExecution from roar.application.publish.requests import PutRequest, RegisterLineageRequest -from roar.application.publish.results import PutResponse, RegisterLineageResponse +from roar.application.publish.results import PutDryRunItem, PutResponse, RegisterLineageResponse from roar.application.publish.service import put_artifacts, register_lineage_target from roar.application.publish.targets import ResolvedRegisterTarget from roar.core.interfaces.lineage import LineageData -def test_register_lineage_target_collects_and_registers(tmp_path: Path) -> None: - expected = RegisterResult(success=True, session_hash="a" * 64) +def test_register_lineage_target_uses_local_preview_path_for_dry_run(tmp_path: Path) -> None: + expected = RegisterLineageResponse(success=True, session_hash="a" * 64) runtime = MagicMock() logger = MagicMock() collected = MagicMock() @@ -37,10 +37,9 @@ def test_register_lineage_target_collects_and_registers(tmp_path: Path) -> None: ) with ( - patch("roar.application.publish.service.build_publish_runtime", return_value=runtime), patch( - "roar.application.publish.service.get_glaas_url", - return_value="http://localhost:3001", + "roar.application.publish.service.build_register_preview_runtime", + return_value=runtime, ), patch("roar.application.publish.service.get_logger", return_value=logger), patch( @@ -52,13 +51,15 @@ def test_register_lineage_target_collects_and_registers(tmp_path: Path) -> None: return_value=(collected, None), ) as collect_lineage, patch( - "roar.application.publish.service.prepare_register_execution", + "roar.application.publish.service.prepare_register_preview_execution", return_value=prepared, ) as prepare_register, + patch( + "roar.application.publish.service.preview_register_lineage", + return_value=expected, + ) as preview_register, patch("roar.application.publish.service.RegisterService") as mock_cls, ): - mock_cls.return_value.register_prepared_lineage.return_value = expected - response = register_lineage_target( RegisterLineageRequest( target="model.pt", @@ -68,11 +69,8 @@ def test_register_lineage_target_collects_and_registers(tmp_path: Path) -> None: ) ) - assert response == RegisterLineageResponse(success=True, session_hash="a" * 64) - mock_cls.assert_called_once_with( - glaas_client=runtime.glaas_client, - coordinator=runtime.registration_coordinator, - ) + assert response == expected + mock_cls.assert_not_called() collect_lineage.assert_called_once_with( target=ResolvedRegisterTarget(kind="artifact_path", value="model.pt"), roar_dir=tmp_path / ".roar", @@ -80,25 +78,23 @@ def test_register_lineage_target_collects_and_registers(tmp_path: Path) -> None: lineage_collector=runtime.lineage_collector, session_service=runtime.session_service, logger=logger, + dry_run=True, ) prepare_register.assert_called_once_with( runtime=runtime, roar_dir=tmp_path / ".roar", cwd=tmp_path, session_id=7, - dry_run=True, session_hash_override=None, logger=logger, ) - mock_cls.return_value.register_prepared_lineage.assert_called_once_with( + preview_register.assert_called_once_with( lineage=collected.lineage, - roar_dir=tmp_path / ".roar", artifact_hash="a" * 64, - dry_run=True, - as_blake3=False, + prepared=prepared, + cwd=tmp_path, skip_confirmation=False, confirm_callback=None, - prepared=prepared, ) @@ -107,10 +103,6 @@ def test_register_lineage_target_returns_collection_error(tmp_path: Path) -> Non with ( patch("roar.application.publish.service.build_publish_runtime", return_value=runtime), - patch( - "roar.application.publish.service.get_glaas_url", - return_value="http://localhost:3001", - ), patch( "roar.application.publish.service.resolve_register_lineage_target", return_value=ResolvedRegisterTarget(kind="artifact_path", value="missing.csv"), @@ -147,10 +139,6 @@ def test_register_lineage_target_returns_preparation_error(tmp_path: Path) -> No with ( patch("roar.application.publish.service.build_publish_runtime", return_value=runtime), patch("roar.application.publish.service.get_logger", return_value=logger), - patch( - "roar.application.publish.service.get_glaas_url", - return_value="http://localhost:3001", - ), patch( "roar.application.publish.service.resolve_register_lineage_target", return_value=ResolvedRegisterTarget(kind="artifact_path", value="model.pt"), @@ -204,10 +192,6 @@ def test_register_lineage_target_creates_git_tag_after_success(tmp_path: Path) - with ( patch("roar.application.publish.service.build_publish_runtime", return_value=runtime), patch("roar.application.publish.service.get_logger", return_value=logger), - patch( - "roar.application.publish.service.get_glaas_url", - return_value="http://localhost:3001", - ), patch( "roar.application.publish.service.resolve_register_lineage_target", return_value=ResolvedRegisterTarget(kind="artifact_path", value="model.pt"), @@ -419,27 +403,25 @@ def test_put_artifacts_continues_when_git_preflight_warns(tmp_path: Path) -> Non def test_put_artifacts_returns_preparation_error_before_service(tmp_path: Path) -> None: db_ctx = MagicMock() + db_ctx.sessions.get_active.return_value = None with ( patch("roar.application.publish.service.bootstrap"), patch("roar.application.publish.service.get_logger", return_value=MagicMock()), patch( - "roar.application.publish.service.create_database_context", + "roar.application.publish.service.create_query_database_context", return_value=nullcontext(db_ctx), ), - patch( - "roar.application.publish.service.resolve_publish_storage_backend", - return_value=MagicMock(), - ), patch( "roar.application.publish.service.prepare_put_git", return_value=MagicMock(git_commit=None, expected_tag=None, warnings=()), ), - patch("roar.application.publish.service.build_publish_runtime", return_value=MagicMock()), + patch("roar.application.publish.service.create_database_context") as create_db_ctx, patch( - "roar.application.publish.service.prepare_put_execution", - side_effect=ValueError("No active session"), - ), + "roar.application.publish.service.resolve_publish_storage_backend" + ) as resolve_backend, + patch("roar.application.publish.service.build_publish_runtime") as build_runtime, + patch("roar.application.publish.service.prepare_put_execution") as prepare_put_execution, patch("roar.application.publish.service.PutService") as mock_put_cls, pytest.raises(ValueError, match="No active session"), ): @@ -455,4 +437,60 @@ def test_put_artifacts_returns_preparation_error_before_service(tmp_path: Path) ) ) + create_db_ctx.assert_not_called() + resolve_backend.assert_not_called() + build_runtime.assert_not_called() + prepare_put_execution.assert_not_called() mock_put_cls.return_value.put_prepared.assert_not_called() + + +def test_put_artifacts_dry_run_skips_backend_runtime_and_service(tmp_path: Path) -> None: + db_ctx = MagicMock() + db_ctx.sessions.get_active.return_value = {"id": 7} + model = tmp_path / "model.pt" + model.write_bytes(b"model") + + with ( + patch("roar.application.publish.service.bootstrap") as bootstrap, + patch("roar.application.publish.service.get_logger", return_value=MagicMock()), + patch( + "roar.application.publish.service.create_query_database_context", + return_value=nullcontext(db_ctx), + ), + patch("roar.application.publish.service.prepare_put_git") as prepare_put_git, + patch("roar.application.publish.service.create_database_context") as create_db_ctx, + patch( + "roar.application.publish.service.resolve_publish_storage_backend" + ) as resolve_backend, + patch("roar.application.publish.service.build_publish_runtime") as build_runtime, + patch("roar.application.publish.service.prepare_put_execution") as prepare_put_execution, + patch("roar.application.publish.service.PutService") as mock_put_cls, + patch("roar.application.publish.service.finalize_put_git") as finalize_put_git, + ): + response = put_artifacts( + PutRequest( + roar_dir=tmp_path / ".roar", + cwd=tmp_path, + repo_root=tmp_path, + sources=["model.pt"], + destination="s3://bucket/prefix", + message="publish", + dry_run=True, + ) + ) + + assert response == PutResponse( + success=True, + destination="s3://bucket/prefix", + dry_run=True, + would_upload=[PutDryRunItem(path=str(model.resolve()), exists=True)], + warnings=[], + ) + create_db_ctx.assert_not_called() + resolve_backend.assert_not_called() + build_runtime.assert_not_called() + prepare_put_execution.assert_not_called() + mock_put_cls.assert_not_called() + bootstrap.assert_not_called() + prepare_put_git.assert_not_called() + finalize_put_git.assert_not_called() diff --git a/tests/application/query/test_log.py b/tests/application/query/test_log.py index 5e47c6ff..75116215 100644 --- a/tests/application/query/test_log.py +++ b/tests/application/query/test_log.py @@ -5,8 +5,10 @@ from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + import roar.application.query.log as log_module -from roar.application.query import LogQueryRequest, render_log +from roar.application.query import LogQueryRequest from roar.application.query.results import LogSummary @@ -17,7 +19,7 @@ def _request(tmp_path: Path, *, use_color: bool = False) -> LogQueryRequest: def test_build_log_summary_returns_typed_jobs_in_display_order(tmp_path: Path) -> None: - with patch.object(log_module, "create_database_context") as mock_db: + with patch.object(log_module, "create_query_database_context") as mock_db: db_ctx = MagicMock() mock_db.return_value.__enter__.return_value = db_ctx db_ctx.sessions.get_active.return_value = {"id": 11} @@ -49,12 +51,14 @@ def test_build_log_summary_returns_typed_jobs_in_display_order(tmp_path: Path) - assert summary.jobs[0].command == "python preprocess.py" -def test_render_log_without_active_session_returns_message(tmp_path: Path) -> None: - with patch.object(log_module, "create_database_context") as mock_db: +def test_render_log_without_active_session_raises_query_error(tmp_path: Path) -> None: + with patch.object(log_module, "create_query_database_context") as mock_db: db_ctx = MagicMock() mock_db.return_value.__enter__.return_value = db_ctx db_ctx.sessions.get_active.return_value = None - rendered = render_log(_request(tmp_path)) - - assert rendered == "No active session." + with pytest.raises( + log_module.LogQueryError, + match=r"No active session\. Run 'roar run' to create a session first\.", + ): + log_module.render_log(_request(tmp_path)) diff --git a/tests/application/query/test_show.py b/tests/application/query/test_show.py index a9d9cdda..afa6663c 100644 --- a/tests/application/query/test_show.py +++ b/tests/application/query/test_show.py @@ -5,24 +5,28 @@ from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + import roar.application.query.show as show_module -from roar.application.query import ShowQueryRequest, render_show +from roar.application.query import ShowQueryRequest from roar.application.query.results import ShowArtifactSummary, ShowJobSummary, ShowSessionSummary -def _request(tmp_path: Path, ref: str | None) -> ShowQueryRequest: +def _request( + tmp_path: Path, + ref: str | None, + *, + selector: str = "auto", +) -> ShowQueryRequest: roar_dir = tmp_path / ".roar" roar_dir.mkdir() - return ShowQueryRequest(roar_dir=roar_dir, cwd=tmp_path, ref=ref) + return ShowQueryRequest(roar_dir=roar_dir, cwd=tmp_path, ref=ref, selector=selector) def test_render_show_artifact_by_full_hash(tmp_path: Path) -> None: full_hash = "a1b2c3d4e5f67890" * 4 - with ( - patch.object(show_module, "bootstrap"), - patch.object(show_module, "create_database_context") as mock_db, - ): + with patch.object(show_module, "create_query_database_context") as mock_db: db_ctx = MagicMock() mock_db.return_value.__enter__.return_value = db_ctx db_ctx.jobs.get_by_uid.return_value = None @@ -52,25 +56,73 @@ def test_render_show_relative_path_resolves_to_absolute_lookup(tmp_path: Path) - rel_path = "./data/model.pkl" expected_abs_path = str(tmp_path / "data" / "model.pkl") - with ( - patch.object(show_module, "bootstrap"), - patch.object(show_module, "create_database_context") as mock_db, - ): + with patch.object(show_module, "create_query_database_context") as mock_db: db_ctx = MagicMock() mock_db.return_value.__enter__.return_value = db_ctx db_ctx.artifacts.get_by_path.return_value = None - rendered = render_show(_request(tmp_path, rel_path)) + with pytest.raises( + show_module.ShowQueryError, match=f"No artifact found for path: {rel_path}" + ): + show_module.render_show(_request(tmp_path, rel_path)) + + db_ctx.artifacts.get_by_path.assert_called_once_with(expected_abs_path) + + +def test_render_show_bare_filename_uses_absolute_path_lookup(tmp_path: Path) -> None: + ref = "data.bin" + expected_abs_path = str(tmp_path / ref) + + with patch.object(show_module, "create_query_database_context") as mock_db: + db_ctx = MagicMock() + mock_db.return_value.__enter__.return_value = db_ctx + db_ctx.artifacts.get_by_path.return_value = { + "id": "artifact-456", + "size": 12, + "first_seen_at": 1700000000.0, + "first_seen_path": expected_abs_path, + "metadata": None, + "hashes": [{"algorithm": "blake3", "digest": "a" * 64}], + } + db_ctx.artifacts.get_locations.return_value = [{"path": expected_abs_path}] + db_ctx.artifacts.get_jobs.return_value = {"produced_by": [], "consumed_by": []} + + summary = show_module.build_show_summary(_request(tmp_path, ref)) db_ctx.artifacts.get_by_path.assert_called_once_with(expected_abs_path) - assert rendered == f"No artifact found for path: {rel_path}" + assert isinstance(summary, ShowArtifactSummary) + assert summary.id == "artifact-456" + + +def test_render_show_explicit_path_selector_bypasses_hex_auto_detection(tmp_path: Path) -> None: + ref = "deadbeef" + expected_abs_path = str(tmp_path / ref) + + with patch.object(show_module, "create_query_database_context") as mock_db: + db_ctx = MagicMock() + mock_db.return_value.__enter__.return_value = db_ctx + db_ctx.artifacts.get_by_path.return_value = { + "id": "artifact-path", + "size": 7, + "first_seen_at": 1700000000.0, + "first_seen_path": expected_abs_path, + "metadata": None, + "hashes": [{"algorithm": "blake3", "digest": "b" * 64}], + } + db_ctx.artifacts.get_locations.return_value = [{"path": expected_abs_path}] + db_ctx.artifacts.get_jobs.return_value = {"produced_by": [], "consumed_by": []} + + summary = show_module.build_show_summary(_request(tmp_path, ref, selector="path")) + + db_ctx.artifacts.get_by_path.assert_called_once_with(expected_abs_path) + db_ctx.jobs.get_by_uid.assert_not_called() + db_ctx.artifacts.get_by_hash.assert_not_called() + assert isinstance(summary, ShowArtifactSummary) + assert summary.id == "artifact-path" def test_render_show_job_uid_takes_precedence_for_short_hex_refs(tmp_path: Path) -> None: - with ( - patch.object(show_module, "bootstrap"), - patch.object(show_module, "create_database_context") as mock_db, - ): + with patch.object(show_module, "create_query_database_context") as mock_db: db_ctx = MagicMock() mock_db.return_value.__enter__.return_value = db_ctx db_ctx.jobs.get_by_uid.return_value = { @@ -95,11 +147,35 @@ def test_render_show_job_uid_takes_precedence_for_short_hex_refs(tmp_path: Path) assert summary.command == "python train.py" +def test_render_show_explicit_artifact_selector_bypasses_job_uid_precedence( + tmp_path: Path, +) -> None: + full_hash = "deadbeef" + + with patch.object(show_module, "create_query_database_context") as mock_db: + db_ctx = MagicMock() + mock_db.return_value.__enter__.return_value = db_ctx + db_ctx.artifacts.get_by_hash.return_value = { + "id": "artifact-hash", + "size": 33, + "first_seen_at": 1700000000.0, + "first_seen_path": "/data/model.pkl", + "metadata": None, + "hashes": [{"algorithm": "blake3", "digest": "c" * 64}], + } + db_ctx.artifacts.get_locations.return_value = [{"path": "/data/model.pkl"}] + db_ctx.artifacts.get_jobs.return_value = {"produced_by": [], "consumed_by": []} + + summary = show_module.build_show_summary(_request(tmp_path, full_hash, selector="artifact")) + + db_ctx.artifacts.get_by_hash.assert_called_once_with(full_hash) + db_ctx.jobs.get_by_uid.assert_not_called() + assert isinstance(summary, ShowArtifactSummary) + assert summary.id == "artifact-hash" + + def test_render_show_without_ref_returns_active_session_summary(tmp_path: Path) -> None: - with ( - patch.object(show_module, "bootstrap"), - patch.object(show_module, "create_database_context") as mock_db, - ): + with patch.object(show_module, "create_query_database_context") as mock_db: db_ctx = MagicMock() mock_db.return_value.__enter__.return_value = db_ctx db_ctx.sessions.get_active.return_value = { @@ -127,15 +203,14 @@ def test_render_show_without_ref_returns_active_session_summary(tmp_path: Path) assert summary.jobs[0].command == "python preprocess.py" -def test_render_show_job_step_without_active_session_returns_message(tmp_path: Path) -> None: - with ( - patch.object(show_module, "bootstrap"), - patch.object(show_module, "create_database_context") as mock_db, - ): +def test_render_show_job_step_without_active_session_raises_query_error(tmp_path: Path) -> None: + with patch.object(show_module, "create_query_database_context") as mock_db: db_ctx = MagicMock() mock_db.return_value.__enter__.return_value = db_ctx db_ctx.sessions.get_active.return_value = None - rendered = render_show(_request(tmp_path, "@1")) - - assert rendered == "No active session." + with pytest.raises( + show_module.ShowQueryError, + match=r"No active session\. Run 'roar run' to create a session first\.", + ): + show_module.render_show(_request(tmp_path, "@1")) diff --git a/tests/application/query/test_status.py b/tests/application/query/test_status.py index 3c6340ec..53035a89 100644 --- a/tests/application/query/test_status.py +++ b/tests/application/query/test_status.py @@ -5,8 +5,10 @@ from pathlib import Path from types import SimpleNamespace +import pytest + from roar.application.query import StatusQueryRequest -from roar.application.query.status import build_status_summary +from roar.application.query.status import StatusQueryError, build_status_summary def test_build_status_summary_groups_steps_and_artifacts(tmp_path: Path) -> None: @@ -48,10 +50,7 @@ def test_build_status_summary_groups_steps_and_artifacts(tmp_path: Path) -> None from unittest.mock import patch - with ( - patch("roar.application.query.status.bootstrap"), - patch("roar.application.query.status.create_database_context") as mock_db, - ): + with patch("roar.application.query.status.create_query_database_context") as mock_db: mock_db.return_value.__enter__.return_value = db_ctx summary = build_status_summary(request) @@ -62,17 +61,18 @@ def test_build_status_summary_groups_steps_and_artifacts(tmp_path: Path) -> None assert [artifact.present for artifact in summary.artifacts] == [True, False] -def test_build_status_summary_returns_none_without_active_session(tmp_path: Path) -> None: +def test_build_status_summary_without_active_session_raises_query_error(tmp_path: Path) -> None: request = StatusQueryRequest(roar_dir=tmp_path / ".roar") request.roar_dir.mkdir() from unittest.mock import patch - with ( - patch("roar.application.query.status.bootstrap"), - patch("roar.application.query.status.create_database_context") as mock_db, - ): + with patch("roar.application.query.status.create_query_database_context") as mock_db: mock_db.return_value.__enter__.return_value = SimpleNamespace( sessions=SimpleNamespace(get_active=lambda: None) ) - assert build_status_summary(request) is None + with pytest.raises( + StatusQueryError, + match=r"No active session\. Run 'roar run' to create a session first\.", + ): + build_status_summary(request) diff --git a/tests/integration/test_put_cli_integration.py b/tests/integration/test_put_cli_integration.py index 872c1834..f1aad9ef 100644 --- a/tests/integration/test_put_cli_integration.py +++ b/tests/integration/test_put_cli_integration.py @@ -78,8 +78,9 @@ def test_put_registers_lineage_with_fake_glaas_and_updates_local_dag( assert result.returncode == 0 assert "Published 1 file(s) to s3://test-bucket/models" in result.stdout assert "model.pt -> s3://test-bucket/models/model.pt" in result.stdout - assert "Job created: step 2" in result.stdout - assert f"View: {fake_glaas_publish_server.base_url}/dag/" in result.stdout + assert "Job step: @2" in result.stdout + assert "GLaaS:" in result.stdout + assert f"{fake_glaas_publish_server.base_url}/dag/" in result.stdout dag = _get_dag(roar_cli) assert dag["total_steps"] == 2 @@ -133,10 +134,9 @@ def test_put_dry_run_does_not_create_local_or_remote_publish_jobs( ) assert result.returncode == 0 - assert "Dry run - would upload:" in result.stdout + assert "Dry run: would upload 2 file(s) to s3://bucket/test" in result.stdout assert "model.pt" in result.stdout assert "metrics.json" in result.stdout - assert "Total: 2 file(s)" in result.stdout dag_after = _get_dag(roar_cli) assert dag_after["total_steps"] == dag_before["total_steps"] == 1 diff --git a/tests/integration/test_register_dry_run_cli.py b/tests/integration/test_register_dry_run_cli.py index 686f62b4..c34b08c1 100644 --- a/tests/integration/test_register_dry_run_cli.py +++ b/tests/integration/test_register_dry_run_cli.py @@ -82,8 +82,8 @@ def test_register_dry_run_resolves_artifact_step_and_session_targets( for result in (artifact_result, step_result, session_result): assert result.returncode == 0 - assert "Dry run - would register:" in result.stdout - assert "View on GLaaS:" in result.stdout + assert "Dry run: would register lineage for:" in result.stdout + assert "GLaaS:" in result.stdout published_session_hashes = { _parse_session_hash(artifact_result.stdout), @@ -134,9 +134,9 @@ def test_register_publishes_local_lineage_with_fake_glaas( assert "Jobs: 1" in result.stdout assert "Artifacts:" in result.stdout assert "Links:" in result.stdout - assert "To reproduce this artifact:" in result.stdout + assert "Next:" in result.stdout assert "roar reproduce " in result.stdout - assert "View on GLaaS:" in result.stdout + assert "GLaaS:" in result.stdout assert fake_glaas_publish_server.health_checks >= 1 assert len(fake_glaas_publish_server.session_registrations) == 1 diff --git a/tests/integrations/config/test_access.py b/tests/integrations/config/test_access.py index 6881c4c6..2dc0ac96 100644 --- a/tests/integrations/config/test_access.py +++ b/tests/integrations/config/test_access.py @@ -1,5 +1,6 @@ """Tests for proxy configuration model and config get/set.""" +from types import SimpleNamespace from unittest.mock import patch from roar.integrations.config import ( @@ -45,8 +46,8 @@ def test_config_get_returns_false_by_default(self, tmp_path): roar_dir = tmp_path / ".roar" roar_dir.mkdir() - with patch("roar.integrations.config.access.load_config") as mock_load: - mock_load.return_value = RoarConfig().to_dict() + with patch("roar.integrations.config.access.load_settings") as mock_load: + mock_load.return_value = SimpleNamespace(proxy=SimpleNamespace(enabled=False)) result = config_get("proxy.enabled") assert result is False diff --git a/tests/integrations/glaas/test_client.py b/tests/integrations/glaas/test_client.py index 8d775592..041e10f3 100644 --- a/tests/integrations/glaas/test_client.py +++ b/tests/integrations/glaas/test_client.py @@ -33,6 +33,14 @@ def test_health_check_raises_not_configured_when_no_url(self): assert "not configured" in str(exc_info.value).lower() + def test_empty_base_url_does_not_fall_back_to_config_lookup(self): + """An explicit empty URL should stay unconfigured without reading config.""" + with patch("roar.integrations.glaas.client.get_glaas_url") as get_glaas_url: + client = GlaasClient(base_url="") + + get_glaas_url.assert_not_called() + assert client.base_url is None + def test_health_check_raises_connection_error_on_network_failure(self): """health_check should raise GlaasConnectionError on network errors.""" client = GlaasClient(base_url="http://localhost:9999") diff --git a/tests/unit/put/test_put_cli.py b/tests/unit/put/test_put_cli.py index eee90476..43623f8e 100644 --- a/tests/unit/put/test_put_cli.py +++ b/tests/unit/put/test_put_cli.py @@ -53,7 +53,11 @@ def test_put_uses_service_session_url_for_dag_link(tmp_path: Path) -> None: with ( patch.object(put_module, "put_artifacts", return_value=response), - patch.object(put_module, "config_get", return_value="https://glaas.example"), + patch.object( + put_module, + "_resolve_glaas_web_url", + return_value="https://glaas.example", + ), ): result = runner.invoke( put, @@ -84,7 +88,11 @@ def test_put_falls_back_to_web_url_plus_service_session_hash(tmp_path: Path) -> with ( patch.object(put_module, "put_artifacts", return_value=response), - patch.object(put_module, "config_get", return_value="https://glaas.example"), + patch.object( + put_module, + "_resolve_glaas_web_url", + return_value="https://glaas.example", + ), ): result = runner.invoke( put, @@ -127,7 +135,11 @@ def test_put_prints_registered_composite_summary(tmp_path: Path) -> None: with ( patch.object(put_module, "put_artifacts", return_value=response), - patch.object(put_module, "config_get", return_value="https://glaas.example"), + patch.object( + put_module, + "_resolve_glaas_web_url", + return_value="https://glaas.example", + ), ): result = runner.invoke( put, @@ -174,7 +186,11 @@ def test_put_warns_when_local_composite_persistence_fails(tmp_path: Path) -> Non with ( patch.object(put_module, "put_artifacts", return_value=response), - patch.object(put_module, "config_get", return_value="https://glaas.example"), + patch.object( + put_module, + "_resolve_glaas_web_url", + return_value="https://glaas.example", + ), ): result = runner.invoke( put, diff --git a/tests/unit/test_auth_cli.py b/tests/unit/test_auth_cli.py new file mode 100644 index 00000000..8e99634f --- /dev/null +++ b/tests/unit/test_auth_cli.py @@ -0,0 +1,45 @@ +from unittest.mock import patch + +from click.testing import CliRunner + +from roar.cli.commands.auth import auth + + +def test_auth_help_lists_key_not_register() -> None: + result = CliRunner().invoke(auth, ["--help"]) + + assert result.exit_code == 0, result.output + assert "roar auth key" in result.output + assert "register" not in result.output + + +@patch("roar.cli.commands.auth._find_ssh_pubkey") +def test_auth_key_outputs_public_key(mock_find_ssh_pubkey) -> None: + mock_find_ssh_pubkey.return_value = ( + "ssh-ed25519", + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey user@example", + "/tmp/id_ed25519.pub", + ) + + result = CliRunner().invoke(auth, ["key"]) + + assert result.exit_code == 0, result.output + assert "Your SSH public key:" in result.output + assert "/tmp/id_ed25519.pub" in result.output + + +@patch("roar.cli.commands.auth._find_ssh_pubkey") +def test_auth_register_alias_matches_key_output(mock_find_ssh_pubkey) -> None: + mock_find_ssh_pubkey.return_value = ( + "ssh-ed25519", + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey user@example", + "/tmp/id_ed25519.pub", + ) + + runner = CliRunner() + key_result = runner.invoke(auth, ["key"]) + alias_result = runner.invoke(auth, ["register"]) + + assert key_result.exit_code == 0, key_result.output + assert alias_result.exit_code == 0, alias_result.output + assert alias_result.output == key_result.output diff --git a/tests/unit/test_bootstrap_config_path.py b/tests/unit/test_bootstrap_config_path.py new file mode 100644 index 00000000..cf72ecc5 --- /dev/null +++ b/tests/unit/test_bootstrap_config_path.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path +from unittest.mock import patch + +from roar.core.bootstrap import _configure_core_logging +from roar.integrations.config import config_get + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def test_configure_core_logging_uses_explicit_roar_dir(tmp_path, monkeypatch) -> None: + repo = tmp_path / "repo" + roar_dir = repo / ".roar" + roar_dir.mkdir(parents=True) + (roar_dir / "config.toml").write_text( + '[logging]\nlevel = "debug"\nconsole = true\nfile = false\n', + encoding="utf-8", + ) + + elsewhere = tmp_path / "elsewhere" + elsewhere.mkdir() + monkeypatch.chdir(elsewhere) + + with patch("roar.core.bootstrap.configure_logger") as configure_logger: + _configure_core_logging(roar_dir) + + configure_logger.assert_called_once_with( + level="debug", + console_enabled=True, + file_enabled=False, + ) + + +def test_bootstrap_does_not_import_execution_registry_for_core_logging(tmp_path: Path) -> None: + repo = tmp_path / "repo" + roar_dir = repo / ".roar" + roar_dir.mkdir(parents=True) + (roar_dir / "config.toml").write_text('[logging]\nlevel = "info"\n', encoding="utf-8") + + env = os.environ.copy() + pythonpath_entries = [str(REPO_ROOT)] + existing = env.get("PYTHONPATH") + if existing: + pythonpath_entries.append(existing) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries) + + proc = subprocess.run( + [ + sys.executable, + "-c", + """ +import sys +from pathlib import Path + +from roar.core.bootstrap import bootstrap, reset + +roar_dir = Path(sys.argv[1]) +reset() +bootstrap(roar_dir) +assert "roar.execution.framework.registry" not in sys.modules +assert "roar.backends.ray.plugin" not in sys.modules +print("ok") +""", + str(roar_dir), + ], + cwd=REPO_ROOT, + env=env, + text=True, + capture_output=True, + check=False, + ) + + assert proc.returncode == 0, proc.stderr + assert proc.stdout.strip() == "ok" + + +def test_config_get_returns_nested_model_as_dict(tmp_path: Path) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + (roar_dir / "config.toml").write_text( + "[registration.omit]\nenabled = false\n", + encoding="utf-8", + ) + + omit_config = config_get("registration.omit", start_dir=str(tmp_path)) + + assert isinstance(omit_config, dict) + assert omit_config["enabled"] is False + + +def test_config_get_core_key_does_not_import_execution_registry(tmp_path: Path) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + (roar_dir / "config.toml").write_text('[logging]\nlevel = "debug"\n', encoding="utf-8") + + env = os.environ.copy() + pythonpath_entries = [str(REPO_ROOT)] + existing = env.get("PYTHONPATH") + if existing: + pythonpath_entries.append(existing) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries) + + proc = subprocess.run( + [ + sys.executable, + "-c", + """ +import sys +from roar.integrations.config import config_get + +value = config_get("logging.level", start_dir=sys.argv[1]) +assert value == "debug" +assert "roar.execution.framework.registry" not in sys.modules +assert "roar.backends.ray.plugin" not in sys.modules +print("ok") +""", + str(tmp_path), + ], + cwd=REPO_ROOT, + env=env, + text=True, + capture_output=True, + check=False, + ) + + assert proc.returncode == 0, proc.stderr + assert proc.stdout.strip() == "ok" diff --git a/tests/unit/test_cli_init.py b/tests/unit/test_cli_init.py new file mode 100644 index 00000000..3166ba34 --- /dev/null +++ b/tests/unit/test_cli_init.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import os +import subprocess +from pathlib import Path + +from click.testing import CliRunner + +from roar.cli import cli + + +def _init_git_repo(path: Path) -> None: + subprocess.run(["git", "init", str(path)], check=True, capture_output=True) + + +def test_init_path_uses_target_repo_for_gitignore_updates(tmp_path: Path) -> None: + caller_repo = tmp_path / "caller-repo" + target_repo = tmp_path / "target-repo" + caller_repo.mkdir() + target_repo.mkdir() + + _init_git_repo(caller_repo) + _init_git_repo(target_repo) + + caller_gitignore = caller_repo / ".gitignore" + target_gitignore = target_repo / ".gitignore" + caller_gitignore.write_text(".roar/\n") + target_gitignore.write_text("") + + runner = CliRunner() + original_cwd = Path.cwd() + try: + os.chdir(caller_repo) + result = runner.invoke(cli, ["init", "--path", str(target_repo), "-y"]) + finally: + os.chdir(original_cwd) + + assert result.exit_code == 0, result.output + assert "Added .roar/ to .gitignore" in result.output + assert ".roar is already in .gitignore" not in result.output + assert caller_gitignore.read_text() == ".roar/\n" + assert ".roar/" in target_gitignore.read_text().splitlines() + assert (target_repo / ".roar").is_dir() diff --git a/tests/unit/test_cli_query_errors.py b/tests/unit/test_cli_query_errors.py new file mode 100644 index 00000000..01fd6495 --- /dev/null +++ b/tests/unit/test_cli_query_errors.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from click.testing import CliRunner + +from roar.application.query import ShowQueryRequest +from roar.cli.commands.log import log +from roar.cli.commands.show import show +from roar.cli.commands.status import status + + +def _ctx(tmp_path): + ctx = MagicMock() + ctx.roar_dir = tmp_path / ".roar" + ctx.roar_dir.mkdir() + ctx.cwd = tmp_path + ctx.is_initialized = True + return ctx + + +def test_log_cli_exits_non_zero_without_active_session(tmp_path) -> None: + result = CliRunner().invoke(log, obj=_ctx(tmp_path)) + + assert result.exit_code != 0 + assert "No active session." in result.output + + +def test_status_cli_exits_non_zero_without_active_session(tmp_path) -> None: + result = CliRunner().invoke(status, obj=_ctx(tmp_path)) + + assert result.exit_code != 0 + assert "No active session." in result.output + + +def test_show_cli_exits_non_zero_for_missing_path_lookup(tmp_path) -> None: + result = CliRunner().invoke(show, ["artifact.bin"], obj=_ctx(tmp_path)) + + assert result.exit_code != 0 + assert "No artifact found for path: artifact.bin" in result.output + + +def test_show_cli_path_selector_builds_explicit_request(tmp_path) -> None: + ctx = _ctx(tmp_path) + + with patch("roar.cli.commands.show.render_show", return_value="ok") as render_show: + result = CliRunner().invoke(show, ["--path", "deadbeef"], obj=ctx) + + assert result.exit_code == 0, result.output + assert result.output == "ok\n" + render_show.assert_called_once_with( + ShowQueryRequest(roar_dir=ctx.roar_dir, cwd=ctx.cwd, ref="deadbeef", selector="path") + ) + + +def test_show_cli_rejects_multiple_explicit_selectors(tmp_path) -> None: + result = CliRunner().invoke(show, ["--path", "a", "--job", "@1"], obj=_ctx(tmp_path)) + + assert result.exit_code == 2 + assert "Specify only one of --path, --job, --artifact, or --session." in result.output + + +def test_show_cli_rejects_positional_ref_with_explicit_selector(tmp_path) -> None: + result = CliRunner().invoke(show, ["--artifact", "deadbeef", "other"], obj=_ctx(tmp_path)) + + assert result.exit_code == 2 + assert "Positional REF cannot be combined" in result.output diff --git a/tests/unit/test_cli_registry.py b/tests/unit/test_cli_registry.py index bce8739f..86ced9a4 100644 --- a/tests/unit/test_cli_registry.py +++ b/tests/unit/test_cli_registry.py @@ -1,5 +1,7 @@ """Unit tests for top-level CLI command registry behavior.""" +from unittest.mock import patch + from click.testing import CliRunner from roar.cli import LAZY_COMMANDS, cli @@ -17,9 +19,36 @@ def test_help_does_not_list_composite_command() -> None: assert "composite" not in result.output +def test_help_groups_commands_by_workflow() -> None: + runner = CliRunner() + result = runner.invoke(cli, ["--help"]) + + assert result.exit_code == 0, result.output + assert "Start Here:" in result.output + assert "Inspect Local Lineage:" in result.output + assert "Share and Publish:" in result.output + assert "Setup and Admin:" in result.output + assert "Track a command with provenance" in result.output + assert "Publish artifacts and register lineage" in result.output + + def test_cli_rejects_removed_composite_command() -> None: runner = CliRunner() result = runner.invoke(cli, ["composite"]) assert result.exit_code == 2 assert "No such command 'composite'" in result.output + + +def test_subcommand_help_reports_import_errors_cleanly() -> None: + runner = CliRunner() + missing = ModuleNotFoundError("No module named 'pydantic'") + missing.name = "pydantic" + + with patch("roar.cli.import_module", side_effect=missing): + result = runner.invoke(cli, ["run", "--help"]) + + assert result.exit_code != 0 + assert "Failed to load 'run'" in result.output + assert "pydantic" in result.output + assert "Traceback" not in result.output diff --git a/tests/unit/test_core_models_lazy_imports.py b/tests/unit/test_core_models_lazy_imports.py new file mode 100644 index 00000000..67203acd --- /dev/null +++ b/tests/unit/test_core_models_lazy_imports.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +def test_core_models_package_lazy_imports_heavy_submodules() -> None: + repo_root = Path(__file__).resolve().parents[2] + script = """ +import json +import sys + + +def loaded(): + names = [ + "roar.core.models.artifact", + "roar.core.models.base", + "roar.core.models.dag", + "roar.core.models.glaas", + "roar.core.models.job", + "roar.core.models.lineage", + "roar.core.models.provenance", + "roar.core.models.run", + "roar.core.models.session", + "roar.core.models.telemetry", + "roar.core.models.vcs", + ] + return {name: name in sys.modules for name in names} + + +import roar.core.models as models +states = {"after_package_import": loaded()} + +from roar.core.models import RoarBaseModel + +states["after_base_export_import"] = loaded() + +import roar.core.models.base +states["after_base_module_import"] = loaded() + +print(json.dumps(states)) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + cwd=repo_root, + capture_output=True, + text=True, + check=True, + ) + states = json.loads(proc.stdout) + + assert states["after_package_import"] == { + "roar.core.models.artifact": False, + "roar.core.models.base": False, + "roar.core.models.dag": False, + "roar.core.models.glaas": False, + "roar.core.models.job": False, + "roar.core.models.lineage": False, + "roar.core.models.provenance": False, + "roar.core.models.run": False, + "roar.core.models.session": False, + "roar.core.models.telemetry": False, + "roar.core.models.vcs": False, + } + assert states["after_base_export_import"] == { + "roar.core.models.artifact": False, + "roar.core.models.base": True, + "roar.core.models.dag": False, + "roar.core.models.glaas": False, + "roar.core.models.job": False, + "roar.core.models.lineage": False, + "roar.core.models.provenance": False, + "roar.core.models.run": False, + "roar.core.models.session": False, + "roar.core.models.telemetry": False, + "roar.core.models.vcs": False, + } + assert states["after_base_module_import"] == states["after_base_export_import"] diff --git a/tests/unit/test_db_lazy_imports.py b/tests/unit/test_db_lazy_imports.py new file mode 100644 index 00000000..13e71c2e --- /dev/null +++ b/tests/unit/test_db_lazy_imports.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +from roar.db.context import create_database_context + + +def test_db_package_lazy_imports_heavy_submodules() -> None: + repo_root = Path(__file__).resolve().parents[2] + script = """ +import json +import sys +from pathlib import Path + + +def loaded(): + names = [ + "roar.db.context", + "roar.db.engine", + "roar.db.models", + "roar.db.repositories", + "roar.db.services", + "sqlalchemy", + ] + return {name: name in sys.modules for name in names} + + +import roar.db as db +states = {"after_package_import": loaded()} + +from roar.db import create_database_context + +states["after_factory_import"] = loaded() +_ctx = create_database_context(Path(".roar")) +states["after_factory_call"] = loaded() + +print(json.dumps(states)) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + cwd=repo_root, + capture_output=True, + text=True, + check=True, + ) + states = json.loads(proc.stdout) + + assert states["after_package_import"] == { + "roar.db.context": False, + "roar.db.engine": False, + "roar.db.models": False, + "roar.db.repositories": False, + "roar.db.services": False, + "sqlalchemy": False, + } + assert states["after_factory_import"] == { + "roar.db.context": True, + "roar.db.engine": False, + "roar.db.models": False, + "roar.db.repositories": False, + "roar.db.services": False, + "sqlalchemy": False, + } + assert states["after_factory_call"] == states["after_factory_import"] + + +def test_database_context_initializes_repositories_and_services_lazily(tmp_path: Path) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + + with create_database_context(roar_dir) as db_ctx: + assert db_ctx._artifact_repo is None + assert db_ctx._job_repo is None + assert db_ctx._session_repo is None + assert db_ctx._collection_repo is None + assert db_ctx._composite_repo is None + assert db_ctx._label_repo is None + assert db_ctx._hashing_service is None + assert db_ctx._session_service is None + assert db_ctx._lineage_service is None + assert db_ctx._job_recording_service is None + + session_repo = db_ctx.sessions + assert db_ctx._session_repo is session_repo + assert db_ctx._artifact_repo is None + assert db_ctx._job_repo is None + assert db_ctx._hashing_service is None + + job_repo = db_ctx.jobs + assert db_ctx._job_repo is job_repo + assert db_ctx._artifact_repo is not None + assert db_ctx._hashing_service is None + assert db_ctx._session_service is None + + session_service = db_ctx.session_service + assert db_ctx._session_service is session_service + assert db_ctx._hashing_service is None + assert db_ctx._lineage_service is None + assert db_ctx._job_recording_service is None + + job_recording = db_ctx.job_recording + assert db_ctx._job_recording_service is job_recording + assert db_ctx._hashing_service is not None + assert db_ctx._session_service is session_service + assert db_ctx._lineage_service is None + assert db_ctx._collection_repo is None + assert db_ctx._composite_repo is None + assert db_ctx._label_repo is None diff --git a/tests/unit/test_presenters_lazy_imports.py b/tests/unit/test_presenters_lazy_imports.py new file mode 100644 index 00000000..6786a2e8 --- /dev/null +++ b/tests/unit/test_presenters_lazy_imports.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +def test_presenters_package_lazy_imports_heavy_submodules() -> None: + repo_root = Path(__file__).resolve().parents[2] + script = """ +import json +import sys + + +def loaded(): + names = [ + "roar.presenters.console", + "roar.presenters.dag_data_builder", + "roar.presenters.dag_renderer", + "roar.presenters.formatting", + "roar.presenters.null", + "roar.presenters.show_renderer", + "roar.core.models", + ] + return {name: name in sys.modules for name in names} + + +import roar.presenters as presenters +states = {"after_package_import": loaded()} + +import roar.presenters.formatting +states["after_formatting_import"] = loaded() + +from roar.presenters import ShowRenderer + +states["after_show_renderer_import"] = loaded() + +print(json.dumps(states)) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + cwd=repo_root, + capture_output=True, + text=True, + check=True, + ) + states = json.loads(proc.stdout) + + assert states["after_package_import"] == { + "roar.presenters.console": False, + "roar.presenters.dag_data_builder": False, + "roar.presenters.dag_renderer": False, + "roar.presenters.formatting": False, + "roar.presenters.null": False, + "roar.presenters.show_renderer": False, + "roar.core.models": False, + } + assert states["after_formatting_import"] == { + "roar.presenters.console": False, + "roar.presenters.dag_data_builder": False, + "roar.presenters.dag_renderer": False, + "roar.presenters.formatting": True, + "roar.presenters.null": False, + "roar.presenters.show_renderer": False, + "roar.core.models": False, + } + assert states["after_show_renderer_import"] == { + "roar.presenters.console": False, + "roar.presenters.dag_data_builder": False, + "roar.presenters.dag_renderer": False, + "roar.presenters.formatting": True, + "roar.presenters.null": False, + "roar.presenters.show_renderer": True, + "roar.core.models": False, + } diff --git a/tests/unit/test_profile_roar.py b/tests/unit/test_profile_roar.py new file mode 100644 index 00000000..02ed080d --- /dev/null +++ b/tests/unit/test_profile_roar.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from scripts.profile_roar import _parse_importtime, _summary_stats + + +def test_parse_importtime_returns_top_cumulative_modules() -> None: + stderr = "\n".join( + [ + "import time: 100 | 100 | _io", + "import time: 200 | 500 | pathlib", + "import time: 50 | 900 | roar.execution.runtime.inject.sitecustomize", + "not an importtime line", + ] + ) + + results = _parse_importtime(stderr, limit=2) + + assert [item.module for item in results] == [ + "roar.execution.runtime.inject.sitecustomize", + "pathlib", + ] + assert results[0].cumulative_ms == 0.9 + assert results[1].self_ms == 0.2 + + +def test_summary_stats_handles_single_sample() -> None: + summary = _summary_stats([12.5]) + + assert summary["mean_ms"] == 12.5 + assert summary["median_ms"] == 12.5 + assert summary["min_ms"] == 12.5 + assert summary["max_ms"] == 12.5 + assert summary["stdev_ms"] == 0.0 diff --git a/tests/unit/test_publish_service_lazy_imports.py b/tests/unit/test_publish_service_lazy_imports.py new file mode 100644 index 00000000..d206956c --- /dev/null +++ b/tests/unit/test_publish_service_lazy_imports.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +def test_publish_service_import_stays_lightweight() -> None: + repo_root = Path(__file__).resolve().parents[2] + script = """ +import json +import sys + + +def loaded(): + names = [ + "roar.application.publish.collection", + "roar.application.publish.put_execution", + "roar.application.publish.register_execution", + "roar.application.publish.runtime", + "roar.application.publish.targets", + "roar.db.context", + "roar.db.engine", + "roar.db.models", + "roar.backends.ray.collector", + "roar.backends.ray.submit", + "sqlalchemy", + ] + return {name: name in sys.modules for name in names} + + +import roar.application.publish.service as service +states = {"after_module_import": loaded()} + +from roar.application.publish.service import put_artifacts, register_lineage_target + +states["after_entrypoint_import"] = loaded() + +print(json.dumps(states)) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + cwd=repo_root, + capture_output=True, + text=True, + check=True, + ) + states = json.loads(proc.stdout) + + expected = { + "roar.application.publish.collection": False, + "roar.application.publish.put_execution": False, + "roar.application.publish.register_execution": False, + "roar.application.publish.runtime": False, + "roar.application.publish.targets": False, + "roar.db.context": False, + "roar.db.engine": False, + "roar.db.models": False, + "roar.backends.ray.collector": False, + "roar.backends.ray.submit": False, + "sqlalchemy": False, + } + assert states["after_module_import"] == expected + assert states["after_entrypoint_import"] == expected diff --git a/tests/unit/test_put_cli.py b/tests/unit/test_put_cli.py new file mode 100644 index 00000000..d7391df5 --- /dev/null +++ b/tests/unit/test_put_cli.py @@ -0,0 +1,103 @@ +"""Unit tests for the put CLI output surface.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from click.testing import CliRunner + +from roar.application.publish.results import PutDryRunItem, PutResponse, PutUploadedFile +from roar.cli.commands.put import put + + +def _mock_context(tmp_path: Path) -> MagicMock: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + ctx = MagicMock() + ctx.roar_dir = roar_dir + ctx.repo_root = tmp_path + ctx.cwd = tmp_path + ctx.is_initialized = True + return ctx + + +def test_put_cli_prints_structured_success_summary(tmp_path: Path) -> None: + runner = CliRunner() + response = PutResponse( + success=True, + destination="s3://bucket/release", + job_id=7, + job_uid="putjob1234", + session_hash="0123456789abcdef0123456789abcdef", + session_url="https://glaas.example/dag/0123456789abcdef0123456789abcdef", + uploaded_files=[ + PutUploadedFile( + local_path=str(tmp_path / "model.pt"), + remote_url="s3://bucket/release/model.pt", + ) + ], + git_tag="roar/0123456789ab", + ) + + with patch("roar.cli.commands.put.put_artifacts", return_value=response): + result = runner.invoke( + put, + ["model.pt", "s3://bucket/release", "-m", "publish release"], + obj=_mock_context(tmp_path), + ) + + assert result.exit_code == 0, result.output + assert "Published 1 file(s) to s3://bucket/release" in result.output + assert "Session: 0123456789ab..." in result.output + assert "Job step: @7" in result.output + assert "Job UID: putjob1234" in result.output + assert "Git tag: roar/0123456789ab" in result.output + assert "GLaaS:" in result.output + assert "https://glaas.example/dag/0123456789abcdef0123456789abcdef" in result.output + assert "Next:" in result.output + assert "roar show --job putjob1234" in result.output + assert "roar show --session" in result.output + assert "Created git tag" not in result.output + + +def test_put_cli_dry_run_mentions_destination_and_count(tmp_path: Path) -> None: + runner = CliRunner() + response = PutResponse( + success=True, + destination="s3://bucket/release", + dry_run=True, + would_upload=[PutDryRunItem(path=str(tmp_path / "model.pt"), exists=True)], + ) + + with patch("roar.cli.commands.put.put_artifacts", return_value=response): + result = runner.invoke( + put, + ["model.pt", "s3://bucket/release", "-m", "publish release", "--dry-run"], + obj=_mock_context(tmp_path), + ) + + assert result.exit_code == 0, result.output + assert "Dry run: would upload 1 file(s) to s3://bucket/release" in result.output + assert str(tmp_path / "model.pt") in result.output + + +def test_put_cli_dry_run_does_not_load_glaas_web_url(tmp_path: Path) -> None: + runner = CliRunner() + response = PutResponse( + success=True, + destination="s3://bucket/release", + dry_run=True, + would_upload=[PutDryRunItem(path=str(tmp_path / "model.pt"), exists=True)], + ) + + with ( + patch("roar.cli.commands.put.put_artifacts", return_value=response), + patch("roar.cli.commands.put._resolve_glaas_web_url") as resolve_web_url, + ): + result = runner.invoke( + put, + ["model.pt", "s3://bucket/release", "-m", "publish release", "--dry-run"], + obj=_mock_context(tmp_path), + ) + + assert result.exit_code == 0, result.output + resolve_web_url.assert_not_called() diff --git a/tests/unit/test_query_db_context.py b/tests/unit/test_query_db_context.py new file mode 100644 index 00000000..cd654e0b --- /dev/null +++ b/tests/unit/test_query_db_context.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + +from roar.db.context import create_database_context +from roar.db.query_context import create_query_database_context + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def test_query_database_context_imports_without_sqlalchemy() -> None: + env = os.environ.copy() + pythonpath_entries = [str(REPO_ROOT)] + existing = env.get("PYTHONPATH") + if existing: + pythonpath_entries.append(existing) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries) + + proc = subprocess.run( + [ + sys.executable, + "-c", + """ +import sys +from pathlib import Path + +from roar.db.query_context import create_query_database_context + +assert "sqlalchemy" not in sys.modules +_ctx = create_query_database_context(Path(".roar")) +assert "sqlalchemy" not in sys.modules +print("ok") +""", + ], + cwd=REPO_ROOT, + env=env, + text=True, + capture_output=True, + check=False, + ) + + assert proc.returncode == 0, proc.stderr + assert proc.stdout.strip() == "ok" + + +def test_query_database_context_reads_hot_query_data(tmp_path: Path) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + + input_path = tmp_path / "input.txt" + output_path = tmp_path / "output.txt" + input_path.write_text("hello\n", encoding="utf-8") + output_path.write_text("world\n", encoding="utf-8") + + with create_database_context(roar_dir) as db_ctx: + session_id = db_ctx.sessions.create( + git_repo="/repo", + git_commit="abc123", + make_active=True, + ) + input_artifact_id, _ = db_ctx.artifacts.register( + {"blake3": "1" * 64}, + size=input_path.stat().st_size, + path=str(input_path), + ) + output_artifact_id, _ = db_ctx.artifacts.register( + {"blake3": "a" * 64, "sha256": "b" * 64}, + size=output_path.stat().st_size, + path=str(output_path), + metadata='{"dataset":{"dataset_id":"demo"}}', + ) + job_id, job_uid = db_ctx.jobs.create( + "python train.py", + 1700000000.0, + session_id=session_id, + step_number=1, + duration_seconds=1.5, + exit_code=0, + metadata='{"epoch":1}', + telemetry='{"backend":"local"}', + ) + db_ctx.jobs.add_input(job_id, input_artifact_id, str(input_path)) + db_ctx.jobs.add_output(job_id, output_artifact_id, str(output_path)) + db_ctx.labels.create_version("dag", {"owner": "alice"}, session_id=session_id) + db_ctx.labels.create_version("job", {"stage": "train"}, job_id=job_id) + db_ctx.labels.create_version("artifact", {"split": "train"}, artifact_id=output_artifact_id) + + with create_query_database_context(roar_dir) as query_db: + active_session = query_db.sessions.get_active() + assert active_session is not None + assert active_session["id"] == session_id + assert active_session["git_commit_start"] == "abc123" + + session_steps = query_db.sessions.get_steps(session_id) + assert [job["job_uid"] for job in session_steps] == [job_uid] + + jobs = query_db.jobs.get_by_session(session_id, limit=10) + assert [job["job_uid"] for job in jobs] == [job_uid] + + resolved_job = query_db.jobs.get_by_uid(job_uid[:6]) + assert resolved_job is not None + assert resolved_job["id"] == job_id + assert resolved_job["metadata"] == '{"epoch":1}' + assert resolved_job["telemetry"] == '{"backend":"local"}' + + inputs = query_db.jobs.get_inputs(job_id) + assert len(inputs) == 1 + assert inputs[0]["path"] == str(input_path) + assert inputs[0]["artifact_hash"] == "1" * 64 + + outputs = query_db.jobs.get_outputs(job_id) + assert len(outputs) == 1 + assert outputs[0]["path"] == str(output_path) + assert outputs[0]["artifact_hash"] == "a" * 64 + + distinct_outputs = query_db.jobs.get_distinct_outputs_by_session(session_id) + assert len(distinct_outputs) == 1 + assert distinct_outputs[0]["artifact_id"] == output_artifact_id + + artifact = query_db.artifacts.get_by_hash("a" * 8) + assert artifact is not None + assert artifact["id"] == output_artifact_id + assert artifact["metadata"] == '{"dataset":{"dataset_id":"demo"}}' + + by_path = query_db.artifacts.get_by_path(str(output_path)) + assert by_path is not None + assert by_path["id"] == output_artifact_id + + locations = query_db.artifacts.get_locations(output_artifact_id) + assert locations == [{"path": str(output_path)}] + + related_jobs = query_db.artifacts.get_jobs(output_artifact_id) + assert [job["job_uid"] for job in related_jobs["produced_by"]] == [job_uid] + assert related_jobs["consumed_by"] == [] + + dag_label = query_db.labels.get_current("dag", session_id=session_id) + assert dag_label is not None + assert dag_label["metadata"] == {"owner": "alice"} + + job_label = query_db.labels.get_current("job", job_id=job_id) + assert job_label is not None + assert job_label["metadata"] == {"stage": "train"} + + artifact_label = query_db.labels.get_current("artifact", artifact_id=output_artifact_id) + assert artifact_label is not None + assert artifact_label["metadata"] == {"split": "train"} + + +def test_step_lookup_prioritizes_host_over_task_phase_and_noise(tmp_path: Path) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + + with create_database_context(roar_dir) as db_ctx: + session_id = db_ctx.sessions.create(make_active=True) + _noise_id, _ = db_ctx.jobs.create( + "ray_task:shutdown", + 1.0, + session_id=session_id, + step_number=7, + execution_backend="ray", + execution_role="noise", + job_type="ray_task", + ) + _task_id, _ = db_ctx.jobs.create( + "ray_task:train", + 2.0, + session_id=session_id, + step_number=7, + execution_backend="ray", + execution_role="task", + job_type="ray_task", + ) + _phase_id, _ = db_ctx.jobs.create( + "ray_task:phase_train", + 3.0, + session_id=session_id, + step_number=7, + execution_backend="ray", + execution_role="phase", + job_type="ray_task", + ) + host_id, host_uid = db_ctx.jobs.create( + "python train.py", + 4.0, + session_id=session_id, + step_number=7, + execution_backend="local", + execution_role="host", + job_type=None, + ) + host_step = db_ctx.sessions.get_step_by_number(session_id, 7) + + assert host_step is not None + assert host_step["id"] == host_id + assert host_step["job_uid"] == host_uid + + with create_query_database_context(roar_dir) as query_db: + host_step_query = query_db.sessions.get_step_by_number(session_id, 7) + + assert host_step_query is not None + assert host_step_query["id"] == host_id + assert host_step_query["job_uid"] == host_uid + + +def test_step_lookup_uses_legacy_ray_command_fallback_when_role_missing(tmp_path: Path) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + + with create_database_context(roar_dir) as db_ctx: + session_id = db_ctx.sessions.create(make_active=True) + _task_id, _ = db_ctx.jobs.create( + "ray_task:process", + 1.0, + session_id=session_id, + step_number=9, + execution_backend="ray", + execution_role=None, + job_type="ray_task", + ) + _noise_id, _ = db_ctx.jobs.create( + "ray_task:shutdown", + 2.0, + session_id=session_id, + step_number=9, + execution_backend="ray", + execution_role=None, + job_type="ray_task", + ) + driver_id, driver_uid = db_ctx.jobs.create( + "python driver.py", + 3.0, + session_id=session_id, + step_number=9, + execution_backend=None, + execution_role=None, + job_type=None, + ) + driver_step = db_ctx.sessions.get_step_by_number(session_id, 9) + + assert driver_step is not None + assert driver_step["id"] == driver_id + assert driver_step["job_uid"] == driver_uid + + with create_query_database_context(roar_dir) as query_db: + driver_step_query = query_db.sessions.get_step_by_number(session_id, 9) + + assert driver_step_query is not None + assert driver_step_query["id"] == driver_id + assert driver_step_query["job_uid"] == driver_uid diff --git a/tests/unit/test_query_lazy_imports.py b/tests/unit/test_query_lazy_imports.py new file mode 100644 index 00000000..408b40c7 --- /dev/null +++ b/tests/unit/test_query_lazy_imports.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +def test_query_package_lazy_imports_heavy_submodules() -> None: + repo_root = Path(__file__).resolve().parents[2] + script = """ +import json +import sys + + +def loaded(): + names = [ + "roar.application.query.requests", + "roar.application.query.results", + "roar.application.query.status", + "roar.application.query.show", + "roar.application.query.log", + "roar.application.query.dag", + "roar.application.query.label", + "roar.application.query.lineage", + ] + return {name: name in sys.modules for name in names} + + +import roar.application.query as query +states = {"after_package_import": loaded()} + +from roar.application.query import StatusQueryRequest + +states["after_request_import"] = loaded() + +from roar.application.query import render_status + +states["after_render_import"] = loaded() + +print(json.dumps(states)) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + cwd=repo_root, + capture_output=True, + text=True, + check=True, + ) + states = json.loads(proc.stdout) + + assert states["after_package_import"] == { + "roar.application.query.requests": False, + "roar.application.query.results": False, + "roar.application.query.status": False, + "roar.application.query.show": False, + "roar.application.query.log": False, + "roar.application.query.dag": False, + "roar.application.query.label": False, + "roar.application.query.lineage": False, + } + assert states["after_request_import"] == { + "roar.application.query.requests": True, + "roar.application.query.results": False, + "roar.application.query.status": False, + "roar.application.query.show": False, + "roar.application.query.log": False, + "roar.application.query.dag": False, + "roar.application.query.label": False, + "roar.application.query.lineage": False, + } + assert states["after_render_import"] == { + "roar.application.query.requests": True, + "roar.application.query.results": True, + "roar.application.query.status": True, + "roar.application.query.show": False, + "roar.application.query.log": False, + "roar.application.query.dag": False, + "roar.application.query.label": False, + "roar.application.query.lineage": False, + } diff --git a/tests/unit/test_raw_config.py b/tests/unit/test_raw_config.py new file mode 100644 index 00000000..41fab47d --- /dev/null +++ b/tests/unit/test_raw_config.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from pathlib import Path + +from roar.integrations.config import config_get +from roar.integrations.config.raw import ( + get_raw_glaas_web_url, + get_raw_registration_omit_config, +) + + +def test_raw_registration_omit_matches_full_config_defaults(tmp_path: Path) -> None: + assert get_raw_registration_omit_config(start_dir=str(tmp_path)) == config_get( + "registration.omit", + start_dir=str(tmp_path), + ) + + +def test_raw_registration_omit_matches_file_and_env_overrides( + tmp_path: Path, + monkeypatch, +) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + (roar_dir / "config.toml").write_text( + """ +[registration.omit] +enabled = false + +[registration.omit.secrets] +values = ["literal-secret"] + +[registration.omit.allowlist] +patterns = ["safe-.*"] +""".strip() + ) + monkeypatch.setenv( + "ROAR_REGISTRATION__OMIT__ENV_VARS__NAMES", + '["LOCAL_API_TOKEN", "SECOND_TOKEN"]', + ) + + assert get_raw_registration_omit_config(start_dir=str(tmp_path)) == config_get( + "registration.omit", + start_dir=str(tmp_path), + ) + + +def test_raw_glaas_web_url_prefers_env_over_file(tmp_path: Path, monkeypatch) -> None: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + (roar_dir / "config.toml").write_text( + """ +[glaas] +web_url = "https://glaas.example/app/" +""".strip() + ) + + assert get_raw_glaas_web_url(start_dir=str(tmp_path)) == "https://glaas.example/app" + + monkeypatch.setenv("ROAR_GLAAS__WEB_URL", "https://override.example/ui/") + assert get_raw_glaas_web_url(start_dir=str(tmp_path)) == "https://override.example/ui" diff --git a/tests/unit/test_register_cli.py b/tests/unit/test_register_cli.py index c4f36454..8f033994 100644 --- a/tests/unit/test_register_cli.py +++ b/tests/unit/test_register_cli.py @@ -1,5 +1,8 @@ +"""Unit tests for the register CLI output surface.""" + from __future__ import annotations +from pathlib import Path from unittest.mock import MagicMock, patch from click.testing import CliRunner @@ -8,7 +11,17 @@ from roar.cli.commands.register import register -def _fake_result(): +def _mock_context(tmp_path: Path) -> MagicMock: + roar_dir = tmp_path / ".roar" + roar_dir.mkdir() + ctx = MagicMock() + ctx.roar_dir = roar_dir + ctx.cwd = tmp_path + ctx.is_initialized = True + return ctx + + +def _fake_result() -> RegisterLineageResponse: return RegisterLineageResponse( success=True, aborted_by_user=False, @@ -23,25 +36,74 @@ def _fake_result(): ) -def test_register_cli_accepts_s3_uri(tmp_path): +def test_register_cli_accepts_s3_uri(tmp_path: Path) -> None: runner = CliRunner() - ctx = MagicMock() - ctx.roar_dir = tmp_path / ".roar" - ctx.roar_dir.mkdir() - ctx.cwd = tmp_path - ctx.is_initialized = True - - with patch("roar.cli.commands.register.register_lineage_target") as mock_register: + with ( + patch("roar.cli.commands.register.register_lineage_target") as mock_register, + patch( + "roar.cli.commands.register._resolve_glaas_web_url", + return_value="https://glaas.local", + ), + ): mock_register.return_value = _fake_result() - with patch("roar.cli.commands.register.config_get", return_value="https://glaas.local"): - result = runner.invoke( - register, - ["s3://output-bucket/results/run123/final_report.json", "--yes"], - obj=ctx, - ) + result = runner.invoke( + register, + ["s3://output-bucket/results/run123/final_report.json", "--yes"], + obj=_mock_context(tmp_path), + ) assert result.exit_code == 0, result.output mock_register.assert_called_once() request = mock_register.call_args.args[0] assert request.target == "s3://output-bucket/results/run123/final_report.json" + + +def test_register_cli_prints_next_steps_for_artifacts(tmp_path: Path) -> None: + runner = CliRunner() + artifact_hash = "abcdef0123456789abcdef0123456789" + response = RegisterLineageResponse( + success=True, + session_hash="0123456789abcdef0123456789abcdef", + artifact_hash=artifact_hash, + jobs_registered=3, + artifacts_registered=4, + links_created=5, + ) + + with ( + patch("roar.cli.commands.register.register_lineage_target", return_value=response), + patch( + "roar.cli.commands.register._resolve_glaas_web_url", + return_value="https://glaas.example", + ), + ): + result = runner.invoke(register, ["model.pt"], obj=_mock_context(tmp_path)) + + assert result.exit_code == 0, result.output + assert "Registered lineage for: model.pt" in result.output + assert "Session: 0123456789ab..." in result.output + assert "GLaaS:" in result.output + assert "https://glaas.example/dag/0123456789abcdef0123456789abcdef" in result.output + assert "https://glaas.example/artifact/abcdef0123456789abcdef0123456789" in result.output + assert "Next:" in result.output + assert f"roar show --artifact {artifact_hash}" in result.output + assert f"roar reproduce {artifact_hash}" in result.output + + +def test_register_cli_dry_run_mentions_target(tmp_path: Path) -> None: + runner = CliRunner() + response = RegisterLineageResponse( + success=True, + session_hash="0123456789abcdef0123456789abcdef", + jobs_registered=2, + artifacts_registered=3, + links_created=4, + ) + + with patch("roar.cli.commands.register.register_lineage_target", return_value=response): + result = runner.invoke(register, ["model.pt", "--dry-run"], obj=_mock_context(tmp_path)) + + assert result.exit_code == 0, result.output + assert "Dry run: would register lineage for: model.pt" in result.output + assert "Session: 0123456789ab..." in result.output diff --git a/tests/unit/test_register_execution_lazy_imports.py b/tests/unit/test_register_execution_lazy_imports.py new file mode 100644 index 00000000..23db57aa --- /dev/null +++ b/tests/unit/test_register_execution_lazy_imports.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +def test_register_execution_import_stays_lightweight_until_real_registration() -> None: + repo_root = Path(__file__).resolve().parents[2] + script = """ +import json +import sys + + +def loaded(): + names = [ + "roar.application.publish.blake3_upgrade", + "roar.application.publish.composite_builder", + "roar.application.publish.lineage_composites", + "roar.application.publish.registration", + "roar.db.context", + "roar.db.engine", + "roar.db.models", + "sqlalchemy", + ] + return {name: name in sys.modules for name in names} + + +import roar.application.publish.register_execution as register_execution +states = {"after_module_import": loaded()} + +service = register_execution.RegisterService() +states["after_service_init"] = loaded() + +print(json.dumps(states)) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + cwd=repo_root, + capture_output=True, + text=True, + check=True, + ) + states = json.loads(proc.stdout) + + expected = { + "roar.application.publish.blake3_upgrade": False, + "roar.application.publish.composite_builder": False, + "roar.application.publish.lineage_composites": False, + "roar.application.publish.registration": False, + "roar.db.context": False, + "roar.db.engine": False, + "roar.db.models": False, + "sqlalchemy": False, + } + assert states["after_module_import"] == expected + assert states["after_service_init"] == expected diff --git a/tests/unit/test_register_preview_jobs.py b/tests/unit/test_register_preview_jobs.py new file mode 100644 index 00000000..d2c9b306 --- /dev/null +++ b/tests/unit/test_register_preview_jobs.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from roar.application.publish.job_preparation import ( + estimate_links as estimate_links_full, +) +from roar.application.publish.job_preparation import ( + normalize_jobs_for_registration as normalize_jobs_full, +) +from roar.application.publish.job_preparation import ( + order_jobs_for_registration as order_jobs_full, +) +from roar.application.publish.register_preview_jobs import ( + estimate_links as estimate_links_preview, +) +from roar.application.publish.register_preview_jobs import ( + normalize_jobs_for_registration as normalize_jobs_preview, +) +from roar.application.publish.register_preview_jobs import ( + order_jobs_for_registration as order_jobs_preview, +) + + +def test_register_preview_jobs_match_full_helpers_for_ray_submit_lineage() -> None: + jobs = [ + { + "id": 1, + "job_uid": "local-submit", + "step_number": 1, + "timestamp": 10.0, + "command": "ray job submit --address http://localhost:8265 -- python main.py", + "job_type": None, + "_outputs": [{"artifact_id": "driver-output"}], + }, + { + "id": 2, + "job_uid": "noise-job", + "parent_job_uid": "local-submit", + "step_number": 2, + "timestamp": 20.0, + "command": "ray_task:unknown", + "job_type": "ray_task", + "_outputs": [{"artifact_id": "noise-output"}], + }, + { + "id": 3, + "job_uid": "phase-job", + "parent_job_uid": "noise-job", + "step_number": 3, + "timestamp": 30.0, + "command": "ray_task:process_shard", + "job_type": "ray_task", + "_inputs": [{"artifact_id": "driver-output"}], + "_outputs": [{"artifact_id": "task-output"}], + }, + ] + + normalized_full = normalize_jobs_full(jobs) + normalized_preview = normalize_jobs_preview(jobs) + + assert normalized_preview == normalized_full + assert order_jobs_preview(normalized_preview) == order_jobs_full(normalized_full) + assert estimate_links_preview(normalized_preview) == estimate_links_full(normalized_full) + + +def test_register_preview_jobs_match_full_helpers_for_legacy_driver_lineage() -> None: + jobs = [ + { + "id": 1, + "job_uid": "driver", + "step_number": 1, + "timestamp": 10.0, + "command": "python train.py", + "job_type": None, + }, + { + "id": 2, + "job_uid": "worker", + "parent_job_uid": "missing-driver", + "step_number": 2, + "timestamp": 20.0, + "command": "ray_task:process_batch", + "job_type": "ray_task", + }, + ] + + assert normalize_jobs_preview(jobs) == normalize_jobs_full(jobs) diff --git a/tests/unit/test_register_preview_lazy_imports.py b/tests/unit/test_register_preview_lazy_imports.py new file mode 100644 index 00000000..904a64d2 --- /dev/null +++ b/tests/unit/test_register_preview_lazy_imports.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +def test_register_dry_run_avoids_registration_and_full_config_imports() -> None: + repo_root = Path(__file__).resolve().parents[2] + script = """ +import json +import sys +from pathlib import Path +from types import SimpleNamespace + +import roar.application.publish.service as service +from roar.application.publish.requests import RegisterLineageRequest +from roar.application.publish.targets import ResolvedRegisterTarget +from roar.core.interfaces.lineage import LineageData + + +def loaded(): + names = [ + "roar.application.publish.register_execution", + "roar.application.publish.register_preparation", + "roar.application.publish.runtime", + "roar.integrations.config.access", + "roar.integrations.config.loader", + "roar.integrations.config.schema", + ] + return {name: name in sys.modules for name in names} + + +states = {"after_service_import": loaded()} + +runtime = SimpleNamespace( + glaas_client=object(), + session_service=object(), + lineage_collector=object(), +) +collected = SimpleNamespace( + lineage=LineageData(jobs=[], artifacts=[], artifact_hashes=set(), pipeline={"id": 7}), + session_id=7, + artifact_hash="a" * 64, + session_hash_override=None, +) +prepared = SimpleNamespace( + git_context=SimpleNamespace(repo="https://github.com/test/repo", commit="abc123", branch="main"), + session_id=7, + session_hash="a" * 64, + session_url=None, + git_tag_name=None, + git_tag_repo_root=None, +) + +service.build_register_preview_runtime = lambda: runtime +service.resolve_register_lineage_target = ( + lambda target, cwd, roar_dir: ResolvedRegisterTarget(kind="artifact_path", value="model.pt") +) +service.collect_register_lineage = lambda **kwargs: (collected, None) +service.prepare_register_preview_execution = lambda **kwargs: prepared + +response = service.register_lineage_target( + RegisterLineageRequest( + target="model.pt", + roar_dir=Path.cwd() / ".roar", + cwd=Path.cwd(), + dry_run=True, + skip_confirmation=True, + ) +) +states["after_dry_run"] = loaded() + +print(json.dumps({"success": response.success, "states": states})) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + cwd=repo_root, + capture_output=True, + text=True, + check=True, + ) + result = json.loads(proc.stdout) + + assert result["success"] is True + assert result["states"]["after_service_import"] == { + "roar.application.publish.register_execution": False, + "roar.application.publish.register_preparation": False, + "roar.application.publish.runtime": False, + "roar.integrations.config.access": False, + "roar.integrations.config.loader": False, + "roar.integrations.config.schema": False, + } + assert result["states"]["after_dry_run"] == result["states"]["after_service_import"] diff --git a/tests/unit/test_run_report.py b/tests/unit/test_run_report.py new file mode 100644 index 00000000..2537b91a --- /dev/null +++ b/tests/unit/test_run_report.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Any + +from roar.core.models.run import RunResult +from roar.presenters.run_report import RunReportPresenter + + +class _CapturePresenter: + def __init__(self) -> None: + self.messages: list[str] = [] + + def print(self, message: str) -> None: + self.messages.append(message) + + def print_error(self, message: str) -> None: + self.messages.append(message) + + def print_table(self, headers: list[str], rows: list[list[str]]) -> None: + return None + + def print_job(self, job: dict[str, Any], verbose: bool = False) -> None: + return None + + def print_artifact(self, artifact: dict[str, Any]) -> None: + return None + + def print_dag( + self, + summary: dict[str, Any], + stale_steps: set[int] | None = None, + ) -> None: + return None + + def confirm(self, message: str, default: bool = False) -> bool: + return default + + +def test_interrupted_report_references_pop_not_clean() -> None: + presenter = _CapturePresenter() + report = RunReportPresenter(presenter) + + report.show_report( + RunResult( + exit_code=130, + job_id=1, + job_uid="job12345", + duration=0.5, + inputs=[], + outputs=[{"path": "/tmp/out.txt", "size": 1, "hashes": []}], + interrupted=True, + is_build=False, + ), + ["python", "train.py"], + ) + + rendered = "\n".join(presenter.messages) + assert "roar pop" in rendered + assert "roar clean" not in rendered + assert "roar show --job job12345" in rendered + + +def test_successful_report_suggests_show_and_dag() -> None: + presenter = _CapturePresenter() + report = RunReportPresenter(presenter) + + report.show_report( + RunResult( + exit_code=0, + job_id=2, + job_uid="job67890", + duration=1.0, + inputs=[], + outputs=[], + interrupted=False, + is_build=False, + ), + ["python", "train.py"], + ) + + rendered = "\n".join(presenter.messages) + assert "Next:" in rendered + assert "roar show --job job67890" in rendered + assert "roar dag" in rendered diff --git a/tests/unit/test_tracer_data_loader.py b/tests/unit/test_tracer_data_loader.py index cf955c1e..90812b44 100644 --- a/tests/unit/test_tracer_data_loader.py +++ b/tests/unit/test_tracer_data_loader.py @@ -12,6 +12,10 @@ def _write_msgpack(path: Path, payload: dict) -> None: path.write_bytes(msgpack.packb(payload, use_bin_type=True)) +def _write_json(path: Path, payload: dict) -> None: + path.write_text(json.dumps(payload)) + + class TestDataLoaderService: def test_loads_json_report_when_tracer_writes_json(self, tmp_path: Path) -> None: report = { @@ -25,7 +29,7 @@ def test_loads_json_report_when_tracer_writes_json(self, tmp_path: Path) -> None "end_time": 2.5, } report_path = tmp_path / "trace.msgpack" - report_path.write_text(json.dumps(report)) + _write_json(report_path, report) data = DataLoaderService().load_tracer_data(str(report_path))