Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion roar/application/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from typing import Any, Protocol

from ..db.context import DatabaseContext
from ..execution.recording.dataset_metadata import AUTO_DATASET_LABEL_KEYS

RESERVED_LABEL_KEYS = {"dataset.type", "dataset.modality"}
RESERVED_LABEL_KEYS = set(AUTO_DATASET_LABEL_KEYS)


@dataclass(frozen=True)
Expand Down
58 changes: 57 additions & 1 deletion roar/execution/recording/dataset_metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
"""Helpers for attaching dataset identity labels to composite artifact metadata."""
"""Helpers for attaching dataset identity metadata and labels to composite artifacts."""

from __future__ import annotations

from typing import Any
from urllib.parse import urlparse

from .dataset_profile import build_dataset_profile

AUTO_DATASET_LABEL_KEYS = frozenset(
{
"dataset.type",
"dataset.id",
"dataset.fingerprint",
"dataset.fingerprint_algorithm",
"dataset.split",
"dataset.version_hint",
"dataset.modality",
}
)


def find_matching_identifier(
root_path: str, dataset_identifiers: list[dict[str, Any]]
Expand Down Expand Up @@ -50,3 +64,45 @@ def build_dataset_metadata(identifier: dict[str, Any]) -> dict[str, Any]:
if value is not None:
meta[key] = value
return meta


def build_dataset_label_metadata(
identifier: dict[str, Any],
*,
components: list[dict[str, Any]] | None = None,
component_count_total: int | None = None,
) -> dict[str, Any]:
"""Build the system-managed label document for a detected dataset artifact.

The label payload is intentionally smaller and more stable than the full
dataset metadata blob. It captures the artifact's dataset identity and the
most queryable derived characteristics for local labels and future sync.
"""
dataset: dict[str, Any] = {"type": "dataset"}

value = identifier.get("dataset_id")
if value is not None:
dataset["id"] = value

value = identifier.get("dataset_fingerprint")
if value is not None:
dataset["fingerprint"] = value

value = identifier.get("dataset_fingerprint_algorithm")
if value is not None:
dataset["fingerprint_algorithm"] = value

value = identifier.get("split")
if value is not None:
dataset["split"] = value

value = identifier.get("version_hint")
if value is not None:
dataset["version_hint"] = value

profile = build_dataset_profile(components or [], total_components=component_count_total)
modality = profile.get("modality_hint") if isinstance(profile, dict) else None
if isinstance(modality, str) and modality:
dataset["modality"] = modality

return {"dataset": dataset}
87 changes: 86 additions & 1 deletion roar/execution/recording/job_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from ...db.context import optional_repo
from ...db.hashing import hash_files_blake3
from .dataset_identifier import DatasetIdentifierInferer
from .dataset_metadata import build_dataset_metadata, find_matching_identifier
from .dataset_metadata import (
AUTO_DATASET_LABEL_KEYS,
build_dataset_label_metadata,
build_dataset_metadata,
find_matching_identifier,
)

if TYPE_CHECKING:
from ...core.models.run import RunContext
Expand Down Expand Up @@ -349,8 +354,14 @@ def materialize(
}
}
matching = find_matching_identifier(str(root), dataset_identifiers)
dataset_label_metadata: dict[str, Any] = {}
if matching is not None:
meta_dict["dataset"] = build_dataset_metadata(matching)
dataset_label_metadata = build_dataset_label_metadata(
matching,
components=list(composite.payload.get("components") or []),
component_count_total=composite.component_count_total,
)
metadata = json.dumps(meta_dict)
artifact_id, _created = db_ctx.artifacts.register(
hashes={"composite-blake3": composite.digest},
Expand All @@ -359,6 +370,12 @@ def materialize(
source_type=composite.payload.get("source_type"),
metadata=metadata,
)
if dataset_label_metadata:
self._sync_dataset_labels(
db_ctx,
artifact_id=artifact_id,
dataset_label_metadata=dataset_label_metadata,
)
composite_repo.upsert_details(
artifact_id=artifact_id,
components=list(composite.payload.get("components") or []),
Expand Down Expand Up @@ -479,6 +496,74 @@ def _is_path_under_root(path: Path, root: Path) -> bool:
except ValueError:
return False

@staticmethod
def _sync_dataset_labels(
db_ctx: Any,
*,
artifact_id: str,
dataset_label_metadata: dict[str, Any],
) -> None:
labels_repo = cast(Any, optional_repo(db_ctx, "labels"))
if labels_repo is None or not dataset_label_metadata:
return

current = labels_repo.get_current("artifact", artifact_id=artifact_id)
current_metadata = current.get("metadata") if isinstance(current, dict) else {}
if not isinstance(current_metadata, dict):
current_metadata = {}

merged = CompositeOutputMaterializer._merge_dataset_labels(
current_metadata,
dataset_label_metadata,
)
if merged == current_metadata:
return

labels_repo.create_version("artifact", merged, artifact_id=artifact_id)

@staticmethod
def _merge_dataset_labels(current: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
merged = CompositeOutputMaterializer._remove_label_paths(
current,
AUTO_DATASET_LABEL_KEYS,
)
return CompositeOutputMaterializer._deep_merge(merged, patch)

@staticmethod
def _deep_merge(current: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
merged = json.loads(json.dumps(current))
for key, value in patch.items():
existing = merged.get(key)
if isinstance(existing, dict) and isinstance(value, dict):
merged[key] = CompositeOutputMaterializer._deep_merge(existing, value)
else:
merged[key] = value
return merged

@staticmethod
def _remove_label_paths(metadata: dict[str, Any], reserved_paths: set[str] | frozenset[str]) -> dict[str, Any]:
cleaned = json.loads(json.dumps(metadata))
for path in reserved_paths:
CompositeOutputMaterializer._remove_nested(cleaned, path.split("."))
return cleaned

@staticmethod
def _remove_nested(root: dict[str, Any], path: list[str]) -> None:
if not path:
return
key = path[0]
if key not in root:
return
if len(path) == 1:
root.pop(key, None)
return
child = root.get(key)
if not isinstance(child, dict):
return
CompositeOutputMaterializer._remove_nested(child, path[1:])
if not child:
root.pop(key, None)


class ExecutionJobRecorder:
"""Persist a traced execution and return reporting payload pieces."""
Expand Down
51 changes: 51 additions & 0 deletions tests/happy_path/test_label_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,57 @@ def _artifact_label_rows_by_hash(
class TestLabelCommand:
"""CLI product-path tests for label lifecycle behavior."""

def test_detected_dataset_composite_artifact_gets_auto_labels(
self,
temp_git_repo,
roar_cli,
git_commit,
python_exe,
):
dataset_script = temp_git_repo / "emit_dataset.py"
dataset_script.write_text(
"\n".join(
[
"from __future__ import annotations",
"",
"from pathlib import Path",
"import argparse",
"",
"parser = argparse.ArgumentParser()",
"parser.add_argument('--output-dir', required=True)",
"args = parser.parse_args()",
"",
"root = Path(args.output_dir)",
"(root / 'train').mkdir(parents=True, exist_ok=True)",
"(root / 'train' / 'part-00000.csv').write_text('value\\n1\\n', encoding='utf-8')",
"(root / 'train' / 'part-00001.csv').write_text('value\\n2\\n', encoding='utf-8')",
]
),
encoding="utf-8",
)
git_commit("Add dataset emitter")

result = roar_cli("run", python_exe, "emit_dataset.py", "--output-dir", "dataset")
assert result.returncode == 0

dataset_root = temp_git_repo / "dataset"

label_show = _assert_ok(roar_cli("label", "show", "artifact", "dataset", check=False))
assert f"dataset.id={dataset_root.resolve().as_uri()}" in label_show
assert "dataset.modality=tabular" in label_show
assert "dataset.type=dataset" in label_show

show_output = _assert_ok(roar_cli("show", "dataset", check=False))
assert "Labels:" in show_output
assert f"dataset.id={dataset_root.resolve().as_uri()}" in show_output
assert "dataset.modality=tabular" in show_output
assert "dataset.type=dataset" in show_output

rows = _artifact_label_rows(temp_git_repo, dataset_root)
assert rows[0][1]["dataset"]["type"] == "dataset"
assert rows[0][1]["dataset"]["id"] == dataset_root.resolve().as_uri()
assert rows[0][1]["dataset"]["modality"] == "tabular"

def test_artifact_label_set_patches_current_document_and_preserves_history(
self,
temp_git_repo,
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_dataset_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ def test_record_materializes_local_composite_outputs(tmp_path: Path):
assert composite_artifact is not None
assert composite_artifact["kind"] == "composite"

current_labels = db_ctx.labels.get_current(
"artifact",
artifact_id=composite_artifact["id"],
)
assert current_labels is not None
assert current_labels["metadata"]["dataset"]["type"] == "dataset"
assert current_labels["metadata"]["dataset"]["id"] == dataset_dir.resolve().as_uri()
assert current_labels["metadata"]["dataset"]["modality"] == "tabular"
assert current_labels["metadata"]["dataset"]["fingerprint"]

summary = db_ctx.composites.get(composite_artifact["id"])
assert summary is not None
assert summary["component_count"] == 2
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/test_dataset_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

from roar.execution.recording.dataset_metadata import (
AUTO_DATASET_LABEL_KEYS,
build_dataset_label_metadata,
build_dataset_metadata,
find_matching_identifier,
)
Expand Down Expand Up @@ -126,3 +128,54 @@ def test_ignores_unknown_keys(self):
result = build_dataset_metadata(identifier)
assert "extra_key" not in result
assert result["dataset_id"] == "file:///data/raw"


# ---------------------------------------------------------------------------
# build_dataset_label_metadata
# ---------------------------------------------------------------------------


class TestBuildDatasetLabelMetadata:
def test_includes_dataset_identity_and_modality_labels(self):
identifier = {
"dataset_id": "file:///data/imagenet",
"dataset_fingerprint": "a1b2c3d4e5f67890",
"dataset_fingerprint_algorithm": "blake3",
"split": "train",
"version_hint": "v2",
}

result = build_dataset_label_metadata(
identifier,
components=[
{
"relative_path": "train/class_a/image-0001.jpg",
"component_size": 42,
"component_type": "image/jpeg",
}
],
component_count_total=1,
)

assert result == {
"dataset": {
"type": "dataset",
"id": "file:///data/imagenet",
"fingerprint": "a1b2c3d4e5f67890",
"fingerprint_algorithm": "blake3",
"split": "train",
"version_hint": "v2",
"modality": "image",
}
}

def test_declares_reserved_paths_for_system_managed_dataset_labels(self):
assert {
"dataset.type",
"dataset.id",
"dataset.fingerprint",
"dataset.fingerprint_algorithm",
"dataset.split",
"dataset.version_hint",
"dataset.modality",
} == AUTO_DATASET_LABEL_KEYS
17 changes: 17 additions & 0 deletions tests/unit/test_label_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import pytest

from roar.application.labels import LabelService


def test_reject_reserved_keys_blocks_system_managed_dataset_labels() -> None:
with pytest.raises(ValueError, match="Reserved label keys cannot be set manually"):
LabelService._reject_reserved_keys(
{
"dataset": {
"id": "file:///data/train",
"modality": "tabular",
}
}
)