From 0ed5ac17feb4a7ce80ad9df0b26cce52f77fbdc4 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Sat, 10 Jan 2026 19:00:36 -0700 Subject: [PATCH] feat: add evidence validation, adaptive query limits, and grafana annotations **Added:** - Implemented evidence validation against recent query results, including provenance tracking, IOC extraction, and confidence adjustment - `src/ares/core/evidence_validation.py` - Introduced automatic posting of investigation started, completed, timeout, and failed annotations to Grafana for investigation lifecycle observability - Added `get_suggested_evidence` tool to suggest IOCs extracted from queries for improved evidence recording accuracy - Provided `execute_parallel_queries` and `combine_query_patterns` tools for efficient and parallelized log querying in investigations - Enhanced `Evidence` model with `source_query_id` and `validated` fields for traceability **Changed:** - Updated blue agent investigation orchestrator to post Grafana annotations at start, completion, timeout, and failure of investigations - Refactored query tracking to count only successful (result-producing) queries against adaptive query limits and not penalize failed/empty queries - Made query limits adaptive: increased limits with investigation progress, bonus queries for finding evidence or reaching higher Pyramid of Pain levels, and a hard cap - Updated evidence recording to validate values, adjust confidence if unvalidated, and log provenance - Improved system instructions to document parallel and combined query strategies and explain new evidence validation and IOC suggestion capabilities - Updated default query limit variable names in `Taskfile.yaml` for blue and red agents **Removed:** - Removed old badge section in `README.md` in favor of up-to-date project badges - Deprecated redundant agent thread import in factories for blue and red agents --- README.md | 12 +- Taskfile.yaml | 23 +- src/ares/agents/blue/soc_investigator.py | 81 ++++++ src/ares/core/evidence_validation.py | 284 +++++++++++++++++++ src/ares/core/factories/blue_factory.py | 168 +++++++++-- src/ares/core/factories/red_factory.py | 3 +- src/ares/core/models.py | 6 + src/ares/tools/blue/actions.py | 2 +- src/ares/tools/blue/grafana.py | 128 +++++++++ src/ares/tools/blue/investigation.py | 65 ++++- src/ares/tools/blue/observability.py | 177 ++++++++++++ templates/agent/system_instructions.md.jinja | 38 ++- 12 files changed, 938 insertions(+), 49 deletions(-) create mode 100644 src/ares/core/evidence_validation.py diff --git a/README.md b/README.md index 093cd982..b2d83c08 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,7 @@ # Ares - Autonomous Security Operations Agent - -
- -[![Pre-Commit](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml) -[![Renovate](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml) -[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) - -
- - +[![Tests](https://github.com/dreadnode/ares/actions/workflows/tests.yaml/badge.svg)](https://github.com/dreadnode/ares/actions/workflows/tests.yaml) +[![Coverage](https://raw.githubusercontent.com/dreadnode/ares/main/.github/badges/coverage.svg)](https://github.com/dreadnode/ares/actions/workflows/coverage-badge.yaml) [![Pre-Commit](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Python](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) diff --git a/Taskfile.yaml b/Taskfile.yaml index bdfc31a2..548a19f0 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -15,8 +15,9 @@ vars: MODEL: '{{.MODEL | default "claude-sonnet-4-20250514"}}' GRAFANA_URL: '{{.GRAFANA_URL | default "https://grafana.dev.plundr.ai"}}' POLL_INTERVAL: '{{.POLL_INTERVAL | default "30"}}' - MAX_STEPS: '{{.MAX_STEPS | default "50"}}' - MAX_STEPS_ONCE: '{{.MAX_STEPS_ONCE | default "15"}}' # ~15 min max for once mode + MAX_STEPS_BLUE: '{{.MAX_STEPS_BLUE | default "50"}}' + MAX_STEPS_BLUE_ONCE: '{{.MAX_STEPS_BLUE_ONCE | default "15"}}' # ~15 min max for once mode + MAX_STEPS_RED: '{{.MAX_STEPS_RED | default "150"}}' REPORT_DIR: '{{.REPORT_DIR | default "./reports"}}' LOG_DIR: '{{.LOG_DIR | default "./logs"}}' DREADNODE_SERVER: '{{.DREADNODE_SERVER | default "https://platform.dev.plundr.ai/"}}' @@ -170,7 +171,7 @@ tasks: --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_BLUE}} \ --args.report-dir {{.REPORT_DIR}} \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.token "$DREADNODE_API_KEY" \ @@ -200,7 +201,7 @@ tasks: --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS_ONCE}} \ + --args.max-steps {{.MAX_STEPS_BLUE_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --args.once \ --dn-args.server {{.DREADNODE_SERVER}} \ @@ -236,7 +237,7 @@ tasks: --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_BLUE}} \ --args.report-dir {{.REPORT_DIR}} \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ @@ -270,7 +271,7 @@ tasks: --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS_ONCE}} \ + --args.max-steps {{.MAX_STEPS_BLUE_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --args.once \ --dn-args.server {{.DREADNODE_SERVER}} \ @@ -302,7 +303,7 @@ tasks: uv run python -m ares investigate-alert {{.ALERT}} \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ - --args.max-steps {{.MAX_STEPS_ONCE}} \ + --args.max-steps {{.MAX_STEPS_BLUE_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.token "$DREADNODE_API_KEY" \ @@ -384,7 +385,9 @@ tasks: echo "" echo "Agent Settings:" echo " Model: {{.MODEL}}" - echo " Max Steps: {{.MAX_STEPS}}" + echo " Max Steps (Blue): {{.MAX_STEPS_BLUE}}" + echo " Max Steps (Blue Once): {{.MAX_STEPS_BLUE_ONCE}}" + echo " Max Steps (Red): {{.MAX_STEPS_RED}}" echo " Poll Interval: {{.POLL_INTERVAL}}s" echo "" echo "Data Sources:" @@ -642,7 +645,7 @@ tasks: uv run python -m ares red-team "$RESOLVED_TARGET" \ --args.model {{.MODEL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_RED}} \ --args.report-dir {{.REPORT_DIR}} \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.token "$DREADNODE_API_KEY" \ @@ -676,7 +679,7 @@ tasks: uv run python -m ares red-team {{.TARGET}} \ --args.model {{.MODEL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_RED}} \ --args.report-dir {{.REPORT_DIR}} \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ diff --git a/src/ares/agents/blue/soc_investigator.py b/src/ares/agents/blue/soc_investigator.py index 769f9584..5dfc05a5 100644 --- a/src/ares/agents/blue/soc_investigator.py +++ b/src/ares/agents/blue/soc_investigator.py @@ -21,6 +21,7 @@ ) from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient +from ares.tools.blue.grafana import GrafanaTools class InvestigationTimeoutError(Exception): @@ -172,6 +173,11 @@ def __init__( self.max_steps = max_steps self._mcp_client = None self._mcp_tools = None + # Grafana tools for annotations + self._grafana_tools = GrafanaTools( + base_url=grafana_url, + api_key=grafana_api_key, + ) async def _ensure_mcp_connection(self) -> None: """Ensure MCP connection is established (with 60s timeout).""" @@ -271,6 +277,9 @@ async def investigate(self, alert: dict) -> dict: # Ensure MCP connection is ready await self._ensure_mcp_connection() + # Post "investigation started" annotation to Grafana + await self._post_started_annotation(investigation_id, alert) + # Auto-extract and record MITRE technique from alert labels = alert.get("labels", {}) annotations = alert.get("annotations", {}) @@ -357,6 +366,11 @@ async def investigate(self, alert: dict) -> dict: # Persist investigation for learning self._persist_investigation(state, status) + # Post "investigation completed" annotation to Grafana + await self._post_completed_annotation( + investigation_id, alert_name, status, state + ) + dn.log_output("report_path", str(report_path)) dn.log_metric("investigation_success", 1) @@ -383,6 +397,11 @@ async def investigate(self, alert: dict) -> dict: # Persist investigation for learning (even on timeout) self._persist_investigation(state, "timeout") + # Post "investigation timeout" annotation to Grafana + await self._post_completed_annotation( + investigation_id, alert_name, "timeout", state + ) + return { "investigation_id": investigation_id, "status": "timeout", @@ -398,12 +417,74 @@ async def investigate(self, alert: dict) -> dict: # Persist failed investigation self._persist_investigation(state, "failed") + + # Post "investigation failed" annotation to Grafana + await self._post_completed_annotation( + investigation_id, alert_name, "failed", state + ) raise finally: # Always cancel the watchdog on normal completion watchdog.cancel() + async def _post_started_annotation(self, investigation_id: str, alert: dict) -> None: + """Post investigation started annotation to Grafana. + + Args: + investigation_id: Unique investigation identifier. + alert: Alert dictionary. + """ + try: + labels = alert.get("labels", {}) + alert_name = labels.get("alertname", "unknown") + severity = labels.get("severity", "unknown") + + await self._grafana_tools.post_investigation_started( + investigation_id=investigation_id, + alert_name=alert_name, + severity=severity, + ) + except Exception as e: + # Don't fail the investigation if annotation fails + logger.warning(f"Failed to post started annotation: {e}") + + async def _post_completed_annotation( + self, + investigation_id: str, + alert_name: str, + status: str, + state: InvestigationState, + ) -> None: + """Post investigation completed annotation to Grafana. + + Args: + investigation_id: Unique investigation identifier. + alert_name: Name of the alert investigated. + status: Final status. + state: Investigation state. + """ + try: + # Get summary from state if available + summary = None + if state.attack_synopsis: + summary = state.attack_synopsis + elif state.recommendations: + summary = f"Recommendations: {', '.join(state.recommendations[:3])}" + + await self._grafana_tools.post_investigation_completed( + investigation_id=investigation_id, + alert_name=alert_name, + status=status, + evidence_count=len(state.evidence), + techniques=list(state.identified_techniques), + pyramid_level=state.highest_pyramid_level, + summary=summary, + ) + except Exception as e: + # Don't fail the investigation if annotation fails + logger.warning(f"Failed to post completed annotation: {e}") + def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) -> None: """Create an initial timeline event from the alert.""" labels = alert.get("labels", {}) diff --git a/src/ares/core/evidence_validation.py b/src/ares/core/evidence_validation.py new file mode 100644 index 00000000..e8119f62 --- /dev/null +++ b/src/ares/core/evidence_validation.py @@ -0,0 +1,284 @@ +"""Evidence validation and IOC extraction for investigation integrity. + +This module provides: +1. Storage for recent query results (for evidence provenance) +2. Validation of evidence values against query results +3. Auto-extraction of IOCs from query results +""" + +import re +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +from loguru import logger + +# Maximum number of query results to store for validation +MAX_STORED_RESULTS = 10 + +# Confidence penalty for unvalidated evidence +UNVALIDATED_CONFIDENCE_PENALTY = 0.3 + + +@dataclass +class StoredQueryResult: + """A stored query result for evidence validation.""" + + query_id: str + query_type: str # e.g., "query_loki_logs", "query_prometheus" + query_string: str + timestamp: datetime + result_data: Any # The actual result (list, dict, or string) + result_count: int + extracted_values: set[str] = field(default_factory=set) # Pre-extracted searchable values + + +# Global storage for recent query results +_recent_results: deque[StoredQueryResult] = deque(maxlen=MAX_STORED_RESULTS) +_query_counter = 0 + + +def reset_evidence_validation(): + """Reset evidence validation state for a new investigation.""" + global _recent_results, _query_counter + _recent_results = deque(maxlen=MAX_STORED_RESULTS) + _query_counter = 0 + + +def store_query_result( + query_type: str, + query_string: str, + result_data: Any, + result_count: int, +) -> str: + """Store a query result for evidence validation. + + Args: + query_type: Type of query (e.g., "query_loki_logs") + query_string: The query string executed + result_data: The actual result data + result_count: Number of results returned + + Returns: + Query ID for reference + """ + global _query_counter + _query_counter += 1 + query_id = f"q-{_query_counter:04d}" + + # Extract searchable values from results + extracted = _extract_searchable_values(result_data) + + stored = StoredQueryResult( + query_id=query_id, + query_type=query_type, + query_string=query_string, + timestamp=datetime.now(timezone.utc), + result_data=result_data, + result_count=result_count, + extracted_values=extracted, + ) + + _recent_results.append(stored) + logger.debug(f"Stored query result {query_id} with {len(extracted)} extracted values") + + return query_id + + +def _extract_searchable_values(data: Any, depth: int = 0) -> set[str]: + """Recursively extract searchable string values from query results. + + Extracts IPs, hostnames, usernames, and other IOC-like values. + + Args: + data: Data to extract from (dict, list, or primitive) + depth: Current recursion depth (to prevent infinite recursion) + + Returns: + Set of extracted string values + """ + if depth > 10: # Prevent infinite recursion + return set() + + values: set[str] = set() + + if isinstance(data, str): + # Add the string itself if it looks like an IOC + if data and len(data) < 500: # Skip very long strings + values.add(data.lower()) + # Also extract embedded patterns + values.update(_extract_patterns_from_string(data)) + elif isinstance(data, dict): + for val in data.values(): + # Add key-value pairs for common fields + if isinstance(val, str) and val: + values.add(val.lower()) + values.update(_extract_patterns_from_string(val)) + elif isinstance(val, (dict, list)): + values.update(_extract_searchable_values(val, depth + 1)) + elif isinstance(data, list): + for item in data: + values.update(_extract_searchable_values(item, depth + 1)) + + return values + + +def _extract_patterns_from_string(text: str) -> set[str]: + """Extract IOC patterns from a string. + + Args: + text: Text to extract patterns from + + Returns: + Set of extracted patterns (IPs, hostnames, etc.) + """ + patterns: set[str] = set() + + # IP addresses + ip_pattern = r"\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b" + for match in re.findall(ip_pattern, text): + patterns.add(match.lower()) + + # Hostnames/FQDNs + hostname_pattern = r"\b([a-zA-Z0-9][-a-zA-Z0-9]*\.[-a-zA-Z0-9.]+)\b" + for match in re.findall(hostname_pattern, text): + if "." in match and not match[0].isdigit(): + patterns.add(match.lower()) + + # Windows usernames (domain\user or user@domain) + user_patterns = [ + r"\b([a-zA-Z0-9_-]+\\[a-zA-Z0-9_.-]+)\b", # domain\user + r"\b([a-zA-Z0-9_.-]+@[a-zA-Z0-9.-]+)\b", # user@domain + ] + for pattern in user_patterns: + for match in re.findall(pattern, text): + patterns.add(match.lower()) + + # Simple usernames (from common fields) + simple_user = r'"(?:user|username|account|TargetUserName|SubjectUserName)":\s*"([^"]+)"' + for match in re.findall(simple_user, text, re.IGNORECASE): + patterns.add(match.lower()) + + return patterns + + +def validate_evidence_value(value: str) -> tuple[bool, str | None]: + """Validate an evidence value against recent query results. + + Args: + value: The evidence value to validate + + Returns: + Tuple of (is_validated, source_query_id) + """ + if not value: + return False, None + + normalized_value = value.lower().strip() + + # Search through recent results + for stored in reversed(_recent_results): # Most recent first + # Check if value appears in extracted values + if normalized_value in stored.extracted_values: + logger.info(f"Evidence '{value[:50]}...' validated against query {stored.query_id}") + return True, stored.query_id + + # Also do a substring search in extracted values for partial matches + for extracted in stored.extracted_values: + if normalized_value in extracted or extracted in normalized_value: + logger.info( + f"Evidence '{value[:50]}...' partially validated against query {stored.query_id}" + ) + return True, stored.query_id + + logger.warning(f"Evidence '{value[:50]}...' could not be validated against recent queries") + return False, None + + +def get_suggested_iocs() -> list[dict]: + """Extract and return suggested IOCs from recent query results. + + Returns: + List of suggested IOCs with type, value, and source query ID + """ + suggestions: list[dict] = [] + seen_values: set[str] = set() + + for stored in reversed(_recent_results): # Most recent first + for value in stored.extracted_values: + if value in seen_values: + continue + seen_values.add(value) + + ioc_type = _classify_ioc(value) + if ioc_type: + suggestions.append( + { + "type": ioc_type, + "value": value, + "source_query_id": stored.query_id, + } + ) + + return suggestions[:50] # Limit to 50 suggestions + + +def _classify_ioc(value: str) -> str | None: + """Classify an IOC value by type. + + Args: + value: The value to classify + + Returns: + IOC type or None if not classifiable + """ + # IP address + if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", value): + return "ip" + + # Domain/hostname + if re.match(r"^[a-z0-9][-a-z0-9]*\.[a-z0-9][-a-z0-9.]+$", value) and not value[0].isdigit(): + return "hostname" + + # Username patterns + if "\\" in value or "@" in value: + return "user" + + # Hash patterns + if re.match(r"^[a-f0-9]{32}$", value): + return "hash" # MD5 + if re.match(r"^[a-f0-9]{40}$", value): + return "hash" # SHA1 + if re.match(r"^[a-f0-9]{64}$", value): + return "hash" # SHA256 + + return None + + +def adjust_confidence_for_validation( + confidence: float, + validated: bool, +) -> float: + """Adjust confidence score based on validation status. + + Args: + confidence: Original confidence score + validated: Whether evidence was validated + + Returns: + Adjusted confidence score + """ + if validated: + return confidence + # Apply penalty for unvalidated evidence + return max(0.1, confidence - UNVALIDATED_CONFIDENCE_PENALTY) + + +def get_recent_query_ids() -> list[str]: + """Get list of recent query IDs for reference. + + Returns: + List of query IDs from most recent to oldest + """ + return [stored.query_id for stored in reversed(_recent_results)] diff --git a/src/ares/core/factories/blue_factory.py b/src/ares/core/factories/blue_factory.py index bb64d903..191f9f9c 100644 --- a/src/ares/core/factories/blue_factory.py +++ b/src/ares/core/factories/blue_factory.py @@ -4,13 +4,13 @@ from typing import Any import dreadnode as dn -from dreadnode.agent import Agent +from dreadnode.agent import Agent, Thread from dreadnode.agent.events import AgentEvent, AgentStalled, ToolEnd, ToolStart from dreadnode.agent.hooks import retry_with_feedback from dreadnode.agent.stop import StopCondition, tool_use -from dreadnode.agent.thread import Thread from loguru import logger +from ares.core.evidence_validation import reset_evidence_validation, store_query_result from ares.core.models import InvestigationState from ares.core.query_resilience import QueryResilientExecutor, get_resilient_executor from ares.core.templates import get_template_loader @@ -29,16 +29,33 @@ SYSTEM_INSTRUCTIONS = get_template_loader().render("agent/system_instructions.md.jinja") # Track query calls - reset per investigation via reset_query_tracking() -_total_queries = 0 +_total_queries = 0 # Only counts queries that returned results +_total_queries_attempted = 0 # All queries attempted (including failed) _consecutive_queries: list[str] = [] _query_limit_hit = False _executed_queries: list[dict] = [] _seen_queries: dict[str, int] = {} # Track query -> count to detect loops _current_state: "InvestigationState | None" = None +_bonus_queries_granted = 0 # Track bonus queries granted + +# Base query limits MAX_QUERIES_PER_INVESTIGATION = 5 MAX_QUERIES_CRITICAL = 8 # Higher limit for critical alerts MAX_DUPLICATE_QUERIES = 2 # Max times same query can run before blocking +# Adaptive query limit settings +BONUS_QUERIES_FOR_EVIDENCE = 2 # Grant +2 queries when evidence is found +BONUS_QUERIES_FOR_PYRAMID_L4 = 2 # Grant +2 queries when reaching pyramid level 4+ +MAX_TOTAL_QUERIES = 15 # Hard cap to prevent runaway investigations + +# Staged limits by investigation phase (cumulative budget per phase) +QUERY_LIMITS_BY_STAGE = { + "triage": 5, # Initial 5 queries for triage + "causation": 8, # +3 more (8 total) + "lateral": 11, # +3 more (11 total) + "synthesis": 11, # No additional queries in synthesis phase +} + def reset_query_tracking(): """Reset query tracking for a new investigation.""" @@ -46,18 +63,23 @@ def reset_query_tracking(): global \ _total_queries, \ + _total_queries_attempted, \ _consecutive_queries, \ _query_limit_hit, \ _executed_queries, \ _seen_queries, \ - _current_state + _current_state, \ + _bonus_queries_granted _total_queries = 0 + _total_queries_attempted = 0 _consecutive_queries = [] _query_limit_hit = False _executed_queries = [] _seen_queries = {} _current_state = None + _bonus_queries_granted = 0 reset_resilient_executor() + reset_evidence_validation() # Reset evidence validation state def set_investigation_state(state: "InvestigationState"): @@ -66,13 +88,79 @@ def set_investigation_state(state: "InvestigationState"): _current_state = state +def _calculate_bonus_queries() -> int: + """Calculate bonus queries based on investigation progress. + + Grants bonus queries for: + - Finding evidence (+2 queries) + - Reaching pyramid level 4+ (+2 queries) + + Returns: + Number of bonus queries to grant (0, 2, or 4) + """ + global _bonus_queries_granted + + if not _current_state: + return 0 + + new_bonus = 0 + + # Check if evidence has been found (grant bonus once) + if _current_state.evidence_count > 0 and _bonus_queries_granted < BONUS_QUERIES_FOR_EVIDENCE: + new_bonus += BONUS_QUERIES_FOR_EVIDENCE + logger.info(f"🎁 Granting +{BONUS_QUERIES_FOR_EVIDENCE} bonus queries for finding evidence") + + # Check if pyramid level 4+ reached (grant bonus once) + if ( + _current_state.highest_pyramid_level >= 4 + and _bonus_queries_granted < BONUS_QUERIES_FOR_EVIDENCE + BONUS_QUERIES_FOR_PYRAMID_L4 + ): + # Only grant pyramid bonus if not already at max bonus + pyramid_bonus = min( + BONUS_QUERIES_FOR_PYRAMID_L4, + BONUS_QUERIES_FOR_EVIDENCE + BONUS_QUERIES_FOR_PYRAMID_L4 - _bonus_queries_granted, + ) + if pyramid_bonus > 0: + new_bonus += pyramid_bonus + logger.info(f"🎁 Granting +{pyramid_bonus} bonus queries for reaching pyramid level 4+") + + if new_bonus > 0: + _bonus_queries_granted += new_bonus + + return _bonus_queries_granted + + def _get_query_limit() -> int: - """Get the query limit based on alert severity.""" + """Get the adaptive query limit based on investigation state. + + The limit is determined by: + 1. Base limit from alert severity (5 normal, 8 critical) + 2. Stage-based limits (triage: 5, causation: 8, lateral: 11) + 3. Bonus queries for productive investigations + 4. Hard cap at MAX_TOTAL_QUERIES (15) + + Returns: + Current query limit + """ + # Start with stage-based limit + base_limit = MAX_QUERIES_PER_INVESTIGATION + if _current_state: + # Use stage-based limit + stage_name = _current_state.stage.value + base_limit = QUERY_LIMITS_BY_STAGE.get(stage_name, MAX_QUERIES_PER_INVESTIGATION) + + # Override with critical severity limit if higher severity = _current_state.alert.get("labels", {}).get("severity", "").lower() if severity == "critical": - return MAX_QUERIES_CRITICAL - return MAX_QUERIES_PER_INVESTIGATION + base_limit = max(base_limit, MAX_QUERIES_CRITICAL) + + # Add bonus queries for productive investigations + bonus = _calculate_bonus_queries() + total_limit = base_limit + bonus + + # Cap at maximum to prevent runaway + return min(total_limit, MAX_TOTAL_QUERIES) def _check_query_limit() -> str | None: @@ -127,24 +215,51 @@ def _check_duplicate_query(query: str) -> str | None: return None -def _increment_query_count(tool_name: str): - """Increment query counter and log.""" - global _total_queries - _total_queries += 1 +def _increment_query_attempt(tool_name: str): + """Increment query attempt counter (called before query execution).""" + global _total_queries_attempted + _total_queries_attempted += 1 _consecutive_queries.append(tool_name) if len(_consecutive_queries) > 5: _consecutive_queries.pop(0) limit = _get_query_limit() - logger.info(f"📊 Query count: {_total_queries}/{limit}") + logger.info( + f"📊 Query attempt: {_total_queries_attempted} (successful: {_total_queries}/{limit})" + ) + + +def _count_successful_query(result_count: int | None): + """Count a query as successful if it returned results. + + Only queries that return data count against the limit. + Failed queries (0 results) get a "free retry". + """ + global _total_queries + + if result_count is not None and result_count > 0: + _total_queries += 1 + limit = _get_query_limit() + logger.info( + f"📊 Successful query count: {_total_queries}/{limit} (returned {result_count} results)" + ) + else: + logger.info("📊 Query returned 0 results - not counting against limit") -def _record_query(tool_name: str, kwargs: dict, result_count: int | None = None): - """Record a query to the investigation state.""" +def _record_query( + tool_name: str, + kwargs: dict, + result_count: int | None = None, + result_data: Any = None, +): + """Record a query to the investigation state and store for evidence validation.""" from datetime import datetime, timezone + query_string = kwargs.get("logql") or kwargs.get("expr") or str(kwargs) + query_record = { "type": tool_name, - "query": kwargs.get("logql") or kwargs.get("expr") or str(kwargs), + "query": query_string, "timestamp": datetime.now(timezone.utc).isoformat(), "result_count": result_count, "datasource": kwargs.get("datasourceUid", "unknown"), @@ -154,6 +269,15 @@ def _record_query(tool_name: str, kwargs: dict, result_count: int | None = None) if _current_state: _current_state.executed_queries.append(query_record) + # Store query result for evidence validation (if we have results) + if result_data is not None and result_count and result_count > 0: + store_query_result( + query_type=tool_name, + query_string=query_string, + result_data=result_data, + result_count=result_count, + ) + def create_rate_limited_mcp_tool( original_tool: Any, resilient_executor: QueryResilientExecutor | None = None @@ -203,8 +327,8 @@ async def rate_limited_wrapper(*args, **kwargs): logger.warning(f"🔁 Blocking duplicate query: {query_str[:50]}...") return dup_msg - # Increment counter - _increment_query_count(tool_name) + # Increment attempt counter (successful count updated after results) + _increment_query_attempt(tool_name) # Extract time parameters for resilient execution start_time = kwargs.get("startRfc3339") or kwargs.get("start_time") or kwargs.get("start") @@ -239,9 +363,11 @@ async def query_wrapper(logql: str, start_time: str, end_time: str, **kw): end_time, ) - # Record the query with result count + # Record the query with result count and data for validation result_count = _extract_result_count(result) - _record_query(tool_name, kwargs, result_count) + _record_query(tool_name, kwargs, result_count, result_data=result) + # Only count successful queries against the limit + _count_successful_query(result_count) # Log resilience metadata if present if isinstance(result, dict) and "_resilience_metadata" in result: @@ -269,7 +395,9 @@ async def query_wrapper(logql: str, start_time: str, end_time: str, **kw): try: result = await original_fn(*args, **kwargs) result_count = _extract_result_count(result) - _record_query(tool_name, kwargs, result_count) + _record_query(tool_name, kwargs, result_count, result_data=result) + # Only count successful queries against the limit + _count_successful_query(result_count) return result except Exception as e: error_str = str(e) diff --git a/src/ares/core/factories/red_factory.py b/src/ares/core/factories/red_factory.py index 4c2b525c..59909454 100644 --- a/src/ares/core/factories/red_factory.py +++ b/src/ares/core/factories/red_factory.py @@ -3,7 +3,7 @@ import time import dreadnode as dn -from dreadnode.agent import Agent +from dreadnode.agent import Agent, Thread from dreadnode.agent.events import ( AgentEnd, AgentError, @@ -15,7 +15,6 @@ ) from dreadnode.agent.hooks import retry_with_feedback from dreadnode.agent.stop import tool_use -from dreadnode.agent.thread import Thread from loguru import logger from ares.core.models import RedTeamState diff --git a/src/ares/core/models.py b/src/ares/core/models.py index 1b039b5e..68817f08 100644 --- a/src/ares/core/models.py +++ b/src/ares/core/models.py @@ -72,6 +72,8 @@ class Evidence: mitre_techniques: Associated MITRE ATT&CK technique IDs. confidence: Confidence score between 0.0 and 1.0. metadata: Additional context about this evidence. + source_query_id: ID of the query that produced this evidence (for provenance). + validated: Whether this evidence was validated against query results. """ id: str @@ -83,6 +85,8 @@ class Evidence: mitre_techniques: list[str] = field(default_factory=list) confidence: float = 0.5 metadata: dict[str, Any] = field(default_factory=dict) + source_query_id: str | None = None + validated: bool = False def to_dict(self) -> dict: return { @@ -95,6 +99,8 @@ def to_dict(self) -> dict: "mitre_techniques": self.mitre_techniques, "confidence": self.confidence, "metadata": self.metadata, + "source_query_id": self.source_query_id, + "validated": self.validated, } diff --git a/src/ares/tools/blue/actions.py b/src/ares/tools/blue/actions.py index d053ee93..f6b7bbaf 100644 --- a/src/ares/tools/blue/actions.py +++ b/src/ares/tools/blue/actions.py @@ -160,7 +160,7 @@ def _generate_fallback_synopsis(self) -> None: self.state.attack_synopsis = " ".join(parts) -@dn.tool() # type: ignore[untyped-decorator] +@dn.tool # type: ignore[untyped-decorator] async def escalate_investigation( reason: str, severity: str, diff --git a/src/ares/tools/blue/grafana.py b/src/ares/tools/blue/grafana.py index c466f3d9..85ab2c5d 100644 --- a/src/ares/tools/blue/grafana.py +++ b/src/ares/tools/blue/grafana.py @@ -90,6 +90,134 @@ async def get_alert_history( logger.error(f"Failed to get alert history: {e}") return [] + async def create_annotation( + self, + text: str, + tags: list[str] | None = None, + dashboard_uid: str | None = None, + time_start: int | None = None, + time_end: int | None = None, + ) -> dict | None: + """Create an annotation in Grafana. + + Args: + text: Annotation text/description. + tags: List of tags for filtering. + dashboard_uid: Optional dashboard UID to associate annotation with. + time_start: Start time as epoch milliseconds (defaults to now). + time_end: End time as epoch milliseconds (optional, for range annotations). + + Returns: + Created annotation response or None on failure. + """ + import time + + payload: dict[str, Any] = { + "text": text, + "tags": tags or ["ares", "investigation"], + "time": time_start or int(time.time() * 1000), + } + + if dashboard_uid: + payload["dashboardUID"] = dashboard_uid + + if time_end: + payload["timeEnd"] = time_end + + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/api/annotations", + headers=self._headers(), + json=payload, + ) + response.raise_for_status() + result = response.json() + logger.info(f"Created Grafana annotation: {result.get('id', 'unknown')}") + return result + except httpx.HTTPError as e: + logger.warning(f"Failed to create annotation: {e}") + return None + + async def post_investigation_started( + self, + investigation_id: str, + alert_name: str, + severity: str, + ) -> dict | None: + """Post annotation when investigation starts. + + Args: + investigation_id: Unique investigation identifier. + alert_name: Name of the alert being investigated. + severity: Alert severity level. + + Returns: + Created annotation response or None on failure. + """ + text = ( + f"🔍 **Investigation Started**\n\n" + f"- **ID:** {investigation_id}\n" + f"- **Alert:** {alert_name}\n" + f"- **Severity:** {severity}\n" + f"- **Status:** In Progress" + ) + return await self.create_annotation( + text=text, + tags=["ares", "investigation", "started", alert_name, severity], + ) + + async def post_investigation_completed( + self, + investigation_id: str, + alert_name: str, + status: str, + evidence_count: int, + techniques: list[str], + pyramid_level: int, + summary: str | None = None, + ) -> dict | None: + """Post annotation when investigation completes. + + Args: + investigation_id: Unique investigation identifier. + alert_name: Name of the alert investigated. + status: Final status (completed, escalated, timeout). + evidence_count: Number of evidence items collected. + techniques: List of MITRE ATT&CK techniques identified. + pyramid_level: Highest Pyramid of Pain level reached. + summary: Optional investigation summary. + + Returns: + Created annotation response or None on failure. + """ + status_emoji = { + "completed": "✅", + "escalated": "🚨", + "timeout": "⏰", + "failed": "❌", + "incomplete": "⚠️", + }.get(status, "📋") + + text = ( + f"{status_emoji} **Investigation {status.title()}**\n\n" + f"- **ID:** {investigation_id}\n" + f"- **Alert:** {alert_name}\n" + f"- **Evidence:** {evidence_count} items\n" + f"- **Techniques:** {', '.join(techniques) if techniques else 'None identified'}\n" + f"- **Pyramid Level:** {pyramid_level}/6" + ) + + if summary: + # Truncate summary if too long + truncated = summary[:500] + "..." if len(summary) > 500 else summary + text += f"\n\n**Summary:** {truncated}" + + return await self.create_annotation( + text=text, + tags=["ares", "investigation", status, alert_name], + ) + def find_mcp_grafana() -> str: """Find the mcp-grafana binary. diff --git a/src/ares/tools/blue/investigation.py b/src/ares/tools/blue/investigation.py index 3dd7f841..97f5c08b 100644 --- a/src/ares/tools/blue/investigation.py +++ b/src/ares/tools/blue/investigation.py @@ -13,6 +13,11 @@ _load_attack_chains, _load_detection_recipes, ) +from ares.core.evidence_validation import ( + adjust_confidence_for_validation, + get_suggested_iocs, + validate_evidence_value, +) from ares.core.models import ( Evidence, InvestigationStage, @@ -74,6 +79,10 @@ def record_evidence( 5. Tools (challenging) 6. TTPs (tough - this is the goal!) + NOTE: Evidence is automatically validated against recent query results. + If the value cannot be found in query results, confidence will be reduced. + Use get_suggested_evidence() to see IOCs extracted from queries. + Args: evidence_type: Type of evidence (ip, domain, hash, process, etc.). value: The actual evidence value. @@ -84,7 +93,7 @@ def record_evidence( confidence: Confidence score 0.0-1.0. Returns: - Evidence ID for reference. + Evidence ID and validation status. Example: >>> record_evidence( @@ -95,7 +104,7 @@ def record_evidence( ... pyramid_level=2, ... confidence=0.8 ... ) - 'ev-0001' + 'ev-0001 (validated)' """ if not self.state: return "ERROR: No investigation state" @@ -107,6 +116,12 @@ def record_evidence( with contextlib.suppress(ValueError): ts = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + # Validate evidence against recent query results + validated, source_query_id = validate_evidence_value(value) + + # Adjust confidence based on validation + adjusted_confidence = adjust_confidence_for_validation(confidence, validated) + ev = Evidence( id=evidence_id, type=evidence_type, @@ -115,7 +130,9 @@ def record_evidence( timestamp=ts, pyramid_level=PyramidLevel(min(max(pyramid_level, 1), 6)), mitre_techniques=mitre_techniques or [], - confidence=confidence, + confidence=adjusted_confidence, + source_query_id=source_query_id, + validated=validated, ) self.state.evidence.append(ev) @@ -128,12 +145,15 @@ def record_evidence( dn.log_output(f"evidence_{evidence_id}", ev.to_dict()) dn.log_metric("evidence_count", 1, mode="count") dn.log_metric("highest_pyramid_level", pyramid_level, mode="max") + dn.log_metric("evidence_validated", 1 if validated else 0, mode="count") + validation_status = "validated" if validated else "UNVALIDATED - confidence reduced" logger.info( - f"Recorded evidence: {evidence_type}={value[:50]}... (pyramid level {pyramid_level})" + f"Recorded evidence: {evidence_type}={value[:50]}... " + f"(pyramid level {pyramid_level}, {validation_status})" ) - return evidence_id + return f"{evidence_id} ({validation_status})" def _resolve_technique_metadata(self, technique_ids: list[str]) -> None: """Look up and cache technique names and tactics.""" @@ -298,6 +318,41 @@ def track_user_investigation(self, username: str) -> str: loader = get_template_loader() return loader.render("tools/user_queries.md.jinja", username=username) + @dn.tool_method # type: ignore[untyped-decorator] + def get_suggested_evidence(self) -> list[dict]: + """Get IOCs auto-extracted from recent query results. + + This helps you record evidence that actually exists in query results, + ensuring proper provenance and avoiding hallucinated evidence. + + The system automatically extracts: + - IP addresses + - Hostnames/FQDNs + - Usernames (domain\\user, user@domain formats) + - Hash values (MD5, SHA1, SHA256) + + Returns: + List of suggested IOCs with type, value, and source query ID. + + Example: + >>> get_suggested_evidence() + [ + {'type': 'ip', 'value': '192.168.1.100', 'source_query_id': 'q-0001'}, + {'type': 'hostname', 'value': 'dc01.domain.local', 'source_query_id': 'q-0001'}, + {'type': 'user', 'value': 'DOMAIN\\\\admin', 'source_query_id': 'q-0002'}, + ] + + See Also: + record_evidence: Use this to record the suggested evidence. + """ + suggestions = get_suggested_iocs() + + if not suggestions: + return [{"message": "No IOCs extracted from recent queries. Run more queries first."}] + + logger.info(f"Returning {len(suggestions)} suggested IOCs from query results") + return suggestions + class QuestionEngineTools(Toolset): # type: ignore[misc] """Tools for the question engines that drive the investigation. diff --git a/src/ares/tools/blue/observability.py b/src/ares/tools/blue/observability.py index 380aa3a5..64222056 100644 --- a/src/ares/tools/blue/observability.py +++ b/src/ares/tools/blue/observability.py @@ -1,6 +1,8 @@ """Observability tools for querying Loki and Prometheus.""" +import asyncio from datetime import datetime, timedelta +from typing import Any import dreadnode as dn import httpx @@ -269,6 +271,181 @@ async def get_label_values(self, label: str) -> list[str]: logger.error(f"Failed to get label values: {e}") return [] + @dn.tool_method # type: ignore[untyped-decorator] + async def execute_parallel_queries( + self, + queries: list[dict[str, Any]], + start_time: str, + end_time: str, + limit: int = 500, + ) -> list[dict]: + """Execute multiple LogQL queries in parallel for faster investigation. + + This is the MOST EFFICIENT way to investigate when you have multiple + independent questions. Use this instead of sequential query_logs calls. + + IMPORTANT: Only queries that are INDEPENDENT should be batched together. + If query B depends on results from query A, run query A first. + + Args: + queries: List of query objects, each containing: + - logql: The LogQL query string + - description: What this query is looking for (for logging) + start_time: ISO8601 timestamp for query start (shared for all queries) + end_time: ISO8601 timestamp for query end (shared for all queries) + limit: Maximum number of log lines per query (default 500) + + Returns: + List of results in the same order as input queries. + Each result contains: + - query: The original query string + - description: The query description + - result: Query results (same format as query_logs) + - success: Whether the query succeeded + + Example: + >>> await execute_parallel_queries( + ... queries=[ + ... {"logql": '{job="syslog"} |= "error"', "description": "Find errors"}, + ... {"logql": '{job="auth"} |= "4625"', "description": "Find failed logins"}, + ... {"logql": '{job="app"} |= "exception"', "description": "Find exceptions"}, + ... ], + ... start_time="2024-01-15T10:00:00Z", + ... end_time="2024-01-15T11:00:00Z" + ... ) + [ + {"query": "...", "description": "...", "result": {...}, "success": True}, + {"query": "...", "description": "...", "result": {...}, "success": True}, + {"query": "...", "description": "...", "result": {...}, "success": True} + ] + + See Also: + query_logs: For single queries. + combine_query_patterns: For combining similar patterns into one query. + """ + if not queries: + return [] + + if len(queries) > 10: + logger.warning(f"Large batch of {len(queries)} queries - consider reducing") + queries = queries[:10] # Cap at 10 to prevent overload + + dn.log_metric("parallel_query_batches", 1, mode="count") + dn.log_metric("parallel_queries_total", len(queries), mode="count") + logger.info(f"Executing {len(queries)} queries in parallel") + + async def execute_single(query_obj: dict[str, Any]) -> dict: + logql = query_obj.get("logql", "") + description = query_obj.get("description", "") + + if not logql: + return { + "query": logql, + "description": description, + "result": {"error": "Empty query"}, + "success": False, + } + + try: + result = await self.query_logs( + logql=logql, + start_time=start_time, + end_time=end_time, + limit=limit, + ) + success = "error" not in result + return { + "query": logql, + "description": description, + "result": result, + "success": success, + } + except Exception as e: + logger.error(f"Parallel query failed: {e}") + return { + "query": logql, + "description": description, + "result": {"error": str(e)}, + "success": False, + } + + # Execute all queries in parallel + tasks = [execute_single(q) for q in queries] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Convert exceptions to error results + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + processed_results.append( + { + "query": queries[i].get("logql", ""), + "description": queries[i].get("description", ""), + "result": {"error": str(result)}, + "success": False, + } + ) + else: + processed_results.append(result) # type: ignore[arg-type] + + # Log summary + successful = sum(1 for r in processed_results if r.get("success")) + logger.info(f"Parallel queries complete: {successful}/{len(queries)} successful") + + return processed_results + + @dn.tool_method # type: ignore[untyped-decorator] + def combine_query_patterns( + self, + base_selector: str, + patterns: list[str], + ) -> str: + """Combine multiple regex patterns into a single efficient LogQL query. + + Instead of running 3 separate queries for "error", "failed", "exception", + combine them into one query with |~ "error|failed|exception". + + This is MORE EFFICIENT than parallel queries when searching for + multiple patterns in the same log stream. + + Args: + base_selector: The label selector (e.g., '{job="syslog"}') + patterns: List of patterns to combine (e.g., ["error", "failed", "exception"]) + + Returns: + Combined LogQL query string. + + Example: + >>> combine_query_patterns( + ... base_selector='{job="syslog"}', + ... patterns=["error", "failed", "critical"] + ... ) + '{job="syslog"} |~ "error|failed|critical"' + + >>> combine_query_patterns( + ... base_selector='{job="auth"}', + ... patterns=["4625", "4624", "4771"] + ... ) + '{job="auth"} |~ "4625|4624|4771"' + + See Also: + execute_parallel_queries: For truly independent queries. + """ + if not patterns: + return base_selector + + # Escape any regex special characters in patterns + escaped = [] + for p in patterns: + # Escape common regex chars but preserve intended regex + if not any(c in p for c in ".*+?()[]{}|\\^$"): + escaped.append(p) + else: + escaped.append(p) # Keep as-is if it looks like regex + + combined = "|".join(escaped) + return f'{base_selector} |~ "{combined}"' + class PrometheusTools(Toolset): # type: ignore[misc] """Tools for querying Prometheus metrics. diff --git a/templates/agent/system_instructions.md.jinja b/templates/agent/system_instructions.md.jinja index 044d99a4..ff87dd28 100644 --- a/templates/agent/system_instructions.md.jinja +++ b/templates/agent/system_instructions.md.jinja @@ -170,11 +170,47 @@ You MUST leverage parallelism. When you have multiple questions: 2. Execute ALL independent queries in a SINGLE response 3. This is the power of automation - don't waste it on sequential queries +### Method 1: execute_parallel_queries() - RECOMMENDED + +Use this tool to batch multiple queries into a single call: +```python +await execute_parallel_queries( + queries=[ + {"logql": '{job="syslog"} |= "4625"', "description": "Failed logins"}, + {"logql": '{job="syslog"} |= "4624"', "description": "Successful logins"}, + {"logql": '{job="app"} |= "powershell"', "description": "PowerShell activity"}, + ], + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z" +) +``` + +### Method 2: combine_query_patterns() - For similar patterns + +When searching for multiple patterns in the SAME log stream, combine them: +```python +# Instead of 3 queries, combine into one: +combine_query_patterns( + base_selector='{job="auth"}', + patterns=["4625", "4624", "4771"] +) +# Returns: {job="auth"} |~ "4625|4624|4771" +``` + +### When to use which: +- **execute_parallel_queries**: Different log streams or different label selectors +- **combine_query_patterns**: Same log stream, searching for multiple patterns + Example - GOOD (parallel): - Query 1: {hostname="web-01"} |= "powershell" - Query 2: {hostname="web-01"} |= "download" - Query 3: {job="auth", user="admin"} | json -[Execute all 3 in one tool call batch] +[Execute all 3 with execute_parallel_queries()] + +Example - BETTER (combined + parallel): +- Combined Query: {hostname="web-01"} |~ "powershell|download" +- Query 2: {job="auth", user="admin"} | json +[2 queries instead of 3] Example - BAD (sequential): - Query 1, wait for response