diff --git a/roar/application/labels.py b/roar/application/labels.py index 093e52bd..1bec06d3 100644 --- a/roar/application/labels.py +++ b/roar/application/labels.py @@ -14,8 +14,9 @@ from typing import Any, Protocol from ..db.context import DatabaseContext +from ..execution.recording.dataset_metadata import AUTO_DATASET_LABEL_KEYS -RESERVED_LABEL_KEYS = {"dataset.type", "dataset.modality"} +RESERVED_LABEL_KEYS = set(AUTO_DATASET_LABEL_KEYS) @dataclass(frozen=True) diff --git a/roar/execution/recording/dataset_metadata.py b/roar/execution/recording/dataset_metadata.py index 35e35848..8fbca5f8 100644 --- a/roar/execution/recording/dataset_metadata.py +++ b/roar/execution/recording/dataset_metadata.py @@ -1,10 +1,24 @@ -"""Helpers for attaching dataset identity labels to composite artifact metadata.""" +"""Helpers for attaching dataset identity metadata and labels to composite artifacts.""" from __future__ import annotations from typing import Any from urllib.parse import urlparse +from .dataset_profile import build_dataset_profile + +AUTO_DATASET_LABEL_KEYS = frozenset( + { + "dataset.type", + "dataset.id", + "dataset.fingerprint", + "dataset.fingerprint_algorithm", + "dataset.split", + "dataset.version_hint", + "dataset.modality", + } +) + def find_matching_identifier( root_path: str, dataset_identifiers: list[dict[str, Any]] @@ -50,3 +64,45 @@ def build_dataset_metadata(identifier: dict[str, Any]) -> dict[str, Any]: if value is not None: meta[key] = value return meta + + +def build_dataset_label_metadata( + identifier: dict[str, Any], + *, + components: list[dict[str, Any]] | None = None, + component_count_total: int | None = None, +) -> dict[str, Any]: + """Build the system-managed label document for a detected dataset artifact. + + The label payload is intentionally smaller and more stable than the full + dataset metadata blob. It captures the artifact's dataset identity and the + most queryable derived characteristics for local labels and future sync. + """ + dataset: dict[str, Any] = {"type": "dataset"} + + value = identifier.get("dataset_id") + if value is not None: + dataset["id"] = value + + value = identifier.get("dataset_fingerprint") + if value is not None: + dataset["fingerprint"] = value + + value = identifier.get("dataset_fingerprint_algorithm") + if value is not None: + dataset["fingerprint_algorithm"] = value + + value = identifier.get("split") + if value is not None: + dataset["split"] = value + + value = identifier.get("version_hint") + if value is not None: + dataset["version_hint"] = value + + profile = build_dataset_profile(components or [], total_components=component_count_total) + modality = profile.get("modality_hint") if isinstance(profile, dict) else None + if isinstance(modality, str) and modality: + dataset["modality"] = modality + + return {"dataset": dataset} diff --git a/roar/execution/recording/job_recording.py b/roar/execution/recording/job_recording.py index 336a6ce3..c832d53c 100644 --- a/roar/execution/recording/job_recording.py +++ b/roar/execution/recording/job_recording.py @@ -27,7 +27,12 @@ from ...db.context import optional_repo from ...db.hashing import hash_files_blake3 from .dataset_identifier import DatasetIdentifierInferer -from .dataset_metadata import build_dataset_metadata, find_matching_identifier +from .dataset_metadata import ( + AUTO_DATASET_LABEL_KEYS, + build_dataset_label_metadata, + build_dataset_metadata, + find_matching_identifier, +) if TYPE_CHECKING: from ...core.models.run import RunContext @@ -349,8 +354,14 @@ def materialize( } } matching = find_matching_identifier(str(root), dataset_identifiers) + dataset_label_metadata: dict[str, Any] = {} if matching is not None: meta_dict["dataset"] = build_dataset_metadata(matching) + dataset_label_metadata = build_dataset_label_metadata( + matching, + components=list(composite.payload.get("components") or []), + component_count_total=composite.component_count_total, + ) metadata = json.dumps(meta_dict) artifact_id, _created = db_ctx.artifacts.register( hashes={"composite-blake3": composite.digest}, @@ -359,6 +370,12 @@ def materialize( source_type=composite.payload.get("source_type"), metadata=metadata, ) + if dataset_label_metadata: + self._sync_dataset_labels( + db_ctx, + artifact_id=artifact_id, + dataset_label_metadata=dataset_label_metadata, + ) composite_repo.upsert_details( artifact_id=artifact_id, components=list(composite.payload.get("components") or []), @@ -479,6 +496,74 @@ def _is_path_under_root(path: Path, root: Path) -> bool: except ValueError: return False + @staticmethod + def _sync_dataset_labels( + db_ctx: Any, + *, + artifact_id: str, + dataset_label_metadata: dict[str, Any], + ) -> None: + labels_repo = cast(Any, optional_repo(db_ctx, "labels")) + if labels_repo is None or not dataset_label_metadata: + return + + current = labels_repo.get_current("artifact", artifact_id=artifact_id) + current_metadata = current.get("metadata") if isinstance(current, dict) else {} + if not isinstance(current_metadata, dict): + current_metadata = {} + + merged = CompositeOutputMaterializer._merge_dataset_labels( + current_metadata, + dataset_label_metadata, + ) + if merged == current_metadata: + return + + labels_repo.create_version("artifact", merged, artifact_id=artifact_id) + + @staticmethod + def _merge_dataset_labels(current: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: + merged = CompositeOutputMaterializer._remove_label_paths( + current, + AUTO_DATASET_LABEL_KEYS, + ) + return CompositeOutputMaterializer._deep_merge(merged, patch) + + @staticmethod + def _deep_merge(current: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: + merged = json.loads(json.dumps(current)) + for key, value in patch.items(): + existing = merged.get(key) + if isinstance(existing, dict) and isinstance(value, dict): + merged[key] = CompositeOutputMaterializer._deep_merge(existing, value) + else: + merged[key] = value + return merged + + @staticmethod + def _remove_label_paths(metadata: dict[str, Any], reserved_paths: set[str] | frozenset[str]) -> dict[str, Any]: + cleaned = json.loads(json.dumps(metadata)) + for path in reserved_paths: + CompositeOutputMaterializer._remove_nested(cleaned, path.split(".")) + return cleaned + + @staticmethod + def _remove_nested(root: dict[str, Any], path: list[str]) -> None: + if not path: + return + key = path[0] + if key not in root: + return + if len(path) == 1: + root.pop(key, None) + return + child = root.get(key) + if not isinstance(child, dict): + return + CompositeOutputMaterializer._remove_nested(child, path[1:]) + if not child: + root.pop(key, None) + class ExecutionJobRecorder: """Persist a traced execution and return reporting payload pieces.""" diff --git a/tests/happy_path/test_label_command.py b/tests/happy_path/test_label_command.py index 3f50b405..f321b7a2 100644 --- a/tests/happy_path/test_label_command.py +++ b/tests/happy_path/test_label_command.py @@ -55,6 +55,57 @@ def _artifact_label_rows_by_hash( class TestLabelCommand: """CLI product-path tests for label lifecycle behavior.""" + def test_detected_dataset_composite_artifact_gets_auto_labels( + self, + temp_git_repo, + roar_cli, + git_commit, + python_exe, + ): + dataset_script = temp_git_repo / "emit_dataset.py" + dataset_script.write_text( + "\n".join( + [ + "from __future__ import annotations", + "", + "from pathlib import Path", + "import argparse", + "", + "parser = argparse.ArgumentParser()", + "parser.add_argument('--output-dir', required=True)", + "args = parser.parse_args()", + "", + "root = Path(args.output_dir)", + "(root / 'train').mkdir(parents=True, exist_ok=True)", + "(root / 'train' / 'part-00000.csv').write_text('value\\n1\\n', encoding='utf-8')", + "(root / 'train' / 'part-00001.csv').write_text('value\\n2\\n', encoding='utf-8')", + ] + ), + encoding="utf-8", + ) + git_commit("Add dataset emitter") + + result = roar_cli("run", python_exe, "emit_dataset.py", "--output-dir", "dataset") + assert result.returncode == 0 + + dataset_root = temp_git_repo / "dataset" + + label_show = _assert_ok(roar_cli("label", "show", "artifact", "dataset", check=False)) + assert f"dataset.id={dataset_root.resolve().as_uri()}" in label_show + assert "dataset.modality=tabular" in label_show + assert "dataset.type=dataset" in label_show + + show_output = _assert_ok(roar_cli("show", "dataset", check=False)) + assert "Labels:" in show_output + assert f"dataset.id={dataset_root.resolve().as_uri()}" in show_output + assert "dataset.modality=tabular" in show_output + assert "dataset.type=dataset" in show_output + + rows = _artifact_label_rows(temp_git_repo, dataset_root) + assert rows[0][1]["dataset"]["type"] == "dataset" + assert rows[0][1]["dataset"]["id"] == dataset_root.resolve().as_uri() + assert rows[0][1]["dataset"]["modality"] == "tabular" + def test_artifact_label_set_patches_current_document_and_preserves_history( self, temp_git_repo, diff --git a/tests/unit/test_dataset_identifier.py b/tests/unit/test_dataset_identifier.py index a0ced680..ecd78532 100644 --- a/tests/unit/test_dataset_identifier.py +++ b/tests/unit/test_dataset_identifier.py @@ -138,6 +138,16 @@ def test_record_materializes_local_composite_outputs(tmp_path: Path): assert composite_artifact is not None assert composite_artifact["kind"] == "composite" + current_labels = db_ctx.labels.get_current( + "artifact", + artifact_id=composite_artifact["id"], + ) + assert current_labels is not None + assert current_labels["metadata"]["dataset"]["type"] == "dataset" + assert current_labels["metadata"]["dataset"]["id"] == dataset_dir.resolve().as_uri() + assert current_labels["metadata"]["dataset"]["modality"] == "tabular" + assert current_labels["metadata"]["dataset"]["fingerprint"] + summary = db_ctx.composites.get(composite_artifact["id"]) assert summary is not None assert summary["component_count"] == 2 diff --git a/tests/unit/test_dataset_label.py b/tests/unit/test_dataset_label.py index 447551bd..bd69ff48 100644 --- a/tests/unit/test_dataset_label.py +++ b/tests/unit/test_dataset_label.py @@ -3,6 +3,8 @@ from __future__ import annotations from roar.execution.recording.dataset_metadata import ( + AUTO_DATASET_LABEL_KEYS, + build_dataset_label_metadata, build_dataset_metadata, find_matching_identifier, ) @@ -126,3 +128,54 @@ def test_ignores_unknown_keys(self): result = build_dataset_metadata(identifier) assert "extra_key" not in result assert result["dataset_id"] == "file:///data/raw" + + +# --------------------------------------------------------------------------- +# build_dataset_label_metadata +# --------------------------------------------------------------------------- + + +class TestBuildDatasetLabelMetadata: + def test_includes_dataset_identity_and_modality_labels(self): + identifier = { + "dataset_id": "file:///data/imagenet", + "dataset_fingerprint": "a1b2c3d4e5f67890", + "dataset_fingerprint_algorithm": "blake3", + "split": "train", + "version_hint": "v2", + } + + result = build_dataset_label_metadata( + identifier, + components=[ + { + "relative_path": "train/class_a/image-0001.jpg", + "component_size": 42, + "component_type": "image/jpeg", + } + ], + component_count_total=1, + ) + + assert result == { + "dataset": { + "type": "dataset", + "id": "file:///data/imagenet", + "fingerprint": "a1b2c3d4e5f67890", + "fingerprint_algorithm": "blake3", + "split": "train", + "version_hint": "v2", + "modality": "image", + } + } + + def test_declares_reserved_paths_for_system_managed_dataset_labels(self): + assert { + "dataset.type", + "dataset.id", + "dataset.fingerprint", + "dataset.fingerprint_algorithm", + "dataset.split", + "dataset.version_hint", + "dataset.modality", + } == AUTO_DATASET_LABEL_KEYS diff --git a/tests/unit/test_label_service.py b/tests/unit/test_label_service.py new file mode 100644 index 00000000..66403970 --- /dev/null +++ b/tests/unit/test_label_service.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import pytest + +from roar.application.labels import LabelService + + +def test_reject_reserved_keys_blocks_system_managed_dataset_labels() -> None: + with pytest.raises(ValueError, match="Reserved label keys cannot be set manually"): + LabelService._reject_reserved_keys( + { + "dataset": { + "id": "file:///data/train", + "modality": "tabular", + } + } + )