Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 168 additions & 60 deletions common/src/buttercup/common/sarif_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import logging
from typing import Any

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, computed_field
from redis import Redis

logger = logging.getLogger(__name__)


class SARIFBroadcastDetail(BaseModel):
"""Model for SARIF broadcast details, matches the model in types.py"""
Expand All @@ -19,8 +22,111 @@ class SARIFBroadcastDetail(BaseModel):
task_id: str


class Finding(BaseModel):
"""Individual vulnerability finding extracted from a SARIF result."""

rule_id: str
level: str
message: str
file_uri: str
start_line: int
end_line: int
start_column: int | None = None
tool_name: str
sarif_id: str
task_id: str

@computed_field
@property
def fingerprint(self) -> str:
return f"{self.rule_id}:{self.file_uri}:{self.start_line}:{self.end_line}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do a fingerprint of all fields in the finding? There could be a fingerprint collision if e.g., only the start_column differs.



def _extract_finding_from_result(
result: dict[str, Any],
tool_name: str,
sarif_id: str,
task_id: str,
) -> Finding | None:
"""Extract a Finding from a single SARIF result entry.

Returns None if the result is malformed or missing required fields.
"""
rule_id = result.get("ruleId", "")
if not rule_id:
rule = result.get("rule", {})
rule_id = rule.get("id", "unknown")

level = result.get("level", "warning")
message_obj = result.get("message", {})
message = message_obj.get("text", "") if isinstance(message_obj, dict) else str(message_obj)

locations = result.get("locations", [])
if not locations:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a warning log in the three return None branches in this function that it failed to extract a finding.

return None

physical = locations[0].get("physicalLocation", {})
artifact_location = physical.get("artifactLocation", {})
file_uri = artifact_location.get("uri", "")
if not file_uri:
return None

region = physical.get("region", {})
start_line = region.get("startLine", 0)
if start_line == 0:
return None

end_line = region.get("endLine", start_line)
start_column = region.get("startColumn")

return Finding(
rule_id=rule_id,
level=level,
message=message,
file_uri=file_uri,
start_line=start_line,
end_line=end_line,
start_column=start_column,
tool_name=tool_name,
sarif_id=sarif_id,
task_id=task_id,
)


def extract_findings(sarif_detail: SARIFBroadcastDetail) -> list[Finding]:
"""Extract individual findings from a SARIF broadcast detail.

Iterates over runs[].results[] and extracts actionable fields.
Skips malformed entries gracefully.
"""
findings: list[Finding] = []
sarif = sarif_detail.sarif
runs = sarif.get("runs", [])

for run in runs:
driver = run.get("tool", {}).get("driver", {})
tool_name = driver.get("name", "unknown")
results = run.get("results", [])

for result in results:
finding = _extract_finding_from_result(
result,
tool_name,
sarif_detail.sarif_id,
sarif_detail.task_id,
)
if finding is not None:
findings.append(finding)

return findings


class SARIFStore:
"""Store and retrieve SARIF objects in Redis"""
"""Store and retrieve SARIF objects and extracted findings in Redis."""

SARIF_PREFIX = "sarif:"
FINDING_PREFIX = "findings:"
FINDING_SEEN_PREFIX = "findings_seen:"

def __init__(self, redis: Redis):
"""Initialize the SARIF store with a Redis connection.
Expand All @@ -30,85 +136,67 @@ def __init__(self, redis: Redis):

"""
self.redis = redis
self.key_prefix = "sarif:"
# Keep for backward compat with code using self.key_prefix
self.key_prefix = self.SARIF_PREFIX

def _get_key(self, task_id: str) -> str:
"""Get the Redis key for a task_id.

Args:
task_id: Task ID
return f"{self.SARIF_PREFIX}{task_id.lower()}"

Returns:
Redis key
def _get_finding_key(self, task_id: str) -> str:
return f"{self.FINDING_PREFIX}{task_id.lower()}"

"""
return f"{self.key_prefix}{task_id.lower()}"
def _get_finding_seen_key(self, task_id: str) -> str:
return f"{self.FINDING_SEEN_PREFIX}{task_id.lower()}"

def _decode_key(self, key: str | bytes) -> str:
"""Decode a Redis key if it's bytes, otherwise return as is.

Args:
key: Redis key, either bytes or string

Returns:
Decoded key as string

"""
if isinstance(key, bytes):
return key.decode("utf-8")
return key

def store(self, sarif_detail: SARIFBroadcastDetail) -> None:
"""Store a SARIF broadcast detail in Redis.

Args:
sarif_detail: The SARIF broadcast detail to store

"""
"""Store a SARIF broadcast detail and its extracted findings in Redis."""
task_id = sarif_detail.task_id
key = self._get_key(task_id)

# We'll use a Redis list to store multiple SARIF objects for the same task
# Serialize the SARIF object to JSON
sarif_key = self._get_key(task_id)
sarif_json = sarif_detail.model_dump_json()
self.redis.rpush(sarif_key, sarif_json)

# Add to the list for this task
self.redis.rpush(key, sarif_json)
findings = extract_findings(sarif_detail)
self._store_findings(task_id, findings)

def get_all(self) -> list[SARIFBroadcastDetail]:
"""Retrieve all SARIF objects from Redis.
def _store_findings(self, task_id: str, findings: list[Finding]) -> int:
"""Store findings, deduplicating by fingerprint. Returns count of new findings added."""
finding_key = self._get_finding_key(task_id)
seen_key = self._get_finding_seen_key(task_id)
added = 0

Returns:
List of SARIF broadcast details
for finding in findings:
if self.redis.sismember(seen_key, finding.fingerprint):
continue
self.redis.rpush(finding_key, finding.model_dump_json())
self.redis.sadd(seen_key, finding.fingerprint)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 — Race condition: non-atomic dedup allows duplicate findings.

The sismemberrpushsadd sequence is three independent Redis commands. Two concurrent store() calls for the same task_id can both pass the sismember check before either executes sadd, causing the same finding to be rpushed twice.

The rest of the codebase uses redis.pipeline() for atomic multi-key writes (e.g. maps.py:67, task_registry.py:99, sets.py:121).

Simplest fix — use sadd return value as the atomic guard:

for finding in findings:
    if not self.redis.sadd(seen_key, finding.fingerprint):
        continue
    self.redis.rpush(finding_key, finding.model_dump_json())
    added += 1

sadd returns 1 if the member was added (new), 0 if it already existed. This eliminates the TOCTOU window in a single atomic check and also fixes the operation ordering (the guard is set before the list push, so a crash between the two can't leave an unguarded duplicate).

added += 1

"""
# Get all SARIF keys in Redis
if added > 0:
logger.info("Added %d new findings for task %s (total pool: %d)", added, task_id, added)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 — Log message says "total pool" but reports the batch count.

The third format arg is added again, not the actual pool size. If 2 findings are added to a pool that already has 10, the log says total pool: 2.

Suggested change
logger.info("Added %d new findings for task %s (total pool: %d)", added, task_id, added)
logger.info("Added %d new findings for task %s (total pool: %d)", added, task_id, self.redis.llen(finding_key))


return added

def get_all(self) -> list[SARIFBroadcastDetail]:
"""Retrieve all SARIF objects from Redis."""
all_keys = self.redis.keys(f"{self.key_prefix}*")

result = []
for key in all_keys:
# Decode the key if it's bytes
decoded_key = self._decode_key(key)

# Get all SARIF objects for this task
sarif_list = self.redis.lrange(decoded_key, 0, -1)
for sarif_json in sarif_list:
# Parse each JSON string into a SARIFBroadcastDetail
sarif_detail = SARIFBroadcastDetail.model_validate_json(sarif_json)
result.append(sarif_detail)

return result

def get_by_task_id(self, task_id: str) -> list[SARIFBroadcastDetail]:
"""Retrieve all SARIF objects for a specific task.

Args:
task_id: Task ID

Returns:
List of SARIF broadcast details for this task

"""
"""Retrieve all SARIF objects for a specific task."""
key = self._get_key(task_id)
sarif_list = self.redis.lrange(key, 0, -1)

Expand All @@ -119,15 +207,35 @@ def get_by_task_id(self, task_id: str) -> list[SARIFBroadcastDetail]:

return result

def delete_by_task_id(self, task_id: str) -> int:
"""Remove all SARIF objects for a specific task.
def get_findings_by_task_id(self, task_id: str) -> list[Finding]:
"""Retrieve all findings for a specific task from the finding pool.

Args:
task_id: Task ID
Falls back to extracting from stored SARIFs if the finding pool
is empty but SARIF data exists (backward compatibility).
"""
finding_key = self._get_finding_key(task_id)
finding_list = self.redis.lrange(finding_key, 0, -1)

Returns:
Number of removed keys (0 or 1)
if finding_list:
return [Finding.model_validate_json(f) for f in finding_list]

"""
key = self._get_key(task_id)
return self.redis.delete(key)
# Fallback: extract from old SARIF data if present
sarifs = self.get_by_task_id(task_id)
if not sarifs:
return []

all_findings: list[Finding] = []
for sarif_detail in sarifs:
all_findings.extend(extract_findings(sarif_detail))

if all_findings:
self._store_findings(task_id, all_findings)

return all_findings
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 — Fallback returns un-deduplicated findings.

The primary path (line 220) returns from the deduplicated Redis pool, but this fallback returns the raw all_findings list. If two stored SARIFs contain findings with the same fingerprint, the first call returns duplicates (while all subsequent calls return the deduplicated pool). This inconsistency propagates to sample_findings(), inflating both the probability calculation (len(self.findings)) and the sampling pool.

Suggested fix — re-read from Redis after storing, or deduplicate before returning:

if all_findings:
    self._store_findings(task_id, all_findings)
    # Return from the deduplicated pool to match the primary path
    return self.get_findings_by_task_id(task_id)

return []

(This is safe from infinite recursion because _store_findings populates the finding key, so the recursive call hits the primary if finding_list: path.)


def delete_by_task_id(self, task_id: str) -> int:
"""Remove all SARIF objects and findings for a specific task."""
sarif_key = self._get_key(task_id)
finding_key = self._get_finding_key(task_id)
seen_key = self._get_finding_seen_key(task_id)
return self.redis.delete(sarif_key, finding_key, seen_key)
Loading
Loading