diff --git a/README.md b/README.md
index b2d83c08..e4123c44 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,15 @@
# Ares - Autonomous Security Operations Agent
+
+
+
+[](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml)
+[](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml)
+[](https://opensource.org/licenses/Apache-2.0)
+
+
+
+
[](https://github.com/dreadnode/ares/actions/workflows/tests.yaml)
[](https://github.com/dreadnode/ares/actions/workflows/coverage-badge.yaml)
[](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml)
diff --git a/src/ares/core/models.py b/src/ares/core/models.py
index 68817f08..d352688b 100644
--- a/src/ares/core/models.py
+++ b/src/ares/core/models.py
@@ -1,10 +1,63 @@
-"""Data models for Ares SOC Investigation Agent."""
+"""Data models for Ares SOC Investigation Agent.
+
+This module provides structured data models for SOC investigations and red team operations,
+built on rigging's Model class for automatic serialization and LLM output parsing.
+
+Example usage for LLM output parsing:
+ >>> from ares.core.models import Evidence, parse, parse_set
+ >>> # Parse a single Evidence from LLM response text
+ >>> evidence, _ = parse(llm_response, Evidence)
+ >>> # Parse multiple Evidence items
+ >>> items = [e for e, _ in parse_set(llm_response, Evidence)]
+
+"""
+
+from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum, IntEnum
from typing import Any
+from pydantic import Field, computed_field
+from rigging import Model
+from rigging.model import element, wrapped
+
+# Re-export rigging parsing utilities for convenient access
+from rigging.parsing import (
+ parse,
+ parse_many,
+ parse_set,
+ try_parse,
+ try_parse_many,
+ try_parse_set,
+)
+
+__all__ = [
+ "Credential",
+ "Evidence",
+ "Hash",
+ "Host",
+ "InvestigationStage",
+ "InvestigationState",
+ "InvestigativeQuestion",
+ "Model",
+ "PyramidLevel",
+ "QuestionSource",
+ "QuestionState",
+ "RedTeamState",
+ "Share",
+ "Target",
+ "TimelineEvent",
+ "User",
+ "parse",
+ "parse_many",
+ "parse_set",
+ "try_parse",
+ "try_parse_many",
+ "try_parse_set",
+]
+
class PyramidLevel(IntEnum):
"""Levels of the Pyramid of Pain.
@@ -57,8 +110,7 @@ class InvestigationStage(Enum):
SYNTHESIS = "synthesis" # Generate report
-@dataclass
-class Evidence:
+class Evidence(Model):
"""A piece of evidence discovered during investigation.
Attributes:
@@ -82,30 +134,18 @@ class Evidence:
source: str
timestamp: datetime | None
pyramid_level: PyramidLevel
- mitre_techniques: list[str] = field(default_factory=list)
+ mitre_techniques: list[str] = wrapped("mitre-techniques", element(tag="technique", default=[]))
confidence: float = 0.5
- metadata: dict[str, Any] = field(default_factory=dict)
+ metadata: dict[str, str] = Field(default_factory=dict)
source_query_id: str | None = None
validated: bool = False
- def to_dict(self) -> dict:
- return {
- "id": self.id,
- "type": self.type,
- "value": self.value,
- "source": self.source,
- "timestamp": self.timestamp.isoformat() if self.timestamp else None,
- "pyramid_level": self.pyramid_level.value,
- "mitre_techniques": self.mitre_techniques,
- "confidence": self.confidence,
- "metadata": self.metadata,
- "source_query_id": self.source_query_id,
- "validated": self.validated,
- }
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary for storage (backward compatible)."""
+ return self.model_dump(mode="json")
-@dataclass
-class TimelineEvent:
+class TimelineEvent(Model):
"""An event in the investigation timeline.
Attributes:
@@ -121,25 +161,17 @@ class TimelineEvent:
id: str
timestamp: datetime
description: str
- evidence_ids: list[str] = field(default_factory=list)
- mitre_techniques: list[str] = field(default_factory=list)
+ evidence_ids: list[str] = wrapped("evidence-ids", element(tag="evidence-id", default=[]))
+ mitre_techniques: list[str] = wrapped("mitre-techniques", element(tag="technique", default=[]))
confidence: float = 0.5
source: str = "investigation"
- def to_dict(self) -> dict:
- return {
- "id": self.id,
- "timestamp": self.timestamp.isoformat(),
- "description": self.description,
- "evidence_ids": self.evidence_ids,
- "mitre_techniques": self.mitre_techniques,
- "confidence": self.confidence,
- "source": self.source,
- }
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary for storage (backward compatible)."""
+ return self.model_dump(mode="json")
-@dataclass
-class InvestigativeQuestion:
+class InvestigativeQuestion(Model):
"""A question that drives the investigation forward.
Generated by the MITRE Navigator and Pyramid Climber engines.
@@ -185,19 +217,23 @@ class InvestigativeQuestion:
urgency_score: float = 0.0
state: QuestionState = QuestionState.PENDING
- created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
answered_at: datetime | None = None
- generated_from_evidence_ids: list[str] = field(default_factory=list)
+ generated_from_evidence_ids: list[str] = wrapped(
+ "generated-from-evidence-ids", element(tag="evidence-id", default=[])
+ )
generated_from_question_id: str | None = None
- answer_evidence_ids: list[str] = field(default_factory=list)
+ answer_evidence_ids: list[str] = wrapped(
+ "answer-evidence-ids", element(tag="evidence-id", default=[])
+ )
answer_summary: str | None = None
+ @computed_field # type: ignore[prop-decorator]
@property
def priority_score(self) -> float:
- """
- Composite priority score.
+ """Composite priority score.
Weights:
- Pyramid elevation: 3x (we want TTPs, not IOCs)
@@ -212,14 +248,16 @@ def priority_score(self) -> float:
+ (self.urgency_score * 1.0)
)
- def can_parallelize_with(self, other: "InvestigativeQuestion") -> bool:
+ def can_parallelize_with(self, other: InvestigativeQuestion) -> bool:
"""Check if this question can run in parallel with another."""
# Questions in a reasoning chain should be sequential
if self.generated_from_question_id == other.id:
return False
return other.generated_from_question_id != self.id
- def to_dict(self) -> dict:
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary for storage (backward compatible)."""
+ # Use custom format to match original API
return {
"id": self.id,
"question": self.text,
@@ -341,8 +379,7 @@ def to_summary(self) -> dict:
# Red Team Models
-@dataclass
-class Target:
+class Target(Model):
"""Primary target information."""
ip: str
@@ -350,19 +387,17 @@ class Target:
domain: str = ""
-@dataclass
-class Host:
+class Host(Model):
"""Discovered host information."""
ip: str
hostname: str = ""
os: str = ""
- roles: list[str] = field(default_factory=list)
- services: list[str] = field(default_factory=list)
+ roles: list[str] = wrapped("roles", element(tag="role", default=[]))
+ services: list[str] = wrapped("services", element(tag="service", default=[]))
-@dataclass
-class User:
+class User(Model):
"""Discovered user account."""
username: str
@@ -371,8 +406,7 @@ class User:
is_admin: bool = False
-@dataclass
-class Credential:
+class Credential(Model):
"""Discovered credential."""
username: str
@@ -382,8 +416,7 @@ class Credential:
is_admin: bool = False
-@dataclass
-class Hash:
+class Hash(Model):
"""Discovered password hash."""
username: str
@@ -393,8 +426,7 @@ class Hash:
cracked_password: str = ""
-@dataclass
-class Share:
+class Share(Model):
"""Discovered SMB share."""
host: str
diff --git a/tests/test_models.py b/tests/test_models.py
new file mode 100644
index 00000000..2a34842b
--- /dev/null
+++ b/tests/test_models.py
@@ -0,0 +1,397 @@
+"""Tests for rigging Model integration in ares.core.models."""
+
+from datetime import datetime, timezone
+
+import pytest
+
+
+class TestEvidenceModel:
+ """Tests for Evidence rigging Model."""
+
+ def test_evidence_creation(self) -> None:
+ """Test creating Evidence with required fields."""
+ from ares.core.models import Evidence, PyramidLevel
+
+ evidence = Evidence(
+ id="test-001",
+ type="ip",
+ value="192.168.1.100",
+ source="loki_query",
+ timestamp=datetime.now(timezone.utc),
+ pyramid_level=PyramidLevel.IP_ADDRESSES,
+ )
+
+ assert evidence.id == "test-001"
+ assert evidence.type == "ip"
+ assert evidence.value == "192.168.1.100"
+ assert evidence.pyramid_level == PyramidLevel.IP_ADDRESSES
+ assert evidence.confidence == 0.5 # default
+ assert evidence.validated is False # default
+
+ def test_evidence_with_optional_fields(self) -> None:
+ """Test creating Evidence with optional fields."""
+ from ares.core.models import Evidence, PyramidLevel
+
+ evidence = Evidence(
+ id="test-002",
+ type="domain",
+ value="malicious.example.com",
+ source="dns_query",
+ timestamp=None,
+ pyramid_level=PyramidLevel.DOMAIN_NAMES,
+ mitre_techniques=["T1071", "T1568"],
+ confidence=0.9,
+ metadata={"resolver": "8.8.8.8"},
+ source_query_id="query-123",
+ validated=True,
+ )
+
+ assert evidence.mitre_techniques == ["T1071", "T1568"]
+ assert evidence.confidence == 0.9
+ assert evidence.metadata == {"resolver": "8.8.8.8"}
+ assert evidence.validated is True
+
+ def test_evidence_to_dict(self) -> None:
+ """Test Evidence to_dict serialization."""
+ from ares.core.models import Evidence, PyramidLevel
+
+ ts = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc)
+ evidence = Evidence(
+ id="test-003",
+ type="hash",
+ value="abc123def456", # pragma: allowlist secret
+ source="file_scan",
+ timestamp=ts,
+ pyramid_level=PyramidLevel.HASH_VALUES,
+ mitre_techniques=["T1027"],
+ )
+
+ data = evidence.to_dict()
+
+ assert data["id"] == "test-003"
+ assert data["type"] == "hash"
+ assert data["value"] == "abc123def456" # pragma: allowlist secret
+ assert data["pyramid_level"] == 1 # HASH_VALUES = 1
+ assert data["mitre_techniques"] == ["T1027"]
+ assert data["timestamp"] == "2024-01-15T10:30:00Z"
+
+ def test_evidence_model_dump(self) -> None:
+ """Test Evidence model_dump method."""
+ from ares.core.models import Evidence, PyramidLevel
+
+ evidence = Evidence(
+ id="test-004",
+ type="ip",
+ value="10.0.0.1",
+ source="firewall",
+ timestamp=None,
+ pyramid_level=PyramidLevel.IP_ADDRESSES,
+ )
+
+ data = evidence.model_dump()
+
+ assert data["id"] == "test-004"
+ assert data["pyramid_level"] == PyramidLevel.IP_ADDRESSES
+ assert data["timestamp"] is None
+
+ def test_evidence_model_validate(self) -> None:
+ """Test creating Evidence from dict using model_validate."""
+ from ares.core.models import Evidence, PyramidLevel
+
+ data = {
+ "id": "test-005",
+ "type": "process",
+ "value": "malware.exe",
+ "source": "edr",
+ "timestamp": None,
+ "pyramid_level": 5, # TOOLS
+ "confidence": 0.8,
+ }
+
+ evidence = Evidence.model_validate(data)
+
+ assert evidence.id == "test-005"
+ assert evidence.pyramid_level == PyramidLevel.TOOLS
+ assert evidence.confidence == 0.8
+
+
+class TestTimelineEventModel:
+ """Tests for TimelineEvent rigging Model."""
+
+ def test_timeline_event_creation(self) -> None:
+ """Test creating TimelineEvent."""
+ from ares.core.models import TimelineEvent
+
+ ts = datetime.now(timezone.utc)
+ event = TimelineEvent(
+ id="event-001",
+ timestamp=ts,
+ description="Suspicious outbound connection detected",
+ )
+
+ assert event.id == "event-001"
+ assert event.timestamp == ts
+ assert event.confidence == 0.5 # default
+ assert event.source == "investigation" # default
+
+ def test_timeline_event_to_dict(self) -> None:
+ """Test TimelineEvent to_dict serialization."""
+ from ares.core.models import TimelineEvent
+
+ ts = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
+ event = TimelineEvent(
+ id="event-002",
+ timestamp=ts,
+ description="Lateral movement attempt",
+ evidence_ids=["ev-001", "ev-002"],
+ mitre_techniques=["T1021"],
+ confidence=0.85,
+ )
+
+ data = event.to_dict()
+
+ assert data["id"] == "event-002"
+ assert data["evidence_ids"] == ["ev-001", "ev-002"]
+ assert data["mitre_techniques"] == ["T1021"]
+ assert data["confidence"] == 0.85
+
+
+class TestInvestigativeQuestionModel:
+ """Tests for InvestigativeQuestion rigging Model."""
+
+ def test_investigative_question_creation(self) -> None:
+ """Test creating InvestigativeQuestion."""
+ from ares.core.models import InvestigativeQuestion, QuestionSource
+
+ question = InvestigativeQuestion(
+ id="q-001",
+ text="What process initiated the connection?",
+ source=QuestionSource.MITRE_NAVIGATOR,
+ rationale="Need to identify source process for T1071",
+ target_insight="Process identification",
+ )
+
+ assert question.id == "q-001"
+ assert question.source == QuestionSource.MITRE_NAVIGATOR
+ assert question.priority_score == 0.0 # all scores default to 0
+
+ def test_investigative_question_priority_score(self) -> None:
+ """Test InvestigativeQuestion priority_score computation."""
+ from ares.core.models import InvestigativeQuestion, QuestionSource
+
+ question = InvestigativeQuestion(
+ id="q-002",
+ text="What TTP was used?",
+ source=QuestionSource.PYRAMID_CLIMBER,
+ rationale="Climb to TTP level",
+ target_insight="TTP identification",
+ pyramid_elevation_score=0.8, # 3x weight
+ mitre_coverage_score=0.5, # 2x weight
+ confidence_impact_score=0.6, # 2x weight
+ urgency_score=0.4, # 1x weight
+ )
+
+ # Expected: (0.8 * 3) + (0.5 * 2) + (0.6 * 2) + (0.4 * 1) = 2.4 + 1.0 + 1.2 + 0.4 = 5.0
+ assert question.priority_score == pytest.approx(5.0)
+
+ def test_investigative_question_to_dict(self) -> None:
+ """Test InvestigativeQuestion to_dict serialization."""
+ from ares.core.models import InvestigativeQuestion, QuestionSource, QuestionState
+
+ question = InvestigativeQuestion(
+ id="q-003",
+ text="Which hosts are affected?",
+ source=QuestionSource.LATERAL_EXPANSION,
+ rationale="Determine scope",
+ target_insight="Host enumeration",
+ state=QuestionState.PENDING,
+ )
+
+ data = question.to_dict()
+
+ assert data["id"] == "q-003"
+ assert data["question"] == "Which hosts are affected?" # Note: 'question' not 'text'
+ assert data["source"] == "lateral"
+ assert data["state"] == "pending"
+ assert "priority_score" in data
+
+ def test_can_parallelize_with(self) -> None:
+ """Test question parallelization check."""
+ from ares.core.models import InvestigativeQuestion, QuestionSource
+
+ q1 = InvestigativeQuestion(
+ id="q-parent",
+ text="Parent question",
+ source=QuestionSource.INITIAL_TRIAGE,
+ rationale="Start",
+ target_insight="Initial",
+ )
+
+ q2 = InvestigativeQuestion(
+ id="q-child",
+ text="Child question",
+ source=QuestionSource.MITRE_NAVIGATOR,
+ rationale="Follow-up",
+ target_insight="Detail",
+ generated_from_question_id="q-parent",
+ )
+
+ q3 = InvestigativeQuestion(
+ id="q-independent",
+ text="Independent question",
+ source=QuestionSource.PYRAMID_CLIMBER,
+ rationale="Separate",
+ target_insight="Other",
+ )
+
+ # Child depends on parent - cannot parallelize
+ assert not q2.can_parallelize_with(q1)
+ assert not q1.can_parallelize_with(q2)
+
+ # Independent questions can parallelize
+ assert q1.can_parallelize_with(q3)
+ assert q3.can_parallelize_with(q1)
+
+
+class TestRedTeamModels:
+ """Tests for Red Team rigging Models."""
+
+ def test_target_model(self) -> None:
+ """Test Target model."""
+ from ares.core.models import Target
+
+ target = Target(ip="10.0.0.50", hostname="dc01", domain="corp.local")
+
+ assert target.ip == "10.0.0.50"
+ assert target.hostname == "dc01"
+ assert target.domain == "corp.local"
+
+ def test_host_model(self) -> None:
+ """Test Host model."""
+ from ares.core.models import Host
+
+ host = Host(
+ ip="10.0.0.100",
+ hostname="web01",
+ os="Windows Server 2019",
+ roles=["web", "app"],
+ services=["http", "https", "rdp"],
+ )
+
+ assert host.ip == "10.0.0.100"
+ assert host.roles == ["web", "app"]
+ assert host.services == ["http", "https", "rdp"]
+
+ def test_credential_model(self) -> None:
+ """Test Credential model."""
+ from ares.core.models import Credential
+
+ cred = Credential(
+ username="admin",
+ password="P@ssw0rd", # pragma: allowlist secret
+ domain="CORP",
+ source="mimikatz",
+ is_admin=True,
+ )
+
+ assert cred.username == "admin"
+ assert cred.is_admin is True
+ assert cred.source == "mimikatz"
+
+ def test_hash_model(self) -> None:
+ """Test Hash model."""
+ from ares.core.models import Hash
+
+ h = Hash(
+ username="svc_account",
+ hash_value="aad3b435b51404eeaad3b435b51404ee:31d6cfe0d16ae931b73c59d7e0c089c0",
+ hash_type="NTLM",
+ domain="CORP",
+ )
+
+ assert h.username == "svc_account"
+ assert h.hash_type == "NTLM"
+
+
+class TestParsingUtilities:
+ """Tests for rigging parsing utilities."""
+
+ def test_parsing_imports(self) -> None:
+ """Test that parsing utilities are importable from models."""
+ from ares.core.models import (
+ Model,
+ parse,
+ parse_set,
+ try_parse,
+ )
+
+ assert Model is not None
+ assert callable(parse)
+ assert callable(parse_set)
+ assert callable(try_parse)
+
+ def test_try_parse_no_match(self) -> None:
+ """Test try_parse returns None when no match found."""
+ from ares.core.models import Evidence, try_parse
+
+ result = try_parse("This text has no Evidence XML", Evidence)
+ assert result is None
+
+ def test_model_to_xml(self) -> None:
+ """Test that models can be serialized to XML."""
+ from ares.core.models import Target
+
+ target = Target(ip="192.168.1.1", hostname="test-host")
+ xml = target.to_xml()
+
+ # Verify XML structure exists (pydantic-xml may use attributes for simple models)
+ assert "" in xml or "/>" in xml
+
+
+class TestModelValidation:
+ """Tests for pydantic validation in rigging Models."""
+
+ def test_evidence_validation_error_missing_field(self) -> None:
+ """Test that missing required fields raise validation error."""
+ from pydantic import ValidationError
+
+ from ares.core.models import Evidence
+
+ with pytest.raises(ValidationError):
+ Evidence(
+ id="test",
+ # missing type, value, source, pyramid_level
+ )
+
+ def test_evidence_validation_error_wrong_type(self) -> None:
+ """Test that wrong field types raise validation error."""
+ from pydantic import ValidationError
+
+ from ares.core.models import Evidence
+
+ with pytest.raises(ValidationError):
+ Evidence(
+ id="test",
+ type="ip",
+ value="192.168.1.1",
+ source="test",
+ timestamp=None,
+ pyramid_level="not-an-int", # should be PyramidLevel
+ )
+
+ def test_confidence_accepts_float(self) -> None:
+ """Test that confidence accepts float values."""
+ from ares.core.models import Evidence, PyramidLevel
+
+ evidence = Evidence(
+ id="test",
+ type="ip",
+ value="1.2.3.4",
+ source="test",
+ timestamp=None,
+ pyramid_level=PyramidLevel.IP_ADDRESSES,
+ confidence=0.95,
+ )
+
+ assert evidence.confidence == 0.95