diff --git a/README.md b/README.md index b2d83c08..e4123c44 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,15 @@ # Ares - Autonomous Security Operations Agent + +
+ +[![Pre-Commit](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/pre-commit.yaml) +[![Renovate](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml/badge.svg)](https://github.com/dreadnode/python-template/actions/workflows/renovate.yaml) +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +
+ + [![Tests](https://github.com/dreadnode/ares/actions/workflows/tests.yaml/badge.svg)](https://github.com/dreadnode/ares/actions/workflows/tests.yaml) [![Coverage](https://raw.githubusercontent.com/dreadnode/ares/main/.github/badges/coverage.svg)](https://github.com/dreadnode/ares/actions/workflows/coverage-badge.yaml) [![Pre-Commit](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/dreadnode/ares/actions/workflows/pre-commit.yaml) diff --git a/src/ares/core/models.py b/src/ares/core/models.py index 68817f08..d352688b 100644 --- a/src/ares/core/models.py +++ b/src/ares/core/models.py @@ -1,10 +1,63 @@ -"""Data models for Ares SOC Investigation Agent.""" +"""Data models for Ares SOC Investigation Agent. + +This module provides structured data models for SOC investigations and red team operations, +built on rigging's Model class for automatic serialization and LLM output parsing. + +Example usage for LLM output parsing: + >>> from ares.core.models import Evidence, parse, parse_set + >>> # Parse a single Evidence from LLM response text + >>> evidence, _ = parse(llm_response, Evidence) + >>> # Parse multiple Evidence items + >>> items = [e for e, _ in parse_set(llm_response, Evidence)] + +""" + +from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum, IntEnum from typing import Any +from pydantic import Field, computed_field +from rigging import Model +from rigging.model import element, wrapped + +# Re-export rigging parsing utilities for convenient access +from rigging.parsing import ( + parse, + parse_many, + parse_set, + try_parse, + try_parse_many, + try_parse_set, +) + +__all__ = [ + "Credential", + "Evidence", + "Hash", + "Host", + "InvestigationStage", + "InvestigationState", + "InvestigativeQuestion", + "Model", + "PyramidLevel", + "QuestionSource", + "QuestionState", + "RedTeamState", + "Share", + "Target", + "TimelineEvent", + "User", + "parse", + "parse_many", + "parse_set", + "try_parse", + "try_parse_many", + "try_parse_set", +] + class PyramidLevel(IntEnum): """Levels of the Pyramid of Pain. @@ -57,8 +110,7 @@ class InvestigationStage(Enum): SYNTHESIS = "synthesis" # Generate report -@dataclass -class Evidence: +class Evidence(Model): """A piece of evidence discovered during investigation. Attributes: @@ -82,30 +134,18 @@ class Evidence: source: str timestamp: datetime | None pyramid_level: PyramidLevel - mitre_techniques: list[str] = field(default_factory=list) + mitre_techniques: list[str] = wrapped("mitre-techniques", element(tag="technique", default=[])) confidence: float = 0.5 - metadata: dict[str, Any] = field(default_factory=dict) + metadata: dict[str, str] = Field(default_factory=dict) source_query_id: str | None = None validated: bool = False - def to_dict(self) -> dict: - return { - "id": self.id, - "type": self.type, - "value": self.value, - "source": self.source, - "timestamp": self.timestamp.isoformat() if self.timestamp else None, - "pyramid_level": self.pyramid_level.value, - "mitre_techniques": self.mitre_techniques, - "confidence": self.confidence, - "metadata": self.metadata, - "source_query_id": self.source_query_id, - "validated": self.validated, - } + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for storage (backward compatible).""" + return self.model_dump(mode="json") -@dataclass -class TimelineEvent: +class TimelineEvent(Model): """An event in the investigation timeline. Attributes: @@ -121,25 +161,17 @@ class TimelineEvent: id: str timestamp: datetime description: str - evidence_ids: list[str] = field(default_factory=list) - mitre_techniques: list[str] = field(default_factory=list) + evidence_ids: list[str] = wrapped("evidence-ids", element(tag="evidence-id", default=[])) + mitre_techniques: list[str] = wrapped("mitre-techniques", element(tag="technique", default=[])) confidence: float = 0.5 source: str = "investigation" - def to_dict(self) -> dict: - return { - "id": self.id, - "timestamp": self.timestamp.isoformat(), - "description": self.description, - "evidence_ids": self.evidence_ids, - "mitre_techniques": self.mitre_techniques, - "confidence": self.confidence, - "source": self.source, - } + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for storage (backward compatible).""" + return self.model_dump(mode="json") -@dataclass -class InvestigativeQuestion: +class InvestigativeQuestion(Model): """A question that drives the investigation forward. Generated by the MITRE Navigator and Pyramid Climber engines. @@ -185,19 +217,23 @@ class InvestigativeQuestion: urgency_score: float = 0.0 state: QuestionState = QuestionState.PENDING - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) answered_at: datetime | None = None - generated_from_evidence_ids: list[str] = field(default_factory=list) + generated_from_evidence_ids: list[str] = wrapped( + "generated-from-evidence-ids", element(tag="evidence-id", default=[]) + ) generated_from_question_id: str | None = None - answer_evidence_ids: list[str] = field(default_factory=list) + answer_evidence_ids: list[str] = wrapped( + "answer-evidence-ids", element(tag="evidence-id", default=[]) + ) answer_summary: str | None = None + @computed_field # type: ignore[prop-decorator] @property def priority_score(self) -> float: - """ - Composite priority score. + """Composite priority score. Weights: - Pyramid elevation: 3x (we want TTPs, not IOCs) @@ -212,14 +248,16 @@ def priority_score(self) -> float: + (self.urgency_score * 1.0) ) - def can_parallelize_with(self, other: "InvestigativeQuestion") -> bool: + def can_parallelize_with(self, other: InvestigativeQuestion) -> bool: """Check if this question can run in parallel with another.""" # Questions in a reasoning chain should be sequential if self.generated_from_question_id == other.id: return False return other.generated_from_question_id != self.id - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for storage (backward compatible).""" + # Use custom format to match original API return { "id": self.id, "question": self.text, @@ -341,8 +379,7 @@ def to_summary(self) -> dict: # Red Team Models -@dataclass -class Target: +class Target(Model): """Primary target information.""" ip: str @@ -350,19 +387,17 @@ class Target: domain: str = "" -@dataclass -class Host: +class Host(Model): """Discovered host information.""" ip: str hostname: str = "" os: str = "" - roles: list[str] = field(default_factory=list) - services: list[str] = field(default_factory=list) + roles: list[str] = wrapped("roles", element(tag="role", default=[])) + services: list[str] = wrapped("services", element(tag="service", default=[])) -@dataclass -class User: +class User(Model): """Discovered user account.""" username: str @@ -371,8 +406,7 @@ class User: is_admin: bool = False -@dataclass -class Credential: +class Credential(Model): """Discovered credential.""" username: str @@ -382,8 +416,7 @@ class Credential: is_admin: bool = False -@dataclass -class Hash: +class Hash(Model): """Discovered password hash.""" username: str @@ -393,8 +426,7 @@ class Hash: cracked_password: str = "" -@dataclass -class Share: +class Share(Model): """Discovered SMB share.""" host: str diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..2a34842b --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,397 @@ +"""Tests for rigging Model integration in ares.core.models.""" + +from datetime import datetime, timezone + +import pytest + + +class TestEvidenceModel: + """Tests for Evidence rigging Model.""" + + def test_evidence_creation(self) -> None: + """Test creating Evidence with required fields.""" + from ares.core.models import Evidence, PyramidLevel + + evidence = Evidence( + id="test-001", + type="ip", + value="192.168.1.100", + source="loki_query", + timestamp=datetime.now(timezone.utc), + pyramid_level=PyramidLevel.IP_ADDRESSES, + ) + + assert evidence.id == "test-001" + assert evidence.type == "ip" + assert evidence.value == "192.168.1.100" + assert evidence.pyramid_level == PyramidLevel.IP_ADDRESSES + assert evidence.confidence == 0.5 # default + assert evidence.validated is False # default + + def test_evidence_with_optional_fields(self) -> None: + """Test creating Evidence with optional fields.""" + from ares.core.models import Evidence, PyramidLevel + + evidence = Evidence( + id="test-002", + type="domain", + value="malicious.example.com", + source="dns_query", + timestamp=None, + pyramid_level=PyramidLevel.DOMAIN_NAMES, + mitre_techniques=["T1071", "T1568"], + confidence=0.9, + metadata={"resolver": "8.8.8.8"}, + source_query_id="query-123", + validated=True, + ) + + assert evidence.mitre_techniques == ["T1071", "T1568"] + assert evidence.confidence == 0.9 + assert evidence.metadata == {"resolver": "8.8.8.8"} + assert evidence.validated is True + + def test_evidence_to_dict(self) -> None: + """Test Evidence to_dict serialization.""" + from ares.core.models import Evidence, PyramidLevel + + ts = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc) + evidence = Evidence( + id="test-003", + type="hash", + value="abc123def456", # pragma: allowlist secret + source="file_scan", + timestamp=ts, + pyramid_level=PyramidLevel.HASH_VALUES, + mitre_techniques=["T1027"], + ) + + data = evidence.to_dict() + + assert data["id"] == "test-003" + assert data["type"] == "hash" + assert data["value"] == "abc123def456" # pragma: allowlist secret + assert data["pyramid_level"] == 1 # HASH_VALUES = 1 + assert data["mitre_techniques"] == ["T1027"] + assert data["timestamp"] == "2024-01-15T10:30:00Z" + + def test_evidence_model_dump(self) -> None: + """Test Evidence model_dump method.""" + from ares.core.models import Evidence, PyramidLevel + + evidence = Evidence( + id="test-004", + type="ip", + value="10.0.0.1", + source="firewall", + timestamp=None, + pyramid_level=PyramidLevel.IP_ADDRESSES, + ) + + data = evidence.model_dump() + + assert data["id"] == "test-004" + assert data["pyramid_level"] == PyramidLevel.IP_ADDRESSES + assert data["timestamp"] is None + + def test_evidence_model_validate(self) -> None: + """Test creating Evidence from dict using model_validate.""" + from ares.core.models import Evidence, PyramidLevel + + data = { + "id": "test-005", + "type": "process", + "value": "malware.exe", + "source": "edr", + "timestamp": None, + "pyramid_level": 5, # TOOLS + "confidence": 0.8, + } + + evidence = Evidence.model_validate(data) + + assert evidence.id == "test-005" + assert evidence.pyramid_level == PyramidLevel.TOOLS + assert evidence.confidence == 0.8 + + +class TestTimelineEventModel: + """Tests for TimelineEvent rigging Model.""" + + def test_timeline_event_creation(self) -> None: + """Test creating TimelineEvent.""" + from ares.core.models import TimelineEvent + + ts = datetime.now(timezone.utc) + event = TimelineEvent( + id="event-001", + timestamp=ts, + description="Suspicious outbound connection detected", + ) + + assert event.id == "event-001" + assert event.timestamp == ts + assert event.confidence == 0.5 # default + assert event.source == "investigation" # default + + def test_timeline_event_to_dict(self) -> None: + """Test TimelineEvent to_dict serialization.""" + from ares.core.models import TimelineEvent + + ts = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + event = TimelineEvent( + id="event-002", + timestamp=ts, + description="Lateral movement attempt", + evidence_ids=["ev-001", "ev-002"], + mitre_techniques=["T1021"], + confidence=0.85, + ) + + data = event.to_dict() + + assert data["id"] == "event-002" + assert data["evidence_ids"] == ["ev-001", "ev-002"] + assert data["mitre_techniques"] == ["T1021"] + assert data["confidence"] == 0.85 + + +class TestInvestigativeQuestionModel: + """Tests for InvestigativeQuestion rigging Model.""" + + def test_investigative_question_creation(self) -> None: + """Test creating InvestigativeQuestion.""" + from ares.core.models import InvestigativeQuestion, QuestionSource + + question = InvestigativeQuestion( + id="q-001", + text="What process initiated the connection?", + source=QuestionSource.MITRE_NAVIGATOR, + rationale="Need to identify source process for T1071", + target_insight="Process identification", + ) + + assert question.id == "q-001" + assert question.source == QuestionSource.MITRE_NAVIGATOR + assert question.priority_score == 0.0 # all scores default to 0 + + def test_investigative_question_priority_score(self) -> None: + """Test InvestigativeQuestion priority_score computation.""" + from ares.core.models import InvestigativeQuestion, QuestionSource + + question = InvestigativeQuestion( + id="q-002", + text="What TTP was used?", + source=QuestionSource.PYRAMID_CLIMBER, + rationale="Climb to TTP level", + target_insight="TTP identification", + pyramid_elevation_score=0.8, # 3x weight + mitre_coverage_score=0.5, # 2x weight + confidence_impact_score=0.6, # 2x weight + urgency_score=0.4, # 1x weight + ) + + # Expected: (0.8 * 3) + (0.5 * 2) + (0.6 * 2) + (0.4 * 1) = 2.4 + 1.0 + 1.2 + 0.4 = 5.0 + assert question.priority_score == pytest.approx(5.0) + + def test_investigative_question_to_dict(self) -> None: + """Test InvestigativeQuestion to_dict serialization.""" + from ares.core.models import InvestigativeQuestion, QuestionSource, QuestionState + + question = InvestigativeQuestion( + id="q-003", + text="Which hosts are affected?", + source=QuestionSource.LATERAL_EXPANSION, + rationale="Determine scope", + target_insight="Host enumeration", + state=QuestionState.PENDING, + ) + + data = question.to_dict() + + assert data["id"] == "q-003" + assert data["question"] == "Which hosts are affected?" # Note: 'question' not 'text' + assert data["source"] == "lateral" + assert data["state"] == "pending" + assert "priority_score" in data + + def test_can_parallelize_with(self) -> None: + """Test question parallelization check.""" + from ares.core.models import InvestigativeQuestion, QuestionSource + + q1 = InvestigativeQuestion( + id="q-parent", + text="Parent question", + source=QuestionSource.INITIAL_TRIAGE, + rationale="Start", + target_insight="Initial", + ) + + q2 = InvestigativeQuestion( + id="q-child", + text="Child question", + source=QuestionSource.MITRE_NAVIGATOR, + rationale="Follow-up", + target_insight="Detail", + generated_from_question_id="q-parent", + ) + + q3 = InvestigativeQuestion( + id="q-independent", + text="Independent question", + source=QuestionSource.PYRAMID_CLIMBER, + rationale="Separate", + target_insight="Other", + ) + + # Child depends on parent - cannot parallelize + assert not q2.can_parallelize_with(q1) + assert not q1.can_parallelize_with(q2) + + # Independent questions can parallelize + assert q1.can_parallelize_with(q3) + assert q3.can_parallelize_with(q1) + + +class TestRedTeamModels: + """Tests for Red Team rigging Models.""" + + def test_target_model(self) -> None: + """Test Target model.""" + from ares.core.models import Target + + target = Target(ip="10.0.0.50", hostname="dc01", domain="corp.local") + + assert target.ip == "10.0.0.50" + assert target.hostname == "dc01" + assert target.domain == "corp.local" + + def test_host_model(self) -> None: + """Test Host model.""" + from ares.core.models import Host + + host = Host( + ip="10.0.0.100", + hostname="web01", + os="Windows Server 2019", + roles=["web", "app"], + services=["http", "https", "rdp"], + ) + + assert host.ip == "10.0.0.100" + assert host.roles == ["web", "app"] + assert host.services == ["http", "https", "rdp"] + + def test_credential_model(self) -> None: + """Test Credential model.""" + from ares.core.models import Credential + + cred = Credential( + username="admin", + password="P@ssw0rd", # pragma: allowlist secret + domain="CORP", + source="mimikatz", + is_admin=True, + ) + + assert cred.username == "admin" + assert cred.is_admin is True + assert cred.source == "mimikatz" + + def test_hash_model(self) -> None: + """Test Hash model.""" + from ares.core.models import Hash + + h = Hash( + username="svc_account", + hash_value="aad3b435b51404eeaad3b435b51404ee:31d6cfe0d16ae931b73c59d7e0c089c0", + hash_type="NTLM", + domain="CORP", + ) + + assert h.username == "svc_account" + assert h.hash_type == "NTLM" + + +class TestParsingUtilities: + """Tests for rigging parsing utilities.""" + + def test_parsing_imports(self) -> None: + """Test that parsing utilities are importable from models.""" + from ares.core.models import ( + Model, + parse, + parse_set, + try_parse, + ) + + assert Model is not None + assert callable(parse) + assert callable(parse_set) + assert callable(try_parse) + + def test_try_parse_no_match(self) -> None: + """Test try_parse returns None when no match found.""" + from ares.core.models import Evidence, try_parse + + result = try_parse("This text has no Evidence XML", Evidence) + assert result is None + + def test_model_to_xml(self) -> None: + """Test that models can be serialized to XML.""" + from ares.core.models import Target + + target = Target(ip="192.168.1.1", hostname="test-host") + xml = target.to_xml() + + # Verify XML structure exists (pydantic-xml may use attributes for simple models) + assert "" in xml or "/>" in xml + + +class TestModelValidation: + """Tests for pydantic validation in rigging Models.""" + + def test_evidence_validation_error_missing_field(self) -> None: + """Test that missing required fields raise validation error.""" + from pydantic import ValidationError + + from ares.core.models import Evidence + + with pytest.raises(ValidationError): + Evidence( + id="test", + # missing type, value, source, pyramid_level + ) + + def test_evidence_validation_error_wrong_type(self) -> None: + """Test that wrong field types raise validation error.""" + from pydantic import ValidationError + + from ares.core.models import Evidence + + with pytest.raises(ValidationError): + Evidence( + id="test", + type="ip", + value="192.168.1.1", + source="test", + timestamp=None, + pyramid_level="not-an-int", # should be PyramidLevel + ) + + def test_confidence_accepts_float(self) -> None: + """Test that confidence accepts float values.""" + from ares.core.models import Evidence, PyramidLevel + + evidence = Evidence( + id="test", + type="ip", + value="1.2.3.4", + source="test", + timestamp=None, + pyramid_level=PyramidLevel.IP_ADDRESSES, + confidence=0.95, + ) + + assert evidence.confidence == 0.95