Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions roar/core/interfaces/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def get_hashes(self, artifact_id: str) -> list[dict[str, Any]]:
"""Get all hashes for an artifact."""
...

def get_hashes_batch(self, artifact_ids: list[str]) -> dict[str, list[dict[str, Any]]]:
"""Get hashes for multiple artifacts in a single query."""
...

def get_locations(self, artifact_id: str) -> list[dict[str, str]]:
"""Get all known locations for an artifact."""
...
Expand Down
99 changes: 99 additions & 0 deletions roar/db/repositories/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,79 @@ def register(
self._session.flush()
return artifact_id, True

def register_batch(self, items: list[tuple[dict[str, str], int, str | None]]) -> list[str]:
"""Register multiple artifacts at once. Returns list of artifact_ids.

items = [(hashes_dict, size, path), ...]

Note: unlike ``register()``, this does not accept ``source_type``,
``source_url``, or ``metadata``, and does not backfill missing hash
algorithms on existing artifacts. Intended for the post-run
registration hot path where those fields are not needed.
"""
if not items:
return []

# Collect all digests for the primary algorithm to check for existing artifacts
all_digests: dict[str, int] = {} # digest -> index in items
primary_algo: str | None = None
for i, (hashes, _size, _path) in enumerate(items):
for algo, digest in hashes.items():
if primary_algo is None:
primary_algo = algo
if algo == primary_algo:
all_digests[digest.lower()] = i

# Bulk lookup existing artifacts by primary hash
existing_map: dict[str, str] = {} # digest -> artifact_id
if primary_algo and all_digests:
rows = self._session.execute(
select(ArtifactHash.digest, ArtifactHash.artifact_id).where(
ArtifactHash.algorithm == primary_algo,
ArtifactHash.digest.in_(list(all_digests.keys())),
)
).all()
for digest, artifact_id in rows:
existing_map[digest] = artifact_id

# Process items
artifact_ids: list[str] = []
new_artifacts: list[Artifact] = []
new_hashes: list[ArtifactHash] = []
for hashes, size, path in items:
primary_digest = hashes.get(primary_algo, "").lower() if primary_algo else ""
if primary_digest in existing_map:
artifact_ids.append(existing_map[primary_digest])
else:
artifact_id = secrets.token_hex(16)
artifact_ids.append(artifact_id)
new_artifacts.append(
Artifact(
id=artifact_id,
size=size,
first_seen_at=time.time(),
first_seen_path=path,
)
)
for algo, digest in hashes.items():
new_hashes.append(
ArtifactHash(
artifact_id=artifact_id,
algorithm=algo,
digest=digest.lower(),
)
)
existing_map[primary_digest] = artifact_id # prevent dupes within batch

if new_artifacts:
self._session.add_all(new_artifacts)
if new_hashes:
self._session.add_all(new_hashes)
if new_artifacts or new_hashes:
self._session.flush()

return artifact_ids

def get(self, artifact_id: str) -> dict[str, Any] | None:
"""
Get artifact by ID.
Expand Down Expand Up @@ -148,6 +221,32 @@ def get_hashes(self, artifact_id: str) -> list[dict[str, Any]]:
for h in hashes
]

def get_hashes_batch(self, artifact_ids: list[str]) -> dict[str, list[dict[str, Any]]]:
"""
Get hashes for multiple artifacts in a single query.

Args:
artifact_ids: List of artifact UUIDs

Returns:
Dict mapping artifact_id to list of hash dicts.
"""
if not artifact_ids:
return {}

rows = (
self._session.execute(
select(ArtifactHash).where(ArtifactHash.artifact_id.in_(artifact_ids))
)
.scalars()
.all()
)

result: dict[str, list[dict[str, Any]]] = {aid: [] for aid in artifact_ids}
for h in rows:
result[h.artifact_id].append({"algorithm": h.algorithm, "digest": h.digest})
return result

def get_by_hash(self, digest: str, algorithm: str | None = None) -> dict[str, Any] | None:
"""
Get artifact by hash digest.
Expand Down
54 changes: 52 additions & 2 deletions roar/db/repositories/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,22 @@ def add_output(
self._session.add(job_output)
self._session.flush()

def add_inputs_batch(self, job_id: int, items: list[tuple[str, str]]) -> None:
"""Bulk-insert input records. items = [(artifact_id, path), ...]"""
if not items:
return
objects = [JobInput(job_id=job_id, artifact_id=aid, path=p) for aid, p in items]
self._session.add_all(objects)
self._session.flush()

def add_outputs_batch(self, job_id: int, items: list[tuple[str, str]]) -> None:
"""Bulk-insert output records. items = [(artifact_id, path), ...]"""
if not items:
return
objects = [JobOutput(job_id=job_id, artifact_id=aid, path=p) for aid, p in items]
self._session.add_all(objects)
self._session.flush()

def has_input_path(self, job_id: int, path: str) -> bool:
"""Check whether an input row already exists for a job/path pair."""
existing = self._session.execute(
Expand All @@ -300,6 +316,32 @@ def has_output_path(self, job_id: int, path: str) -> bool:
).scalar_one_or_none()
return existing is not None

def existing_input_paths(self, job_id: int, paths: list[str]) -> set[str]:
"""Return the subset of *paths* that already have input rows for *job_id*."""
if not paths:
return set()
rows = (
self._session.execute(
select(JobInput.path).where(JobInput.job_id == job_id, JobInput.path.in_(paths))
)
.scalars()
.all()
)
return set(rows)

def existing_output_paths(self, job_id: int, paths: list[str]) -> set[str]:
"""Return the subset of *paths* that already have output rows for *job_id*."""
if not paths:
return set()
rows = (
self._session.execute(
select(JobOutput.path).where(JobOutput.job_id == job_id, JobOutput.path.in_(paths))
)
.scalars()
.all()
)
return set(rows)

def get_inputs(self, job_id: int) -> list[dict[str, Any]]:
"""
Get input artifacts for a job.
Expand All @@ -325,9 +367,13 @@ def get_inputs(self, job_id: int) -> list[dict[str, Any]]:
)
rows = self._session.execute(query).all()

# Batch-fetch all hashes in one query
artifact_ids = list({row[1] for row in rows})
all_hashes = self._artifact_repository.get_hashes_batch(artifact_ids)

results = []
for path, artifact_id, byte_ranges, size, first_seen_path, kind, component_count in rows:
hashes = self._artifact_repository.get_hashes(artifact_id)
hashes = all_hashes.get(artifact_id, [])
results.append(
{
"path": path or first_seen_path, # Use artifact path as fallback
Expand Down Expand Up @@ -369,9 +415,13 @@ def get_outputs(self, job_id: int) -> list[dict[str, Any]]:
)
rows = self._session.execute(query).all()

# Batch-fetch all hashes in one query
artifact_ids = list({row[1] for row in rows})
all_hashes = self._artifact_repository.get_hashes_batch(artifact_ids)

results = []
for path, artifact_id, byte_ranges, size, first_seen_path, kind, component_count in rows:
hashes = self._artifact_repository.get_hashes(artifact_id)
hashes = all_hashes.get(artifact_id, [])
results.append(
{
"path": path or first_seen_path, # Use artifact path as fallback
Expand Down
33 changes: 23 additions & 10 deletions roar/db/services/job_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,31 +223,44 @@ def _register_artifacts(
is_input: bool,
) -> None:
"""Register artifacts and link them to the job."""
# Batch-check which paths already have edges for this job
if is_input:
already_linked = self._job_repo.existing_input_paths(job_id, file_paths)
else:
already_linked = self._job_repo.existing_output_paths(job_id, file_paths)

# Build batch items, skipping paths that already have edges
batch_items: list[tuple[dict[str, str], int, str | None]] = []
valid_paths: list[str] = []
for path in file_paths:
if is_input and self._job_repo.has_input_path(job_id, path):
continue
if not is_input and self._job_repo.has_output_path(job_id, path):
if path in already_linked:
continue

path_hashes = hashes_by_path.get(path)
if not path_hashes:
continue

hashes = {algo: digest for algo in hash_algorithms if (digest := path_hashes.get(algo))}
if not hashes:
continue

try:
size = os.path.getsize(path)
except OSError:
size = 0
batch_items.append((hashes, size, path))
valid_paths.append(path)

if not batch_items:
return

artifact_id, _ = self._artifact_repo.register(hashes, size, path)
# Batch register artifacts
artifact_ids = self._artifact_repo.register_batch(batch_items)

if is_input:
self._job_repo.add_input(job_id, artifact_id, path)
else:
self._job_repo.add_output(job_id, artifact_id, path)
# Batch create edges
edges = list(zip(artifact_ids, valid_paths, strict=True))
if is_input:
self._job_repo.add_inputs_batch(job_id, edges)
else:
self._job_repo.add_outputs_batch(job_id, edges)

@staticmethod
def _unique_paths(paths: list[str]) -> list[str]:
Expand Down
25 changes: 25 additions & 0 deletions roar/execution/runtime/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def stop_proxy_if_running() -> list:
from ...integrations.config import load_config
from ..provenance import ProvenanceService

emit_timing = os.environ.get("ROAR_TIMING") == "1"
t_postrun_start = time.perf_counter()

bootstrap(ctx.roar_dir)
config = load_config(start_dir=ctx.repo_root)

Expand Down Expand Up @@ -230,12 +233,14 @@ def stop_proxy_if_running() -> list:
)
roar_dir = os.path.join(ctx.repo_root, ".roar")
provenance_service = ProvenanceService(cache_dir=roar_dir)
t_prov_start = time.perf_counter()
prov = provenance_service.collect(
ctx.repo_root,
tracer_result.tracer_log_path,
inject_log,
config,
)
t_prov_end = time.perf_counter()
self.logger.debug(
"Provenance collected: read_files=%d, written_files=%d",
len(prov.get("data", {}).get("read_files", [])),
Expand All @@ -249,6 +254,7 @@ def stop_proxy_if_running() -> list:

# Record in database
self.logger.debug("Recording job in database")
t_record_start = time.perf_counter()
job_id, job_uid, read_file_info, written_file_info, stale_upstream, stale_downstream = (
self._record_job(
ctx,
Expand All @@ -260,6 +266,7 @@ def stop_proxy_if_running() -> list:
run_job_uid=run_job_uid,
)
)
t_record_end = time.perf_counter()
self.logger.debug(
"Job recorded: id=%d, uid=%s, inputs=%d, outputs=%d",
job_id,
Expand All @@ -274,6 +281,24 @@ def stop_proxy_if_running() -> list:
self.logger.debug("Cleaning up temporary log files")
self._cleanup_logs(tracer_result.tracer_log_path, tracer_result.inject_log_path)

t_postrun_end = time.perf_counter()

if emit_timing:
import json as _json

n_inputs = len(prov.get("data", {}).get("read_files", []))
n_outputs = len(prov.get("data", {}).get("written_files", []))
timing = {
"roar_timing": True,
"tracer_ms": round(tracer_result.duration * 1000, 2),
"postrun_ms": round((t_postrun_end - t_postrun_start) * 1000, 2),
"provenance_ms": round((t_prov_end - t_prov_start) * 1000, 2),
"record_ms": round((t_record_end - t_record_start) * 1000, 2),
"n_inputs": n_inputs,
"n_outputs": n_outputs,
}
print(f"ROAR_TIMING:{_json.dumps(timing)}", file=sys.stderr, flush=True)

self.logger.debug(
"RunCoordinator.execute completed: exit_code=%d, duration=%.2fs",
tracer_result.exit_code,
Expand Down
37 changes: 29 additions & 8 deletions roar/filters/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,20 @@ def _build_package_file_map(self) -> tuple[dict, dict]:
# file_to_pkg is intentionally empty; classify() uses path extraction.
return {}, pkg_versions

def classify(self, path: str) -> tuple[str, str | None]:
def _get_git_tracked_files(self) -> set[str]:
"""Run ``git ls-files`` once and return a set of repo-relative paths."""
try:
output = subprocess.check_output(
["git", "ls-files"],
cwd=str(self.repo_root),
text=True,
stderr=subprocess.DEVNULL,
)
return set(output.splitlines())
except (subprocess.CalledProcessError, FileNotFoundError):
return set()

def classify(self, path: str, git_tracked: set[str] | None = None) -> tuple[str, str | None]:
"""
Classify a file into one of:
- "repo": tracked in the git repo
Expand Down Expand Up @@ -224,12 +237,18 @@ def classify(self, path: str) -> tuple[str, str | None]:
else:
try:
rel = Path(path_str).relative_to(self.repo_root)
subprocess.check_output(
["git", "ls-files", "--error-unmatch", str(rel)],
cwd=str(self.repo_root),
stderr=subprocess.DEVNULL,
)
return ("repo", None)
if git_tracked is not None:
if str(rel) in git_tracked:
return ("repo", None)
else:
return ("unmanaged", None)
else:
subprocess.check_output(
["git", "ls-files", "--error-unmatch", str(rel)],
cwd=str(self.repo_root),
stderr=subprocess.DEVNULL,
)
return ("repo", None)
except subprocess.CalledProcessError:
# In repo but not tracked - could be generated file
return ("unmanaged", None)
Expand Down Expand Up @@ -360,10 +379,12 @@ def classify_all(self, paths: list[str]) -> dict:
"skip": 0,
}

git_tracked = self._get_git_tracked_files()

for path in paths:
if not path:
continue
classification, pkg_name = self.classify(path)
classification, pkg_name = self.classify(path, git_tracked=git_tracked)
stats[classification] = stats.get(classification, 0) + 1

if classification == "repo":
Expand Down
Loading