diff --git a/.taskfiles/ec2/Taskfile.yaml b/.taskfiles/ec2/Taskfile.yaml index cd2b1f98..83b4e3f6 100644 --- a/.taskfiles/ec2/Taskfile.yaml +++ b/.taskfiles/ec2/Taskfile.yaml @@ -954,101 +954,145 @@ tasks: exit 1 fi - # Step 1: Generate report on EC2 and get the operation ID + report content - REPORT_CMD="set -e; RUST_LOG=error ares ops report" - {{if ne .OPERATION_ID ""}}REPORT_CMD="$REPORT_CMD {{.OPERATION_ID}}"{{end}} - {{if eq .LATEST "true"}}REPORT_CMD="$REPORT_CMD --latest"{{end}} - {{if eq .REGENERATE "true"}}REPORT_CMD="$REPORT_CMD --regenerate"{{end}} - REPORT_CMD="$REPORT_CMD --output-dir /tmp/reports 2>&1" - # After generation, output the report file content with a marker - REPORT_CMD="$REPORT_CMD; echo '===REPORT_FILES==='; find /tmp/reports/red -name '*.md' -type f 2>/dev/null | while read f; do echo \"FILE:red/\$(basename \$f)\"; cat \"\$f\"; done" - - PARAMS_FILE=$(mktemp) - trap "rm -f $PARAMS_FILE" EXIT - jq -n --arg cmd "$REPORT_CMD" '{"commands": [$cmd]}' > "$PARAMS_FILE" - - echo -e "{{.INFO}} Generating report on EC2..." - - CMD_ID=$(aws ssm send-command \ - --profile "{{.EC2_PROFILE}}" \ - --region "{{.EC2_REGION}}" \ - --instance-ids "$INSTANCE_ID" \ - --document-name "AWS-RunShellScript" \ - --parameters "file://$PARAMS_FILE" \ - --timeout-seconds 120 \ - --query "Command.CommandId" --output text) + run_ssm_cmd() { + CMD_PAYLOAD="$1" + TIMEOUT="${2:-120}" + PARAMS_FILE=$(mktemp) + jq -n --arg cmd "$CMD_PAYLOAD" '{"commands": [$cmd]}' > "$PARAMS_FILE" - for i in $(seq 1 120); do - STATUS=$(aws ssm get-command-invocation \ + CMD_ID=$(aws ssm send-command \ --profile "{{.EC2_PROFILE}}" \ --region "{{.EC2_REGION}}" \ - --command-id "$CMD_ID" \ - --instance-id "$INSTANCE_ID" \ - --query "Status" --output text 2>/dev/null) || true - case "$STATUS" in - Success|Failed|Cancelled|TimedOut) break ;; - esac - sleep 1 - done + --instance-ids "$INSTANCE_ID" \ + --document-name "AWS-RunShellScript" \ + --parameters "file://$PARAMS_FILE" \ + --timeout-seconds "$TIMEOUT" \ + --query "Command.CommandId" --output text) - OUTPUT=$(aws ssm get-command-invocation \ - --profile "{{.EC2_PROFILE}}" \ - --region "{{.EC2_REGION}}" \ - --command-id "$CMD_ID" \ - --instance-id "$INSTANCE_ID" \ - --query "StandardOutputContent" --output text) + rm -f "$PARAMS_FILE" - if [ "$STATUS" != "Success" ]; then - DETAILS=$(aws ssm get-command-invocation \ + STATUS="" + for i in $(seq 1 "$TIMEOUT"); do + STATUS=$(aws ssm get-command-invocation \ + --profile "{{.EC2_PROFILE}}" \ + --region "{{.EC2_REGION}}" \ + --command-id "$CMD_ID" \ + --instance-id "$INSTANCE_ID" \ + --query "Status" --output text 2>/dev/null) || true + case "$STATUS" in + Success|Failed|Cancelled|TimedOut) break ;; + esac + sleep 1 + done + + OUTPUT=$(aws ssm get-command-invocation \ --profile "{{.EC2_PROFILE}}" \ --region "{{.EC2_REGION}}" \ --command-id "$CMD_ID" \ --instance-id "$INSTANCE_ID" \ - --query "StatusDetails" --output text 2>/dev/null) - echo -e "{{.ERROR}} Report generation failed (status: $STATUS, details: $DETAILS)" >&2 - if [ "$DETAILS" = "Undeliverable" ]; then - echo -e "{{.ERROR}} SSM could not deliver the command to $INSTANCE_ID (PingStatus likely ConnectionLost)." >&2 - echo -e "{{.ERROR}} Recovery: reboot the instance ('aws ec2 reboot-instances --instance-ids $INSTANCE_ID')." >&2 + --query "StandardOutputContent" --output text) + + if [ "$STATUS" != "Success" ]; then + DETAILS=$(aws ssm get-command-invocation \ + --profile "{{.EC2_PROFILE}}" \ + --region "{{.EC2_REGION}}" \ + --command-id "$CMD_ID" \ + --instance-id "$INSTANCE_ID" \ + --query "StatusDetails" --output text 2>/dev/null) + echo -e "{{.ERROR}} SSM command failed (status: $STATUS, details: $DETAILS)" >&2 + if [ "$DETAILS" = "Undeliverable" ]; then + echo -e "{{.ERROR}} SSM could not deliver the command to $INSTANCE_ID (PingStatus likely ConnectionLost)." >&2 + echo -e "{{.ERROR}} Recovery: reboot the instance ('aws ec2 reboot-instances --instance-ids $INSTANCE_ID')." >&2 + fi + if [ -n "$OUTPUT" ]; then + echo "$OUTPUT" >&2 + fi + aws ssm get-command-invocation \ + --profile "{{.EC2_PROFILE}}" \ + --region "{{.EC2_REGION}}" \ + --command-id "$CMD_ID" \ + --instance-id "$INSTANCE_ID" \ + --query "StandardErrorContent" --output text >&2 + return 1 fi - aws ssm get-command-invocation \ - --profile "{{.EC2_PROFILE}}" \ - --region "{{.EC2_REGION}}" \ - --command-id "$CMD_ID" \ - --instance-id "$INSTANCE_ID" \ - --query "StandardErrorContent" --output text >&2 - exit 1 - fi - # Step 2: Extract report content and save locally - mkdir -p "{{.OUTPUT_DIR}}/red" + printf '%s' "$OUTPUT" + } + + REPORT_ARGS="" + {{if ne .OPERATION_ID ""}}REPORT_ARGS="$REPORT_ARGS {{.OPERATION_ID}}"{{end}} + {{if eq .LATEST "true"}}REPORT_ARGS="$REPORT_ARGS --latest"{{end}} + {{if eq .REGENERATE "true"}}REPORT_ARGS="$REPORT_ARGS --regenerate"{{end}} + REPORT_CMD="set +e" + REPORT_CMD="$REPORT_CMD; OUTPUT=\$(RUST_LOG=error ares ops report${REPORT_ARGS} --output-dir /tmp/reports 2>&1)" + REPORT_CMD="$REPORT_CMD; RC=\$?; set -e; echo \"\$OUTPUT\"; if [ \$RC -ne 0 ]; then exit \$RC; fi" + REPORT_CMD="$REPORT_CMD; REPORT_PATH=\$(printf '%s\n' \"\$OUTPUT\" | sed -n 's/^Report saved to \([^ ]*\.md\).*/\1/p' | tail -1)" + REPORT_CMD="$REPORT_CMD; echo '===REPORT_META==='" + REPORT_CMD="$REPORT_CMD; if [ -z \"\$REPORT_PATH\" ]; then echo 'ERROR: could not parse report path from ares output' >&2; exit 1; fi" + REPORT_CMD="$REPORT_CMD; if [ ! -f \"\$REPORT_PATH\" ]; then echo \"ERROR: report file not found: \$REPORT_PATH\" >&2; exit 1; fi" + REPORT_CMD="$REPORT_CMD; echo \"FILE:red/\$(basename \"\$REPORT_PATH\")\"" + REPORT_CMD="$REPORT_CMD; echo \"PATH:\$REPORT_PATH\"" + REPORT_CMD="$REPORT_CMD; echo \"BYTES:\$(wc -c < \"\$REPORT_PATH\" | tr -d ' ')\"" + REPORT_CMD="$REPORT_CMD; echo \"SHA256:\$(sha256sum \"\$REPORT_PATH\" | awk '{print \$1}')\"" + + echo -e "{{.INFO}} Generating report on EC2..." - # Parse the output — everything after ===REPORT_FILES=== contains FILE:name + content - BEFORE_MARKER=$(echo "$OUTPUT" | sed '/===REPORT_FILES===/,$d') - AFTER_MARKER=$(echo "$OUTPUT" | sed -n '/===REPORT_FILES===/,$p' | tail -n +2) + OUTPUT=$(run_ssm_cmd "$REPORT_CMD" 120) + BEFORE_MARKER=$(echo "$OUTPUT" | sed '/===REPORT_META===/,$d') + META=$(echo "$OUTPUT" | sed -n '/===REPORT_META===/,$p' | tail -n +2) - # Show the CLI output (report generation messages) echo "$BEFORE_MARKER" - # Extract filename and content, save to local files - FILENAME="" - OUTFILE="" - while IFS= read -r line; do - if [[ "$line" == FILE:* ]]; then - if [ -n "$FILENAME" ] && [ -f "$OUTFILE" ]; then - echo -e "{{.SUCCESS}} Report saved: $OUTFILE" - fi - FILENAME="${line#FILE:}" - OUTFILE="{{.OUTPUT_DIR}}/$FILENAME" - : > "$OUTFILE" - elif [ -n "$OUTFILE" ]; then - echo "$line" >> "$OUTFILE" + FILENAME=$(echo "$META" | sed -n 's/^FILE://p' | head -1) + REMOTE_PATH=$(echo "$META" | sed -n 's/^PATH://p' | head -1) + REMOTE_BYTES=$(echo "$META" | sed -n 's/^BYTES://p' | head -1) + REMOTE_SHA=$(echo "$META" | sed -n 's/^SHA256://p' | head -1) + + if [ -z "$FILENAME" ] || [ -z "$REMOTE_PATH" ] || [ -z "$REMOTE_BYTES" ] || [ -z "$REMOTE_SHA" ]; then + echo -e "{{.ERROR}} Could not parse report metadata from EC2 output" >&2 + exit 1 + fi + + mkdir -p "{{.OUTPUT_DIR}}/red" + OUTFILE="{{.OUTPUT_DIR}}/$FILENAME" + TMP_OUT="${OUTFILE}.tmp.$$" + : > "$TMP_OUT" + trap "rm -f '$TMP_OUT'" EXIT + + CHUNK_SIZE=12000 + CHUNKS=$(( (REMOTE_BYTES + CHUNK_SIZE - 1) / CHUNK_SIZE )) + echo -e "{{.INFO}} Fetching $REMOTE_BYTES bytes from EC2 in $CHUNKS chunks..." + + CHUNK=0 + while [ $CHUNK -lt $CHUNKS ]; do + CHUNK_CMD="dd if=\"$REMOTE_PATH\" bs=$CHUNK_SIZE skip=$CHUNK count=1 status=none | base64 | tr -d '\n'" + CHUNK_B64=$(run_ssm_cmd "$CHUNK_CMD" 30) + if ! printf '%s' "$CHUNK_B64" | base64 -d >> "$TMP_OUT"; then + echo -e "{{.ERROR}} Failed to decode report chunk $((CHUNK + 1))/$CHUNKS" >&2 + exit 1 fi - done <<< "$AFTER_MARKER" + CHUNK=$((CHUNK + 1)) + done - if [ -n "$FILENAME" ] && [ -f "$OUTFILE" ]; then - echo -e "{{.SUCCESS}} Report saved: $OUTFILE" + LOCAL_BYTES=$(wc -c < "$TMP_OUT" | tr -d ' ') + if [ "$LOCAL_BYTES" != "$REMOTE_BYTES" ]; then + echo -e "{{.ERROR}} Report size mismatch: remote=$REMOTE_BYTES local=$LOCAL_BYTES" >&2 + exit 1 fi + if command -v sha256sum >/dev/null 2>&1; then + LOCAL_SHA=$(sha256sum "$TMP_OUT" | awk '{print $1}') + else + LOCAL_SHA=$(shasum -a 256 "$TMP_OUT" | awk '{print $1}') + fi + if [ "$LOCAL_SHA" != "$REMOTE_SHA" ]; then + echo -e "{{.ERROR}} Report checksum mismatch: remote=$REMOTE_SHA local=$LOCAL_SHA" >&2 + exit 1 + fi + + mv "$TMP_OUT" "$OUTFILE" + echo -e "{{.SUCCESS}} Report saved: $OUTFILE" + ops: desc: "List all operations on EC2 (usage: task ec2:ops [EC2_NAME=ares-tools] [LATEST=true])" silent: true @@ -1146,7 +1190,7 @@ tasks: LLM_MODEL: '{{.LLM_MODEL | default ""}}' FLUSH_REDIS: '{{.FLUSH_REDIS | default "true"}}' OPERATION_ID: '{{.OPERATION_ID | default ""}}' - WAIT: '{{.WAIT | default "false"}}' + WAIT: '{{.WAIT | default "true"}}' POLL_INTERVAL: '{{.POLL_INTERVAL | default "30"}}' MAX_WAIT: '{{.MAX_WAIT | default "7200"}}' OUTPUT_DIR: '{{.OUTPUT_DIR | default "./reports"}}' diff --git a/.taskfiles/red/Taskfile.yaml b/.taskfiles/red/Taskfile.yaml index 97228365..a4eb7f4b 100644 --- a/.taskfiles/red/Taskfile.yaml +++ b/.taskfiles/red/Taskfile.yaml @@ -19,11 +19,15 @@ tasks: # =========================================================================== multi: - desc: "Run multi-agent red team operation (usage: task red:multi [TARGET=dreadgoad] [DOMAIN=contoso.local] [TARGET_ENV=staging] [IPS=10.1.10.10,10.1.10.11])" + desc: "Run multi-agent red team operation (usage: task red:multi [TARGET=dreadgoad] [DOMAIN=contoso.local] [TARGET_ENV=staging] [IPS=10.1.10.10,10.1.10.11] [FOLLOW=true])" silent: true vars: OPERATION_ID: '{{.OPERATION_ID | default ""}}' RESUME: '{{.RESUME | default "false"}}' + FOLLOW: '{{.FOLLOW | default "true"}}' + POLL_INTERVAL: '{{.POLL_INTERVAL | default "30"}}' + MAX_WAIT: '{{.MAX_WAIT | default "7200"}}' + OUTPUT_DIR: '{{.OUTPUT_DIR | default "./reports"}}' TARGET_ENV: '{{.TARGET_ENV | default "staging"}}' IPS: '{{.IPS | default ""}}' OPERATION_ID_COMPUTED: @@ -101,90 +105,55 @@ tasks: 2>&1 | tee -a "{{.LOGFILE}}" silent: false - # Follow logs filtered to this operation only (auto-exits on completion) - cmd: | + if [ "{{.FOLLOW}}" != "true" ]; then + echo "" + echo "Operation {{.OPERATION_ID_COMPUTED}} submitted without local wait." + echo "Auto-fetch later with: task red:multi:watch OPERATION_ID={{.OPERATION_ID_COMPUTED}} ONCE=true OUTPUT_DIR={{.OUTPUT_DIR}}" + exit 0 + fi + echo "" - echo "Following orchestrator logs (auto-exits on completion, Ctrl+C to stop early)..." + echo "Watching operation status (poll={{.POLL_INTERVAL}}s, max_wait={{.MAX_WAIT}}s)" echo "Operation ID: {{.OPERATION_ID_COMPUTED}}" + echo "Report will be fetched to {{.OUTPUT_DIR}}/red/ when the op reaches a terminal state" echo "" - sleep 2 - - DONE_MARKER="/tmp/ares_done_$$" - rm -f "$DONE_MARKER" - - while IFS= read -r line; do - echo "$line" | tee -a "{{.LOGFILE}}" - if echo "$line" | grep -q "completed successfully"; then - echo "" - echo "Operation completed - fetching report..." - touch "$DONE_MARKER" - pkill -f "kubectl logs -f.*{{.K8S_NAMESPACE}}.*ares-orchestrator" 2>/dev/null || true - break + START=$(date +%s) + while true; do + ELAPSED=$(( $(date +%s) - START )) + if [ $ELAPSED -gt {{.MAX_WAIT}} ]; then + echo "ERROR: Max wait ({{.MAX_WAIT}}s) exceeded; report was not auto-fetched" + echo "Fetch manually: task red:multi:report OPERATION_ID={{.OPERATION_ID_COMPUTED}} OUTPUT_DIR={{.OUTPUT_DIR}}" + exit 1 fi - done < <(timeout 1800 stdbuf -oL kubectl logs -f --since=10s -n {{.K8S_NAMESPACE}} deploy/ares-orchestrator 2>&1 | \ - stdbuf -oL grep --line-buffered "{{.OPERATION_ID_COMPUTED}}" || true) - - pkill -f "kubectl logs -f.*{{.K8S_NAMESPACE}}.*ares-orchestrator" 2>/dev/null || true - - if [ ! -f "$DONE_MARKER" ]; then - echo "" - echo "Operation {{.OPERATION_ID_COMPUTED}} is queued/running but not yet complete." - echo "Monitor: task red:multi:loot LATEST=true WATCH=5" - echo "Report: task red:multi:report OPERATION_ID={{.OPERATION_ID_COMPUTED}}" - touch "/tmp/ares_skip_report_$$" - fi - rm -f "$DONE_MARKER" - silent: false - - # Auto-fetch report from Redis when operation completes - - cmd: | - SKIP_MARKER="/tmp/ares_skip_report_$$" - if [ -f "$SKIP_MARKER" ]; then - rm -f "$SKIP_MARKER" - exit 0 - fi - mkdir -p ./reports - echo "Fetching report from Redis..." + STATUS_OUT=$({{.ARES_CLI}} --k8s {{.K8S_NAMESPACE}} ops status "{{.OPERATION_ID_COMPUTED}}" 2>&1 || true) + STATUS=$(echo "$STATUS_OUT" | grep -E '^Status: ' | head -1 | awk '{print $2}') - lsof -ti:16379 | xargs kill 2>/dev/null || true - sleep 1 - kubectl port-forward -n {{.K8S_NAMESPACE}} svc/redis 16379:6379 2>/dev/null & - PF_PID=$! + if [ -z "$STATUS" ]; then + echo "[${ELAPSED}s] no status yet (waiting for op to register)" | tee -a "{{.LOGFILE}}" + else + echo "[${ELAPSED}s] status=$STATUS" | tee -a "{{.LOGFILE}}" + case "$STATUS" in + completed|stopped) + echo "" + echo "Operation reached terminal state: $STATUS" + echo "Fetching report..." + task red:multi:report OPERATION_ID={{.OPERATION_ID_COMPUTED}} OUTPUT_DIR={{.OUTPUT_DIR}} + exit 0 + ;; + failed|cancelled) + echo "Operation reached terminal state: $STATUS; fetching best-effort report" + task red:multi:report OPERATION_ID={{.OPERATION_ID_COMPUTED}} OUTPUT_DIR={{.OUTPUT_DIR}} || true + exit 1 + ;; + esac + fi - i=0 - while [ $i -lt 15 ]; do - if nc -z localhost 16379 2>/dev/null; then break; fi - sleep 1; i=$((i + 1)) + sleep {{.POLL_INTERVAL}} done - - if ! nc -z localhost 16379 2>/dev/null; then - echo "ERROR: kubectl port-forward failed" - echo "Fetch manually: task red:multi:report OPERATION_ID={{.OPERATION_ID_COMPUTED}}" - kill $PF_PID 2>/dev/null || true - exit 0 - fi - trap "kill $PF_PID 2>/dev/null || true" EXIT - - REDIS_PASS=$(kubectl get secret redis-secret -n {{.K8S_NAMESPACE}} -o jsonpath='{.data.password}' 2>/dev/null | base64 -d || echo "") - if [ -n "$REDIS_PASS" ]; then - export ARES_REDIS_URL="redis://:${REDIS_PASS}@localhost:16379" - else - export ARES_REDIS_URL="redis://localhost:16379" - fi - - if {{.ARES_CLI}} ops report "{{.OPERATION_ID_COMPUTED}}" --output-dir ./reports 2>&1; then - echo "" - echo "Report saved: ./reports/red/{{.OPERATION_ID_COMPUTED}}.md" - else - echo "" - echo "ERROR: Report fetch failed" - echo "Fetch manually: task red:multi:report OPERATION_ID={{.OPERATION_ID_COMPUTED}}" - fi - kill $PF_PID 2>/dev/null || true silent: false - ignore_error: true # =========================================================================== # K8s CLI wrappers — run ares on the orchestrator pod @@ -332,61 +301,81 @@ tasks: fi multi:watch: - desc: "Watch for completed operations and auto-fetch reports (usage: task red:multi:watch [POLL_INTERVAL=30] [ONCE=true])" + desc: "Poll red operation status and auto-fetch reports (usage: task red:multi:watch [OPERATION_ID=op-xxx] [LATEST=true] [POLL_INTERVAL=30] [ONCE=true])" silent: true vars: + OPERATION_ID: '{{.OPERATION_ID | default ""}}' + LATEST: '{{.LATEST | default "true"}}' POLL_INTERVAL: '{{.POLL_INTERVAL | default "30"}}' + MAX_WAIT: '{{.MAX_WAIT | default "7200"}}' OUTPUT_DIR: '{{.OUTPUT_DIR | default "./reports"}}' ONCE: '{{.ONCE | default "false"}}' - REDIS_PASSWORD: - sh: kubectl get secret redis-secret -n {{.K8S_NAMESPACE}} -o jsonpath='{.data.password}' 2>/dev/null | base64 -d || echo "" + preconditions: + - sh: test -n "{{.OPERATION_ID}}" || test "{{.LATEST}}" = "true" + msg: "Either OPERATION_ID or LATEST=true is required" cmds: - cmd: mkdir -p "{{.OUTPUT_DIR}}" silent: true - | - # Start port-forward in background (use 16379 locally to avoid conflicts with local Redis) - # Kill any stale port-forward first - lsof -ti:16379 | xargs kill 2>/dev/null || true - sleep 1 - - kubectl port-forward -n {{.K8S_NAMESPACE}} svc/redis 16379:6379 2>/dev/null & - PF_PID=$! + OP_ARG="" + LATEST_FLAG="" + if [ -n "{{.OPERATION_ID}}" ]; then + OP_ARG="{{.OPERATION_ID}}" + else + LATEST_FLAG="--latest" + fi - cleanup() { - kill $PF_PID 2>/dev/null || true - exit 0 - } - trap cleanup INT TERM EXIT + START=$(date +%s) + FETCHED_OPS="" + echo "Watching red operation status (poll={{.POLL_INTERVAL}}s, max_wait={{.MAX_WAIT}}s)" + echo "Reports will be fetched to {{.OUTPUT_DIR}}/red/" - # Wait for port to be connectable (up to 15 seconds) - i=0 - while [ $i -lt 15 ]; do - if nc -z localhost 16379 2>/dev/null; then - break + while true; do + ELAPSED=$(( $(date +%s) - START )) + if [ $ELAPSED -gt {{.MAX_WAIT}} ]; then + echo "ERROR: Max wait ({{.MAX_WAIT}}s) exceeded" + exit 1 fi - sleep 1 - i=$((i + 1)) - done - if ! nc -z localhost 16379 2>/dev/null; then - echo "ERROR: kubectl port-forward failed (port not connectable after 15s)" - exit 1 - fi - - # Build Redis URL with password if available - if [ -n "{{.REDIS_PASSWORD}}" ]; then - export ARES_REDIS_URL="redis://:{{.REDIS_PASSWORD}}@localhost:16379" - else - export ARES_REDIS_URL="redis://localhost:16379" - fi + STATUS_OUT=$({{.ARES_CLI}} --k8s {{.K8S_NAMESPACE}} ops status $OP_ARG $LATEST_FLAG 2>&1 || true) + STATUS=$(echo "$STATUS_OUT" | grep -E '^Status: ' | head -1 | awk '{print $2}') + RESOLVED_OP=$(echo "$STATUS_OUT" | grep -E '^Operation: ' | head -1 | awk '{print $2}') - ONCE_FLAG="" - if [ "{{.ONCE}}" = "true" ]; then - ONCE_FLAG="--once" - fi + if [ -z "$STATUS" ]; then + echo "[${ELAPSED}s] no status yet (waiting for op to register)" + else + echo "[${ELAPSED}s] op=${RESOLVED_OP:-?} status=$STATUS" + case "$STATUS" in + completed|stopped) + if [ -z "$RESOLVED_OP" ]; then + echo "ERROR: Could not resolve operation ID; cannot fetch report" + exit 1 + fi + + if echo " $FETCHED_OPS " | grep -q " $RESOLVED_OP "; then + echo "Report already fetched for $RESOLVED_OP" + else + echo "Fetching report for $RESOLVED_OP..." + task red:multi:report OPERATION_ID="$RESOLVED_OP" OUTPUT_DIR="{{.OUTPUT_DIR}}" + FETCHED_OPS="$FETCHED_OPS $RESOLVED_OP" + fi + + if [ "{{.ONCE}}" = "true" ]; then + exit 0 + fi + ;; + failed|cancelled) + if [ -n "$RESOLVED_OP" ]; then + echo "Operation reached terminal state: $STATUS; fetching best-effort report" + task red:multi:report OPERATION_ID="$RESOLVED_OP" OUTPUT_DIR="{{.OUTPUT_DIR}}" || true + fi + exit 1 + ;; + esac + fi - # Run watch locally - fetches reports from Redis to local ./reports - {{.ARES_CLI}} ops watch --poll-interval {{.POLL_INTERVAL}} --output-dir "{{.OUTPUT_DIR}}" $ONCE_FLAG + sleep {{.POLL_INTERVAL}} + done multi:list: desc: "List multi-agent operations and their state" diff --git a/ares-cli/src/cli/ops.rs b/ares-cli/src/cli/ops.rs index 90a61c97..779c0dea 100644 --- a/ares-cli/src/cli/ops.rs +++ b/ares-cli/src/cli/ops.rs @@ -401,7 +401,7 @@ pub(crate) enum OpsCommands { /// Poll interval in seconds for --follow mode #[arg(long, default_value = "5")] follow_interval: u64, - /// Auto-fetch report when operation completes (requires --follow) + /// Auto-fetch report when operation completes; implied by --follow #[arg(long)] auto_report: bool, /// Output directory for auto-report diff --git a/ares-cli/src/ops/mod.rs b/ares-cli/src/ops/mod.rs index 13d925d5..3946dab5 100644 --- a/ares-cli/src/ops/mod.rs +++ b/ares-cli/src/ops/mod.rs @@ -254,10 +254,11 @@ pub(crate) async fn run_ops(cmd: OpsCommands, redis_url: Option) -> Resu pin_active, ) .await?; - if follow { + let should_wait_for_report = follow || auto_report; + if should_wait_for_report { submit::follow_operation(redis_url.clone(), &op_id, follow_interval).await?; } - if auto_report { + if should_wait_for_report { report::ops_report(redis_url, Some(op_id), false, false, report_dir).await?; } Ok(()) diff --git a/ares-cli/src/orchestrator/automation/credential_access.rs b/ares-cli/src/orchestrator/automation/credential_access.rs index 430ae160..a6e6f261 100644 --- a/ares-cli/src/orchestrator/automation/credential_access.rs +++ b/ares-cli/src/orchestrator/automation/credential_access.rs @@ -288,6 +288,9 @@ pub(crate) fn select_low_hanging_work( .filter(|c| !state.is_principal_quarantined(&c.username, &c.domain)) .filter_map(|cred| { let cred_domain = cred.domain.to_lowercase(); + if state.is_domain_dominated(&cred_domain) { + return None; + } let dedup = low_hanging_dedup_key(&cred_domain, &cred.username); if state.is_processed(DEDUP_LOW_HANGING, &dedup) { return None; @@ -1314,6 +1317,18 @@ mod tests { assert_eq!(work[0].1, "192.168.58.99"); } + #[test] + fn select_low_hanging_skips_dominated_domain() { + let mut s = StateInner::new("op".into()); + s.credentials + .push(make_cred("alice", "Pw!", "contoso.local")); + s.domain_controllers + .insert("contoso.local".into(), "192.168.58.10".into()); + s.dominated_domains.insert("contoso.local".into()); + + assert!(select_low_hanging_work(&s, 10).is_empty()); + } + // --- select_credential_secretsdump_work ---------------------------- #[test] diff --git a/ares-cli/src/orchestrator/automation/credential_expansion.rs b/ares-cli/src/orchestrator/automation/credential_expansion.rs index 96805b12..f0abb979 100644 --- a/ares-cli/src/orchestrator/automation/credential_expansion.rs +++ b/ares-cli/src/orchestrator/automation/credential_expansion.rs @@ -132,6 +132,9 @@ pub(crate) fn select_credential_expansion_work( return None; } let cred_domain = resolve_cred_domain(state, &cred.domain); + if state.is_domain_dominated(&cred_domain) { + return None; + } let targets = find_lateral_targets_for_cred_domain(state, &cred_domain); if targets.is_empty() { return None; @@ -185,7 +188,8 @@ pub(crate) fn build_pth_credential( /// Snapshot the next batch of hash-expansion work items. /// /// Filters `state.hashes` for non-`krbtgt`, non-machine NTLM hashes, with -/// at least one non-owned target host, capping at `max_items`. +/// at least one same-forest non-owned target host or DC, capping at +/// `max_items`. pub(crate) fn select_hash_expansion_work( state: &StateInner, max_items: usize, @@ -204,19 +208,21 @@ pub(crate) fn select_hash_expansion_work( if state.is_processed(DEDUP_HASH_LATERAL, &dedup) { return None; } - let targets: Vec = state - .hosts - .iter() - .filter(|h| !h.owned) - .map(|h| h.ip.clone()) - .collect(); - if targets.is_empty() { + let hash_domain = resolve_cred_domain(state, &hash.domain); + if state.is_domain_dominated(&hash_domain) { + return None; + } + let targets = find_lateral_targets_for_cred_domain(state, &hash_domain); + let dc_ips = find_pth_dc_ips_for_hash(state, &hash_domain); + if targets.is_empty() && dc_ips.is_empty() { return None; } Some(HashExpansionWork { dedup_key: dedup, hash: hash.clone(), + resolved_domain: hash_domain, targets, + dc_ips, }) }) .take(max_items) @@ -370,16 +376,18 @@ pub async fn auto_credential_expansion( }; for item in hash_work { - let mut dc_sd_dispatched = false; + let mut any_dispatched = false; // Build a credential-like object for pass-the-hash - let pth_cred = build_pth_credential(&item.hash); + let mut pth_cred = build_pth_credential(&item.hash); + pth_cred.domain = item.resolved_domain.clone(); for target_ip in item.targets.iter().take(3) { if let Ok(Some(task_id)) = dispatcher .request_lateral(target_ip, &pth_cred, "pth_smbclient") .await { + any_dispatched = true; debug!( task_id = %task_id, target = %target_ip, @@ -400,19 +408,14 @@ pub async fn auto_credential_expansion( // path was missing the gate, dispatching foreign-forest creds // against unrelated DCs. { - let dc_ips: Vec = { - let state = dispatcher.state.read().await; - find_pth_dc_ips_for_hash(&state, &item.hash.domain) - }; - if !dispatcher.is_technique_allowed("secretsdump") { // Strategy excludes secretsdump — skip hash-based expansion too. } else { - for dc_ip in dc_ips { + for dc_ip in &item.dc_ips { let sd_dedup = format!( "{}:{}:{}", dc_ip, - item.hash.domain.to_lowercase(), + &item.resolved_domain, item.hash.username.to_lowercase() ); let already = { @@ -422,10 +425,10 @@ pub async fn auto_credential_expansion( if !already { let priority = dispatcher.effective_priority("secretsdump"); if let Ok(Some(task_id)) = dispatcher - .request_secretsdump(&dc_ip, &pth_cred, priority) + .request_secretsdump(dc_ip, &pth_cred, priority) .await { - dc_sd_dispatched = true; + any_dispatched = true; debug!( task_id = %task_id, dc = %dc_ip, @@ -447,9 +450,10 @@ pub async fn auto_credential_expansion( } // end else (secretsdump allowed for hash expansion) } - // Only mark as fully processed once DC secretsdump has been dispatched. - // PTH lateral alone is not sufficient — the critical path is hash→DC→krbtgt. - if dc_sd_dispatched { + // Mark once at least one viable same-forest task was dispatched. + // Per-DC secretsdump dedup remains separate, so a deferred dump can + // still be attempted by other paths without re-spraying the hash. + if any_dispatched { dispatcher .state .write() @@ -588,7 +592,9 @@ pub(crate) struct ExpansionWork { pub(crate) struct HashExpansionWork { pub dedup_key: String, pub hash: ares_core::models::Hash, + pub resolved_domain: String, pub targets: Vec, + pub dc_ips: Vec, } #[cfg(test)] @@ -1456,9 +1462,64 @@ mod tests { let work = select_hash_expansion_work(&s, 10); assert_eq!(work.len(), 1); assert_eq!(work[0].hash.username, "alice"); + assert_eq!(work[0].resolved_domain, "contoso.local"); + assert_eq!(work[0].targets, vec!["192.168.58.10"]); + } + + #[test] + fn select_hash_work_resolves_netbios_domain_for_dispatch() { + let mut s = StateInner::new("op".into()); + s.netbios_to_fqdn + .insert("north".into(), "north.sevenkingdoms.local".into()); + s.hosts.push(make_host( + "winterfell.north.sevenkingdoms.local", + "192.168.58.10", + )); + s.hashes.push(make_ntlm_hash("alice", "aaaaaaaa", "NORTH")); + + let work = select_hash_expansion_work(&s, 10); + assert_eq!(work.len(), 1); + assert_eq!(work[0].resolved_domain, "north.sevenkingdoms.local"); assert_eq!(work[0].targets, vec!["192.168.58.10"]); } + #[test] + fn select_hash_work_skips_cross_forest_targets() { + let mut s = StateInner::new("op".into()); + s.hosts + .push(make_host("srv01.fabrikam.local", "192.168.58.40")); + s.hashes + .push(make_ntlm_hash("alice", "aaaaaaaa", "contoso.local")); + + assert!(select_hash_expansion_work(&s, 10).is_empty()); + } + + #[test] + fn select_hash_work_allows_same_domain_dc_without_host_target() { + let mut s = StateInner::new("op".into()); + s.domain_controllers + .insert("contoso.local".into(), "192.168.58.10".into()); + s.hashes + .push(make_ntlm_hash("alice", "aaaaaaaa", "contoso.local")); + + let work = select_hash_expansion_work(&s, 10); + assert_eq!(work.len(), 1); + assert!(work[0].targets.is_empty()); + assert_eq!(work[0].dc_ips, vec!["192.168.58.10"]); + } + + #[test] + fn select_hash_work_skips_dominated_domain() { + let mut s = StateInner::new("op".into()); + s.hosts + .push(make_host("dc01.contoso.local", "192.168.58.10")); + s.hashes + .push(make_ntlm_hash("alice", "aaaaaaaa", "contoso.local")); + s.dominated_domains.insert("contoso.local".into()); + + assert!(select_hash_expansion_work(&s, 10).is_empty()); + } + #[test] fn select_hash_work_excludes_owned_hosts() { let mut s = StateInner::new("op".into()); diff --git a/ares-cli/src/orchestrator/automation/gpp_sysvol.rs b/ares-cli/src/orchestrator/automation/gpp_sysvol.rs index 4608a1fc..bbc9274d 100644 --- a/ares-cli/src/orchestrator/automation/gpp_sysvol.rs +++ b/ares-cli/src/orchestrator/automation/gpp_sysvol.rs @@ -57,6 +57,10 @@ fn collect_gpp_sysvol_work(state: &StateInner) -> Vec { let mut items = Vec::new(); for (domain, dc_ip) in &state.all_domains_with_dcs() { + if state.is_domain_dominated(domain) { + continue; + } + let dedup_key = format!("gpp:{}", domain.to_lowercase()); if state.is_processed(DEDUP_GPP_SYSVOL, &dedup_key) { continue; @@ -298,6 +302,19 @@ mod tests { assert!(work.is_empty()); } + #[test] + fn collect_skips_dominated_domain() { + let mut state = StateInner::new("test".into()); + state + .domain_controllers + .insert("contoso.local".into(), "192.168.58.10".into()); + state.credentials.push(make_cred("admin", "contoso.local")); + state.dominated_domains.insert("contoso.local".into()); + + let work = collect_gpp_sysvol_work(&state); + assert!(work.is_empty()); + } + #[test] fn collect_skips_unrelated_cross_forest_credential() { let mut state = StateInner::new("test".into()); diff --git a/ares-cli/src/orchestrator/automation/pth_spray.rs b/ares-cli/src/orchestrator/automation/pth_spray.rs index cb894176..687a5c5c 100644 --- a/ares-cli/src/orchestrator/automation/pth_spray.rs +++ b/ares-cli/src/orchestrator/automation/pth_spray.rs @@ -15,6 +15,9 @@ use serde_json::json; use tokio::sync::watch; use tracing::{debug, info, warn}; +use super::credential_expansion::{ + domain_is_same_or_relative, resolve_cred_domain, resolve_host_domain, +}; use crate::orchestrator::dispatcher::Dispatcher; use crate::orchestrator::state::*; @@ -106,6 +109,8 @@ fn collect_pth_work(state: &StateInner) -> Option> { h.hash_type.to_lowercase().contains("ntlm") && !h.hash_value.is_empty() && h.hash_value.len() == 32 + && h.username.to_lowercase() != "krbtgt" + && !h.username.ends_with('$') }) .collect(); @@ -129,9 +134,26 @@ fn collect_pth_work(state: &StateInner) -> Option> { if !has_smb { continue; } + let host_domain = resolve_host_domain(state, host); + if !host_domain.is_empty() && state.is_domain_dominated(&host_domain) { + continue; + } // Try each unique NTLM hash against this host for hash in &ntlm_hashes { + let hash_domain = resolve_cred_domain(state, &hash.domain); + let domain = if !hash_domain.is_empty() { + hash_domain.clone() + } else { + host_domain.clone() + }; + if domain.is_empty() { + continue; + } + if !hash_domain.is_empty() && !domain_is_same_or_relative(&host_domain, &hash_domain) { + continue; + } + let dedup_key = format!( "pth:{}:{}:{}", host.ip, @@ -142,16 +164,6 @@ fn collect_pth_work(state: &StateInner) -> Option> { continue; } - // Infer domain from hash or host - let domain = if !hash.domain.is_empty() { - hash.domain.clone() - } else { - host.hostname - .find('.') - .map(|i| host.hostname[i + 1..].to_string()) - .unwrap_or_default() - }; - items.push(PthWork { dedup_key, target_ip: host.ip.clone(), @@ -589,8 +601,7 @@ mod tests { false, )); let work = collect_pth_work(&state).unwrap(); - assert_eq!(work.len(), 1); - assert_eq!(work[0].domain, ""); + assert!(work.is_empty()); } #[test] @@ -617,6 +628,61 @@ mod tests { assert_eq!(work.len(), 4); } + #[test] + fn collect_skips_cross_forest_hash_host_pairs() { + let mut state = StateInner::new("test".into()); + state.hashes.push(make_ntlm_hash( + "admin", + "aad3b435b51404eeaad3b435b51404ee", // pragma: allowlist secret + "sevenkingdoms.local", + )); + state + .hosts + .push(make_smb_host("10.1.2.254", "braavos.essos.local", false)); + + let work = collect_pth_work(&state).unwrap(); + assert!(work.is_empty()); + } + + #[test] + fn collect_filters_machine_and_krbtgt_hashes() { + let mut state = StateInner::new("test".into()); + state.hashes.push(make_ntlm_hash( + "WINTERFELL$", + "aad3b435b51404eeaad3b435b51404ee", // pragma: allowlist secret + "north.sevenkingdoms.local", + )); + state.hashes.push(make_ntlm_hash( + "krbtgt", + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", // pragma: allowlist secret + "north.sevenkingdoms.local", + )); + state.hosts.push(make_smb_host( + "192.168.58.11", + "srv01.north.sevenkingdoms.local", + false, + )); + + assert!(collect_pth_work(&state).is_none()); + } + + #[test] + fn collect_skips_dominated_target_domain() { + let mut state = StateInner::new("test".into()); + state.hashes.push(make_ntlm_hash( + "admin", + "aad3b435b51404eeaad3b435b51404ee", // pragma: allowlist secret + "contoso.local", + )); + state + .hosts + .push(make_smb_host("192.168.58.10", "srv01.contoso.local", false)); + state.dominated_domains.insert("contoso.local".into()); + + let work = collect_pth_work(&state).unwrap(); + assert!(work.is_empty()); + } + #[test] fn collect_dedup_key_lowercases_username() { let mut state = StateInner::new("test".into()); @@ -759,7 +825,7 @@ mod tests { } #[test] - fn collect_hash_domain_preferred_over_hostname_domain() { + fn collect_hash_domain_must_match_hostname_domain() { let mut state = StateInner::new("test".into()); state.hashes.push(make_ntlm_hash( "admin", @@ -772,8 +838,7 @@ mod tests { false, )); let work = collect_pth_work(&state).unwrap(); - // Hash domain takes priority over hostname domain - assert_eq!(work[0].domain, "contoso.local"); + assert!(work.is_empty()); } #[test] diff --git a/ares-cli/src/orchestrator/automation/shares.rs b/ares-cli/src/orchestrator/automation/shares.rs index 5ea12c38..4d734f69 100644 --- a/ares-cli/src/orchestrator/automation/shares.rs +++ b/ares-cli/src/orchestrator/automation/shares.rs @@ -9,6 +9,32 @@ use tracing::{debug, warn}; use crate::orchestrator::dispatcher::Dispatcher; use crate::orchestrator::state::*; +fn resolve_share_host_domain(state: &StateInner, share_host: &str) -> String { + if let Some(host) = state + .hosts + .iter() + .find(|h| h.ip == share_host || h.hostname.eq_ignore_ascii_case(share_host)) + { + if let Some((_, domain)) = host.hostname.to_lowercase().split_once('.') { + return domain.to_string(); + } + } + + if let Some((domain, _)) = state + .domain_controllers + .iter() + .find(|(_, ip)| ip.as_str() == share_host) + { + return domain.to_lowercase(); + } + + share_host + .to_lowercase() + .split_once('.') + .map(|(_, domain)| domain.to_string()) + .unwrap_or_default() +} + /// Select share-spider work items. /// /// Picks the first non-delegation, non-quarantined credential (or any cred @@ -44,6 +70,10 @@ pub(crate) fn select_share_spider_work( let perms = s.permissions.to_uppercase(); perms.contains("READ") && !s.name.to_uppercase().ends_with('$') }) + .filter(|s| { + let domain = resolve_share_host_domain(state, &s.host); + domain.is_empty() || !state.is_domain_dominated(&domain) + }) .filter_map(|s| { let dedup = format!("{}:{}:{}:{}", s.host, s.name, cred.username, cred.domain); if state.is_processed(DEDUP_SPIDERED_SHARES, &dedup) { @@ -128,6 +158,18 @@ mod tests { } } + fn make_host(hostname: &str, ip: &str) -> ares_core::models::Host { + ares_core::models::Host { + ip: ip.to_string(), + hostname: hostname.to_string(), + os: String::new(), + roles: Vec::new(), + services: Vec::new(), + is_dc: false, + owned: false, + } + } + #[test] fn select_share_spider_empty_without_credentials() { let mut s = StateInner::new("op".into()); @@ -184,6 +226,19 @@ mod tests { assert!(select_share_spider_work(&s, 3).is_empty()); } + #[test] + fn select_share_spider_skips_dominated_share_domain() { + let mut s = StateInner::new("op".into()); + s.credentials + .push(make_cred("alice", "Pw", "contoso.local")); + s.hosts + .push(make_host("dc01.contoso.local", "192.168.58.10")); + s.shares.push(make_share("192.168.58.10", "Shared", "READ")); + s.dominated_domains.insert("contoso.local".into()); + + assert!(select_share_spider_work(&s, 3).is_empty()); + } + #[test] fn select_share_spider_caps_at_max_items() { let mut s = StateInner::new("op".into()); diff --git a/ares-cli/src/orchestrator/automation/stall_detection.rs b/ares-cli/src/orchestrator/automation/stall_detection.rs index 3caa0d29..1549b13e 100644 --- a/ares-cli/src/orchestrator/automation/stall_detection.rs +++ b/ares-cli/src/orchestrator/automation/stall_detection.rs @@ -94,6 +94,7 @@ pub(crate) fn select_stall_spray_work( state .domain_controllers .iter() + .filter(|(domain, _)| !state.is_domain_dominated(domain)) .filter(|(domain, _)| !delegation_domains.contains(&domain.to_lowercase())) .filter(|(domain, _)| { let key = stall_spray_dedup_key(domain, recovery_attempts); @@ -115,6 +116,9 @@ pub(crate) fn select_stall_lhf_work( .filter(|c| !c.domain.is_empty() && !c.password.is_empty()) .filter_map(|cred| { let cred_domain = cred.domain.to_lowercase(); + if state.is_domain_dominated(&cred_domain) { + return None; + } let key = stall_lhf_dedup_key(&cred_domain, &cred.username, recovery_attempts); if state.is_processed(DEDUP_EXPANSION_CREDS, &key) { return None; @@ -485,6 +489,16 @@ mod tests { assert_eq!(select_stall_spray_work(&s, 1).len(), 1); } + #[test] + fn select_stall_spray_skips_dominated_domain() { + let mut s = StateInner::new("op".into()); + s.domain_controllers + .insert("contoso.local".into(), "192.168.58.10".into()); + s.dominated_domains.insert("contoso.local".into()); + + assert!(select_stall_spray_work(&s, 0).is_empty()); + } + #[test] fn select_stall_lhf_empty_state() { let s = StateInner::new("op".into()); @@ -514,6 +528,18 @@ mod tests { assert!(select_stall_lhf_work(&s, 0, 5).is_empty()); } + #[test] + fn select_stall_lhf_skips_dominated_domain() { + let mut s = StateInner::new("op".into()); + s.credentials + .push(make_cred("alice", "Pw", "contoso.local")); + s.domain_controllers + .insert("contoso.local".into(), "192.168.58.10".into()); + s.dominated_domains.insert("contoso.local".into()); + + assert!(select_stall_lhf_work(&s, 0, 5).is_empty()); + } + #[test] fn select_stall_lhf_caps_at_max_items() { let mut s = StateInner::new("op".into()); diff --git a/ares-cli/src/orchestrator/mod.rs b/ares-cli/src/orchestrator/mod.rs index 78396707..e07ddc4d 100644 --- a/ares-cli/src/orchestrator/mod.rs +++ b/ares-cli/src/orchestrator/mod.rs @@ -508,6 +508,7 @@ async fn run_inner() -> Result<()> { registry.clone(), tracker.clone(), dispatcher.credential_inflight.clone(), + shared_state.clone(), config.clone(), shutdown_rx.clone(), ); diff --git a/ares-cli/src/orchestrator/monitoring.rs b/ares-cli/src/orchestrator/monitoring.rs index 08c93193..9bd9c9b3 100644 --- a/ares-cli/src/orchestrator/monitoring.rs +++ b/ares-cli/src/orchestrator/monitoring.rs @@ -15,6 +15,7 @@ use tracing::{debug, info, warn}; use crate::orchestrator::config::OrchestratorConfig; use crate::orchestrator::dispatcher::CredentialInflight; use crate::orchestrator::routing::ActiveTaskTracker; +use crate::orchestrator::state::SharedState; use crate::orchestrator::task_queue::TaskQueue; /// Live state for a registered agent. @@ -195,6 +196,7 @@ pub fn spawn_heartbeat_monitor( registry: AgentRegistry, tracker: ActiveTaskTracker, credential_inflight: CredentialInflight, + state: SharedState, config: Arc, mut shutdown: watch::Receiver, ) -> tokio::task::JoinHandle<()> { @@ -230,7 +232,7 @@ pub fn spawn_heartbeat_monitor( // Clean up stale tasks (salvage any pending results first) if let Err(e) = - cleanup_stale_tasks(&tracker, &queue, &credential_inflight, &config).await + cleanup_stale_tasks(&tracker, &queue, &credential_inflight, &state, &config).await { warn!(err = %e, "Stale task cleanup failed"); } @@ -287,6 +289,7 @@ async fn cleanup_stale_tasks( tracker: &ActiveTaskTracker, queue: &TaskQueue, credential_inflight: &CredentialInflight, + state: &SharedState, config: &OrchestratorConfig, ) -> Result<()> { let llm_count = tracker.llm_task_count().await; @@ -332,6 +335,33 @@ async fn cleanup_stale_tasks( credential_inflight.release(key).await; } } + + let age_secs = task.submitted_at.elapsed().as_secs(); + let reason = format!("stale task evicted after {age_secs}s without a result"); + + if let Err(e) = queue.set_task_status(&task.task_id, "failed").await { + warn!( + task_id = %task.task_id, + err = %e, + "Failed to mark stale task status as failed" + ); + } + + let result = ares_core::models::TaskResult { + task_id: task.task_id.clone(), + success: false, + result: None, + error: Some(reason), + worker_pod: None, + completed_at: Utc::now(), + }; + if let Err(e) = state.complete_task(queue, &task.task_id, result).await { + warn!( + task_id = %task.task_id, + err = %e, + "Failed to move stale task from pending to completed" + ); + } } if !stale.is_empty() { diff --git a/ares-cli/src/orchestrator/state/inner.rs b/ares-cli/src/orchestrator/state/inner.rs index 5cb67a83..c268de0e 100644 --- a/ares-cli/src/orchestrator/state/inner.rs +++ b/ares-cli/src/orchestrator/state/inner.rs @@ -579,6 +579,30 @@ impl StateInner { d } + /// Return true when this exact domain is already dominated. + /// + /// This intentionally avoids forest-root inference: a child-domain krbtgt + /// should not suppress work in an undominated parent domain. NetBIOS names + /// are resolved through `netbios_to_fqdn` when available. + pub fn is_domain_dominated(&self, domain: &str) -> bool { + let raw = domain.to_lowercase(); + if raw.is_empty() { + return false; + } + let normalized = if raw.contains('.') { + raw + } else { + self.netbios_to_fqdn + .get(&raw) + .or_else(|| self.netbios_to_fqdn.get(&domain.to_uppercase())) + .map(|fqdn| fqdn.to_lowercase()) + .unwrap_or(raw) + }; + self.dominated_domains + .iter() + .any(|d| d.eq_ignore_ascii_case(&normalized)) + } + /// Check if a dedup key exists in the named set. pub fn is_processed(&self, set_name: &str, key: &str) -> bool { self.dedup @@ -1080,6 +1104,22 @@ mod tests { assert!(state.all_forests_dominated()); } + #[test] + fn is_domain_dominated_exact_and_netbios_only() { + let mut state = StateInner::new("op-1".into()); + state + .netbios_to_fqdn + .insert("north".into(), "north.sevenkingdoms.local".into()); + state + .dominated_domains + .insert("north.sevenkingdoms.local".into()); + + assert!(state.is_domain_dominated("north.sevenkingdoms.local")); + assert!(state.is_domain_dominated("NORTH")); + assert!(!state.is_domain_dominated("sevenkingdoms.local")); + assert!(!state.is_domain_dominated("")); + } + #[test] fn user_quarantine_basic() { let mut state = StateInner::new("op-1".into());