diff --git a/README.md b/README.md index 05064b7..8868d34 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ``` ▄▀█ █ ▀▄▀ - █▀█ █ █ █ v1.0.1 + █▀█ █ █ █ v1.1.0 AI Security Testing Framework ``` @@ -85,8 +85,16 @@ aix chain https://api.target.com/chat -k sk-xxx -P full_compromise # Use with Burp Suite request file aix inject -r request.txt -p "messages[0].content" +# Target a WebSocket endpoint +aix inject ws://api.target.com/ws -k sk-xxx +aix scan wss://api.target.com/ws -k sk-xxx + # Generate HTML report aix db --export report.html + +# View sessions and conversations +aix db --sessions +aix db --conversations ``` --- @@ -414,6 +422,36 @@ The `-p` parameter specifies the JSON path to the injection point. Examples: --- +## WebSocket Support + +AIX supports WebSocket endpoints (`ws://` and `wss://`) natively. Use them exactly like HTTP targets: + +```bash +aix recon ws://api.target.com/chat +aix inject wss://api.target.com/chat -k sk-xxx +aix scan wss://api.target.com/chat -k sk-xxx +``` + +### Chat ID Tracking + +For stateful APIs that return a session or chat ID in the response, AIX can extract and reuse it automatically across requests: + +| Option | Description | +|--------|-------------| +| `--chat-id-path` | Dot-path to extract chat ID from response JSON (e.g., `data.chat_id`) | +| `--chat-id-param` | Request parameter to inject the captured chat ID into | +| `--new-chat` | Force a new conversation for each payload (ignore existing chat ID) | +| `--reuse-chat` | Reuse the same chat ID for all payloads in this run | + +```bash +# Extract chat_id from response and send it back in subsequent requests +aix inject https://api.target.com/chat --chat-id-path data.chat_id --chat-id-param chat_id +``` + +> **Note:** HTTP proxy is not supported for WebSocket connections. SSL verification is disabled for `wss://` (same as other connectors, for use with Burp/ZAP). + +--- + ## Database & Reporting ```bash @@ -431,8 +469,24 @@ aix db --export report.html # Clear database aix db --clear + +# --- Sessions --- +# List all sessions (grouped by target) +aix db --sessions + +# Show results for a specific session +aix db --session + +# --- Conversations --- +# List all recorded conversations (multi-turn) +aix db --conversations + +# Show full transcript for a specific conversation +aix db --conversation ``` +All scan runs are automatically grouped into **sessions** by target. Multi-turn attack transcripts are stored as **conversations** and linked to both their session and individual findings. + --- ## AI-Powered Features diff --git a/TODO.md b/TODO.md index 20c530e..e2ee4fe 100644 --- a/TODO.md +++ b/TODO.md @@ -259,7 +259,7 @@ steps: |-------|--------|--------|----------|--------| | Phase 1: Advanced Attacks | Very High | Medium | **P0** | Multi-Turn ✅ | | Phase 2: Adaptive Testing | Very High | High | **P0** | Planned | -| Phase 3: Attack Chaining | High | Medium | **P1** | Core ✅ | +| Phase 3: Attack Chaining | High | Medium | **P1** | Core ✅, WebSocket ✅ | | Phase 4: Enterprise/CI | High | Medium | **P1** | Planned | | Phase 5: Blue Team | Medium | Medium | **P2** | Planned | | Phase 6: Platform | Medium | High | **P2** | Planned | @@ -277,13 +277,35 @@ steps: --- -*Last Updated: February 8, 2026* +*Last Updated: February 20, 2026* --- ## Recent Changes -### v1.3.0 - AI Context & OWASP Integration +### v1.1.0 - WebSocket Support & Sessions +- Added **WebSocket Connector** (`ws://` / `wss://` targets): + - Full attack module support for WebSocket endpoints + - Configurable JSON message template and response extraction path + - SSL verification disabled for `wss://` (Burp/ZAP compatible) + - Extra headers support for the HTTP upgrade handshake +- Added **Chat ID Tracking**: + - `--chat-id-path`: extract chat/session ID from response via dot-path + - `--chat-id-param`: inject captured ID back into subsequent requests + - `--new-chat` / `--reuse-chat` flags to control conversation continuity + - `{chat_id}` URL placeholder substitution +- Added **Sessions** to the database: + - Scans are automatically grouped into sessions per target + - `sessions` table with status, notes, and modules-run tracking + - `aix db --sessions` to list sessions; `aix db --session ` for results + - `get_or_create_session()` auto-creates a session at the start of each run +- Added **Conversations** to the database: + - Multi-turn transcripts stored as conversations linked to sessions + - `conversations` table with full turn-by-turn transcript (JSON) + - `aix db --conversations` to list; `aix db --conversation ` for transcript +- DB migrations: `session_id` and `conversation_id` columns added to `results` + +### v1.0.2 - AI Context & OWASP Integration - Added **AI Context Gathering** feature: - Probes target to detect purpose, domain, personality, restrictions - New fields: `purpose`, `domain`, `expected_inputs`, `personality` @@ -305,7 +327,7 @@ steps: - Preserved successful attempt reason in `scan_payload()` - Prevents failure reasons from overwriting success reasons with `--verify-attempts` -### v1.2.0 - Attack Chain Module +### v1.0.1 - Attack Chain Module - Added `aix chain` command for executing YAML-defined attack playbooks - Implemented ChainExecutor for orchestrating multi-step attack workflows - Implemented ChainContext for state management and variable interpolation @@ -326,7 +348,7 @@ steps: - Added conditional branching with `on_success`, `on_fail`, and `conditions` - Added variable storage and interpolation across steps -### v1.1.0 - Multi-Turn Attack Module +### v1.0.0 - Multi-Turn Attack Module - Added `aix multiturn` command with 8 attack categories - Implemented ConversationManager for stateful attacks - Implemented TurnEvaluator for response analysis diff --git a/aix/__init__.py b/aix/__init__.py index e521355..8676971 100644 --- a/aix/__init__.py +++ b/aix/__init__.py @@ -10,7 +10,7 @@ aix jailbreak https://chat.company.com """ -__version__ = "1.0.1" +__version__ = "1.1.0" __author__ = "AIX Team" __license__ = "MIT" diff --git a/aix/cli.py b/aix/cli.py index 66b019d..6b978b1 100644 --- a/aix/cli.py +++ b/aix/cli.py @@ -37,6 +37,41 @@ console = Console() + +def _get_session_id(target: str) -> str: + """Get or create a session for the target. Returns session_id.""" + db = AIXDatabase() + session_id = db.get_or_create_session(target) + db.close() + return session_id + + +def _update_session_module(session_id: str, module: str) -> None: + """Register a module run in the session.""" + db = AIXDatabase() + db.update_session_modules(session_id, module) + db.close() + + +def _resolve_chat_id_flags(new_chat: bool, reuse_chat: bool, is_multiturn: bool = False) -> bool: + """Resolve --new-chat / --reuse-chat flags into a single new_chat bool. + + Defaults: single-turn scans → new_chat=True, multiturn → new_chat=False. + Explicit flags override defaults. + """ + if new_chat and reuse_chat: + console.print( + "[yellow][!] Both --new-chat and --reuse-chat specified; using --new-chat[/yellow]" + ) + return True + if new_chat: + return True + if reuse_chat: + return False + # Default based on scan type + return not is_multiturn + + BANNER = f""" [bold cyan] ▄▀█ █ ▀▄▀[/bold cyan] [bold cyan] █▀█ █ █ █[/bold cyan] [dim]v{__version__}[/dim] @@ -89,6 +124,13 @@ def validate_input(target, request, param): console.print(f"[red][-][/red] Error parsing request file: {e}") raise click.Abort() + # Hint: suggest --response-path for WebSocket targets + if target and target.startswith(("ws://", "wss://")) and not parsed_request: + console.print( + "[cyan][*][/cyan] WebSocket target detected. " + "Use [bold]-rp[/bold] (--response-path) to extract a specific JSON field from responses." + ) + return target, parsed_request @@ -161,6 +203,10 @@ def recon_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, key=None, profile=None, ): @@ -196,6 +242,9 @@ def recon_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "recon") + recon.run( target, output=output, @@ -221,6 +270,10 @@ def recon_cmd( risk=risk, show_response=show_response, verify_attempts=verify_attempts, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -275,6 +328,10 @@ def inject_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ Inject - Prompt injection attacks @@ -309,6 +366,9 @@ def inject_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "inject") + inject.run( target=target, api_key=key, @@ -339,6 +399,10 @@ def inject_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -388,6 +452,10 @@ def jailbreak_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ Jailbreak - Bypass AI restrictions @@ -418,6 +486,9 @@ def jailbreak_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "jailbreak") + jailbreak.run( target=target, api_key=key, @@ -446,6 +517,10 @@ def jailbreak_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -493,6 +568,10 @@ def extract_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ Extract - System prompt extraction @@ -524,6 +603,9 @@ def extract_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "extract") + extract.run( target=target, api_key=key, @@ -551,6 +633,10 @@ def extract_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -598,6 +684,10 @@ def leak_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ Leak - Training data extraction @@ -629,6 +719,9 @@ def leak_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "leak") + leak.run( target=target, api_key=key, @@ -656,6 +749,10 @@ def leak_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -699,6 +796,10 @@ def exfil_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, refresh_url=None, refresh_regex=None, refresh_param=None, @@ -736,6 +837,9 @@ def exfil_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "exfil") + exfil.run( target=target, api_key=key, @@ -756,6 +860,10 @@ def exfil_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -803,6 +911,10 @@ def agent_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ Agent - AI agent exploitation @@ -834,6 +946,9 @@ def agent_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "agent") + agent.run( target=target, api_key=key, @@ -861,6 +976,10 @@ def agent_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -908,6 +1027,10 @@ def dos_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ DoS - Denial of Service testing @@ -937,6 +1060,9 @@ def dos_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "dos") + dos.run( target=target, api_key=key, @@ -962,6 +1088,10 @@ def dos_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -1011,6 +1141,10 @@ def fuzz_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ Fuzz - Fuzzing and edge cases @@ -1040,6 +1174,9 @@ def fuzz_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "fuzz") + fuzz.run( target=target, api_key=key, @@ -1067,6 +1204,10 @@ def fuzz_cmd( risk=risk, show_response=show_response, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -1114,6 +1255,10 @@ def memory_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, ): """ Memory - Memory and context manipulation attacks @@ -1147,6 +1292,9 @@ def memory_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "memory") + memory.run( target=target, api_key=key, @@ -1174,6 +1322,10 @@ def memory_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, ) @@ -1245,6 +1397,10 @@ def rag_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, canary, category, ): @@ -1303,6 +1459,9 @@ def rag_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "rag") + rag.run( target=target, api_key=key, @@ -1330,6 +1489,10 @@ def rag_cmd( show_response=show_response, verify_attempts=verify_attempts, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat), + session_id=session_id, canary=canary, category=category, ) @@ -1400,6 +1563,10 @@ def multiturn_cmd( risk, show_response, verify_attempts, + chat_id_path, + chat_id_param, + new_chat, + reuse_chat, category, max_turns, turn_delay, @@ -1450,6 +1617,9 @@ def multiturn_cmd( "enable_context": not no_context, } + session_id = _get_session_id(target) + _update_session_module(session_id, "multiturn") + multiturn.run( target=target, api_key=key, @@ -1480,6 +1650,10 @@ def multiturn_cmd( max_turns=max_turns, turn_delay=turn_delay, generate=generate, + chat_id_path=chat_id_path, + chat_id_param=chat_id_param, + new_chat=_resolve_chat_id_flags(new_chat, reuse_chat, is_multiturn=True), + session_id=session_id, ) @@ -1494,7 +1668,11 @@ def multiturn_cmd( @click.option("--clear", is_flag=True, help="Clear all results") @click.option("--target", "-t", help="Filter by target") @click.option("--module", "-m", help="Filter by module") -def db(export, clear, target, module): +@click.option("--sessions", is_flag=True, help="List all sessions") +@click.option("--session", "session_id", help="Show results for a specific session ID") +@click.option("--conversations", is_flag=True, help="List all conversations") +@click.option("--conversation", "conversation_id", help="Show full conversation transcript") +def db(export, clear, target, module, sessions, session_id, conversations, conversation_id): """ Database - View and manage results @@ -1504,6 +1682,11 @@ def db(export, clear, target, module): aix db --export report.html aix db --target company.com aix db --clear + aix db --sessions + aix db --session + aix db --session --export report.html + aix db --conversations + aix db --conversation """ print_banner() @@ -1515,11 +1698,36 @@ def db(export, clear, target, module): console.print("[green][+][/green] Database cleared") return + if conversations: + convs = db.list_conversations(target=target) + db.display_conversations(convs) + return + + if conversation_id: + db.display_conversation_transcript(conversation_id) + return + + if sessions: + db.display_sessions() + return + if export: - db.export_html(export, target=target, module=module) + db.export_html(export, target=target, module=module, session_id=session_id) console.print(f"[green][+][/green] Report exported: {export}") return + if session_id: + # Show results for specific session + results = db.get_session_results(session_id) + session = db.get_session(session_id) + if session: + console.print(f"[cyan]Session:[/cyan] {session.get('name', 'N/A')} ({session_id[:8]})") + console.print(f"[cyan]Target:[/cyan] {session.get('target', 'N/A')}") + console.print(f"[cyan]Modules:[/cyan] {', '.join(session.get('modules_run', []))}") + console.print() + db.display_results(results) + return + # Show results results = db.get_results(target=target, module=module) db.display_results(results) @@ -1854,7 +2062,16 @@ def scan( _set_proxy_env(proxy) target, parsed_request = validate_input(target, request, param) + # Always create a fresh session for full scan + scan_db = AIXDatabase() + session_id = scan_db.create_session( + target=target, + name=f"Full Scan - {target[:30]}", + ) + scan_db.close() + console.print("[bold cyan][*][/bold cyan] Starting comprehensive scan...") + console.print(f"[dim]Session: {session_id[:8]}[/dim]") console.print() # Run all modules @@ -1871,6 +2088,7 @@ def scan( ] for name, module in modules_to_run: + _update_session_module(session_id, name) console.print(f"[bold cyan][*][/bold cyan] Running {name} module...") try: module.run( @@ -1903,13 +2121,21 @@ def scan( risk=risk, show_response=show_response, verify_attempts=verify_attempts, + session_id=session_id, ) except Exception as e: console.print(f"[red][-][/red] {name} failed: {e}") console.print() + # End session + end_db = AIXDatabase() + end_db.end_session(session_id, status="completed") + end_db.close() + console.print("[bold green][+][/bold green] Scan complete!") - console.print("[dim]Run 'aix db --export report.html' to generate report[/dim]") + console.print( + f"[dim]Session: {session_id[:8]} | Run 'aix db --session {session_id[:8]} --export report.html' to generate report[/dim]" + ) if __name__ == "__main__": diff --git a/aix/core/connector.py b/aix/core/connector.py index 5c5a20d..bc6cd27 100644 --- a/aix/core/connector.py +++ b/aix/core/connector.py @@ -56,6 +56,11 @@ def __init__(self, url: str, profile=None, console=None, **kwargs): _global_console = Console() self.console = console or _global_console + # Chat ID tracking + self.chat_id_path = kwargs.get("chat_id_path") + self.chat_id_param = kwargs.get("chat_id_param") + self._current_chat_id: str | None = None + def _parse_cookies(self, cookies: str | None) -> dict[str, str]: """Parse cookie string into dictionary""" if not cookies: @@ -80,6 +85,39 @@ def _parse_headers(self, headers: str | None) -> dict[str, str]: header_dict[key.strip()] = value.strip() return header_dict + def _navigate_json_path(self, data: Any, path: str) -> Any: + """Navigate a dot-separated JSON path (e.g., 'data.chat_id').""" + result = data + for key in path.split("."): + if isinstance(result, list): + if key.isdigit() and int(key) < len(result): + result = result[int(key)] + else: + return None + elif isinstance(result, dict): + result = result.get(key) + else: + return None + if result is None: + return None + return result + + def _capture_chat_id(self, response_data: dict) -> None: + """Extract chat_id from response if chat_id_path is configured.""" + if not self.chat_id_path: + return + chat_id = self._navigate_json_path(response_data, self.chat_id_path) + if chat_id: + self._current_chat_id = str(chat_id) + + def reset_chat_id(self) -> None: + """Clear captured chat ID (for fresh conversation).""" + self._current_chat_id = None + + @property + def current_chat_id(self) -> str | None: + return self._current_chat_id + @abstractmethod async def connect(self) -> None: """Establish connection to target""" @@ -292,6 +330,10 @@ def _build_payload(self, message: str) -> dict[str, Any]: if key not in payload: payload[key] = value + # Inject captured chat ID if available + if self._current_chat_id and self.chat_id_param and "{chat_id}" not in self.url: + payload[self.chat_id_param] = self._current_chat_id + return payload def _extract_response(self, data: dict[str, Any]) -> str: @@ -501,6 +543,7 @@ async def send_with_messages(self, messages: list[dict]) -> str: response = await self.client.post(url, json=body, headers=headers) response.raise_for_status() data = response.json() + self._capture_chat_id(data) return self._extract_response(data) except httpx.HTTPStatusError as e: @@ -535,6 +578,10 @@ async def send(self, payload: str) -> str: else: url = self.url.rstrip("/") + endpoint + # Support {chat_id} URL placeholder + if self._current_chat_id and "{chat_id}" in url: + url = url.replace("{chat_id}", self._current_chat_id) + # Use profile URL if available if self.profile and self.profile.endpoint: url = self.profile.url.rstrip("/") + self.profile.endpoint @@ -584,6 +631,7 @@ async def send(self, payload: str) -> str: response.raise_for_status() data = response.json() + self._capture_chat_id(data) return self._extract_response(data) except httpx.HTTPStatusError as e: @@ -629,35 +677,169 @@ async def close(self) -> None: class WebSocketConnector(Connector): """ WebSocket connector for real-time chat interfaces. + + Sends JSON messages over WebSocket and extracts responses + using configurable JSON path or regex. Supports cookies + and custom headers via the HTTP upgrade handshake. """ def __init__(self, url: str, profile=None, **kwargs): - super().__init__(url, profile, **kwargs) + console = kwargs.pop("console", None) + super().__init__(url, profile, console=console, **kwargs) self.ws = None - self.message_format = kwargs.get("message_format", lambda m: json.dumps({"message": m})) - self.response_parser = kwargs.get( - "response_parser", lambda r: json.loads(r).get("response", r) - ) + self.injection_param = kwargs.get("injection_param", "message") + self.response_path = kwargs.get("response_path") + self.response_regex = kwargs.get("response_regex") + self.timeout = kwargs.get("timeout", 30) + self.verbose = kwargs.get("verbose", 0) + self.cookies = kwargs.get("cookies") + self.headers = kwargs.get("headers") + self.proxy = kwargs.get("proxy") + + def _build_extra_headers(self) -> dict[str, str]: + """Build extra headers for the WebSocket HTTP upgrade handshake.""" + extra = {} + + # Cookies -> Cookie header + cookie_dict = self._parse_cookies(self.cookies) + if cookie_dict: + extra["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookie_dict.items()) + + # Custom headers + header_dict = self._parse_headers(self.headers) + if header_dict: + extra.update(header_dict) + + return extra + + def _build_message(self, payload: str) -> str: + """Build JSON message from payload.""" + msg = {self.injection_param: payload} + # Inject captured chat ID if available + if self._current_chat_id and self.chat_id_param: + msg[self.chat_id_param] = self._current_chat_id + return json.dumps(msg) + + def _extract_response(self, raw: str) -> str: + """Parse JSON response and extract text using response_path or fallback keys.""" + try: + data = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return self._apply_regex(raw) + + extracted = "" + + if self.response_path: + extracted = str(self._navigate_path(data, self.response_path)) + else: + # Try common chatbot response keys + for key in ("content", "response", "text", "message", "answer", "reply"): + if isinstance(data, dict) and key in data: + extracted = str(data[key]) + break + + if not extracted: + extracted = json.dumps(data) if isinstance(data, dict) else str(data) + + return self._apply_regex(extracted) + + def _navigate_path(self, data: Any, path: str) -> Any: + """Navigate dot-separated path in JSON data (e.g. 'data.content').""" + result = data + for key in path.split("."): + if isinstance(result, list): + if key.isdigit() and int(key) < len(result): + result = result[int(key)] + else: + return "" + elif isinstance(result, dict): + result = result.get(key, "") + else: + return str(result) + if not result: + break + return result async def connect(self) -> None: - """Connect to WebSocket""" - import websockets + """Connect to WebSocket endpoint.""" + try: + import websockets + except ImportError: + raise ImportError( + "websockets library required for WebSocket connections. " + "Install with: pip install aix-framework[full]" + ) + + if self.proxy: + self.console.print( + "[yellow][!][/yellow] HTTP proxy is not supported for WebSocket connections" + ) + + extra_headers = self._build_extra_headers() - self.ws = await websockets.connect(self.url) + # Disable SSL verification for wss:// (same as other connectors for Burp/ZAP) + import ssl + + ssl_context = None + if self.url.startswith("wss://"): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + self.ws = await websockets.connect( + self.url, + additional_headers=extra_headers, + ssl=ssl_context, + open_timeout=self.timeout, + close_timeout=self.timeout, + ) async def send(self, payload: str) -> str: - """Send message through WebSocket""" + """Send message and wait for response.""" if not self.ws: await self.connect() - await self.ws.send(self.message_format(payload)) - response = await self.ws.recv() - return self.response_parser(response) + message = self._build_message(payload) + + if self.verbose >= 3: + self.console.print(f"[cyan]WS-CONN[/cyan] [*] Sending: {message[:200]}") + + try: + await self.ws.send(message) + raw = await asyncio.wait_for(self.ws.recv(), timeout=self.timeout) + except asyncio.TimeoutError: + raise ConnectionError(f"WebSocket recv() timed out after {self.timeout}s") + except Exception as e: + # Auto-reconnect once on connection closed + if "close" in str(e).lower() or "closed" in str(e).lower(): + self.ws = None + await self.connect() + await self.ws.send(message) + raw = await asyncio.wait_for(self.ws.recv(), timeout=self.timeout) + else: + raise ConnectionError(f"WebSocket error: {e!s}") + + if self.verbose >= 3: + self.console.print(f"[cyan]WS-CONN[/cyan] [*] Received: {raw[:200]}") + + # Capture chat ID from WebSocket response + try: + ws_data = json.loads(raw) + if isinstance(ws_data, dict): + self._capture_chat_id(ws_data) + except (json.JSONDecodeError, TypeError): + pass + + return self._extract_response(raw) async def close(self) -> None: - """Close WebSocket""" + """Close WebSocket connection.""" if self.ws: - await self.ws.close() + try: + await self.ws.close() + except Exception: + pass + self.ws = None class InterceptConnector(Connector): @@ -897,6 +1079,7 @@ async def send(self, payload: str) -> str: # Try to parse JSON response try: data = response.json() + self._capture_chat_id(data) return self._extract_response(data) except json.JSONDecodeError: return self._apply_regex(response.text) diff --git a/aix/core/conversation.py b/aix/core/conversation.py index 66aa5b8..73a6083 100644 --- a/aix/core/conversation.py +++ b/aix/core/conversation.py @@ -10,6 +10,7 @@ import asyncio import re +import uuid from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING @@ -66,6 +67,7 @@ class ConversationState: retries: int = 0 max_retries: int = 2 branch_taken: str | None = None + conversation_id: str = "" @dataclass @@ -82,6 +84,7 @@ class SequenceResult: matched_indicators: list[str] variables_extracted: dict error: str | None = None + conversation_id: str = "" class ConversationManager: @@ -120,7 +123,7 @@ def __init__( def reset(self): """Reset conversation state for new sequence.""" - self.state = ConversationState() + self.state = ConversationState(conversation_id=str(uuid.uuid4())) def _build_messages(self) -> list[dict]: """ @@ -409,11 +412,23 @@ async def execute_sequence(self, sequence: dict) -> SequenceResult: console.print(f"[red]Error at turn {turn_num}: {e}[/red]") break - # Copy extracted variables + # Copy extracted variables and conversation ID result.variables_extracted = dict(self.state.variables) + result.conversation_id = self.state.conversation_id return result + def get_transcript_as_dicts(self) -> list[dict]: + """Get conversation history as a list of serializable dicts for DB storage.""" + return [ + { + "role": turn.role, + "content": turn.content, + "turn_number": turn.turn_number, + } + for turn in self.state.history + ] + def get_conversation_transcript(self) -> str: """ Get a formatted transcript of the conversation. diff --git a/aix/core/reporting/base.py b/aix/core/reporting/base.py index ba1bb55..284809a 100644 --- a/aix/core/reporting/base.py +++ b/aix/core/reporting/base.py @@ -20,6 +20,89 @@ console = Console() +# Severity weights for risk score calculation +SEVERITY_WEIGHTS = { + "critical": 10, + "high": 7, + "medium": 4, + "low": 1, + "info": 0, +} + +# OWASP LLM Top 10 remediation recommendations +OWASP_REMEDIATION = { + "LLM01": { + "title": "Prompt Injection", + "recommendation": "Implement input validation, use delimiter-based prompt structures, apply least-privilege principles for LLM actions, and consider prompt isolation techniques.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM02": { + "title": "Insecure Output Handling", + "recommendation": "Treat all LLM output as untrusted. Apply output encoding, validate and sanitize responses before rendering, and never execute LLM output directly.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM03": { + "title": "Training Data Poisoning", + "recommendation": "Vet training data sources, implement data sanitization pipelines, use anomaly detection on training data, and maintain data provenance records.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM04": { + "title": "Model Denial of Service", + "recommendation": "Implement rate limiting, set token/response size limits, use input length validation, and deploy resource monitoring with auto-scaling.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM05": { + "title": "Supply Chain Vulnerabilities", + "recommendation": "Audit third-party model sources, verify model integrity, maintain software bill of materials (SBOM), and use signed model artifacts.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM06": { + "title": "Sensitive Information Disclosure", + "recommendation": "Implement output filtering for PII/secrets, use data classification, apply redaction rules, and restrict training data scope.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM07": { + "title": "Insecure Plugin Design", + "recommendation": "Apply strict input validation for plugins, enforce least-privilege access, require user confirmation for sensitive actions, and sandbox plugin execution.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM08": { + "title": "Excessive Agency", + "recommendation": "Limit LLM permissions and tool access, implement human-in-the-loop for sensitive operations, log all agent actions, and enforce scope boundaries.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM09": { + "title": "Overreliance", + "recommendation": "Implement output verification mechanisms, add confidence scoring, require human review for critical decisions, and clearly communicate AI limitations to users.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, + "LLM10": { + "title": "Model Theft", + "recommendation": "Implement access controls, use watermarking, monitor for extraction attempts, rate-limit API access, and restrict model output verbosity.", + "references": [ + "https://owasp.org/www-project-top-10-for-large-language-model-applications/" + ], + }, +} + class Severity(Enum): """Severity levels for findings""" @@ -61,6 +144,19 @@ def to_dict(self) -> dict[str, Any]: } +@dataclass +class ScanMetadata: + """Metadata about the scan session for reports.""" + + session_id: str | None = None + session_name: str | None = None + target: str = "" + start_time: datetime | None = None + end_time: datetime | None = None + modules_run: list[str] = field(default_factory=list) + risk_score: float = 0.0 + + class Reporter: """ Handles output formatting and report generation. @@ -70,6 +166,7 @@ def __init__(self): self.findings: list[Finding] = [] self.start_time: datetime | None = None self.end_time: datetime | None = None + self.metadata: ScanMetadata | None = None def start(self) -> None: """Mark scan start time""" @@ -83,6 +180,104 @@ def add_finding(self, finding: Finding) -> None: """Add a finding""" self.findings.append(finding) + def calculate_risk_score(self) -> float: + """Calculate overall risk score (0-10 scale).""" + total_weight = sum(SEVERITY_WEIGHTS.get(f.severity.value, 0) for f in self.findings) + return min(10.0, total_weight / 5.0) + + def get_risk_level(self, score: float) -> str: + """Classify risk level from score.""" + if score >= 8: + return "Critical" + elif score >= 5: + return "High" + elif score >= 2: + return "Medium" + else: + return "Low" + + def get_owasp_coverage(self) -> dict[str, dict[str, Any]]: + """Build OWASP LLM Top 10 coverage map.""" + all_categories = [f"LLM{i:02d}" for i in range(1, 11)] + coverage: dict[str, dict[str, Any]] = {} + + for cat_id in all_categories: + coverage[cat_id] = { + "title": OWASP_REMEDIATION.get(cat_id, {}).get("title", "Unknown"), + "tested": False, + "findings_count": 0, + "max_severity": "info", + } + + severity_rank = {"critical": 0, "high": 1, "medium": 2, "low": 3, "info": 4} + + for finding in self.findings: + if finding.owasp: + for cat in finding.owasp: + cat_id = cat.id if hasattr(cat, "id") else str(cat) + if cat_id in coverage: + coverage[cat_id]["tested"] = True + coverage[cat_id]["findings_count"] += 1 + current_max = coverage[cat_id]["max_severity"] + if severity_rank.get(finding.severity.value, 4) < severity_rank.get( + current_max, 4 + ): + coverage[cat_id]["max_severity"] = finding.severity.value + + return coverage + + def generate_executive_summary(self) -> str: + """Generate executive summary text.""" + if not self.findings: + return "No vulnerabilities were identified during this assessment. The target appears to have adequate security controls in place for the tested attack vectors." + + risk_score = self.calculate_risk_score() + risk_level = self.get_risk_level(risk_score) + + counts = dict.fromkeys(Severity, 0) + for f in self.findings: + counts[f.severity] += 1 + + total = len(self.findings) + parts = [] + + parts.append( + f"This assessment identified {total} {'vulnerabilities' if total != 1 else 'vulnerability'} " + f"with an overall risk score of {risk_score:.1f}/10 ({risk_level})." + ) + + severity_parts = [] + if counts[Severity.CRITICAL]: + severity_parts.append(f"{counts[Severity.CRITICAL]} critical") + if counts[Severity.HIGH]: + severity_parts.append(f"{counts[Severity.HIGH]} high") + if counts[Severity.MEDIUM]: + severity_parts.append(f"{counts[Severity.MEDIUM]} medium") + if counts[Severity.LOW]: + severity_parts.append(f"{counts[Severity.LOW]} low") + + if severity_parts: + parts.append(f"Breakdown: {', '.join(severity_parts)} severity findings.") + + if risk_score >= 8: + parts.append( + "Immediate remediation is strongly recommended. The target is highly vulnerable to AI-specific attacks." + ) + elif risk_score >= 5: + parts.append( + "Significant vulnerabilities were found. Prioritize remediation of high and critical findings." + ) + elif risk_score >= 2: + parts.append( + "Moderate risk detected. Review and address findings based on severity and business impact." + ) + else: + parts.append( + "Low risk detected. Minor findings should be reviewed as part of regular security maintenance." + ) + + return " ".join(parts) + def print_finding(self, finding: Finding) -> None: """Print a finding to console""" severity_colors = { @@ -144,20 +339,32 @@ def print_summary(self) -> None: console.print(table) def export_json(self, filepath: str) -> None: - """Export findings to JSON""" - data = { + """Export findings to JSON with metadata, OWASP coverage, and executive summary""" + risk_score = self.calculate_risk_score() + + data: dict[str, Any] = { "scan_info": { "start_time": self.start_time.isoformat() if self.start_time else None, "end_time": self.end_time.isoformat() if self.end_time else None, "total_findings": len(self.findings), + "risk_score": risk_score, + "risk_level": self.get_risk_level(risk_score), }, + "executive_summary": self.generate_executive_summary(), + "owasp_coverage": self.get_owasp_coverage(), "findings": [f.to_dict() for f in self.findings], } + if self.metadata: + data["scan_info"]["session_id"] = self.metadata.session_id + data["scan_info"]["session_name"] = self.metadata.session_name + data["scan_info"]["target"] = self.metadata.target + data["scan_info"]["modules_run"] = self.metadata.modules_run + Path(filepath).write_text(json.dumps(data, indent=2)) def export_html(self, filepath: str) -> None: - """Export findings to HTML report""" + """Export findings to enhanced HTML report""" # Count findings by severity counts = dict.fromkeys(Severity, 0) @@ -185,11 +392,110 @@ def export_html(self, filepath: str) -> None: for target in findings_by_target: findings_by_target[target].sort(key=lambda f: severity_order.get(f.severity, 99)) + # Risk score and executive summary + risk_score = self.calculate_risk_score() + risk_level = self.get_risk_level(risk_score) + executive_summary = self.generate_executive_summary() + owasp_coverage = self.get_owasp_coverage() + + # Build metadata header HTML + metadata_html = "" + if self.metadata: + meta = self.metadata + duration = "" + if meta.start_time and meta.end_time: + delta = meta.end_time - meta.start_time + minutes = int(delta.total_seconds() // 60) + seconds = int(delta.total_seconds() % 60) + duration = f"{minutes}m {seconds}s" + modules_str = ", ".join(meta.modules_run) if meta.modules_run else "N/A" + metadata_html = f""" + + """ + + # Build executive summary HTML + risk_color = { + "Critical": "#ff4757", + "High": "#ffa502", + "Medium": "#3742fa", + "Low": "#888", + }.get(risk_level, "#888") + risk_pct = min(100, risk_score * 10) + exec_summary_html = f""" +
+

Executive Summary

+
+
Risk Score: {risk_score:.1f}/10 ({risk_level})
+
+
+
+
+

{self._escape_html(executive_summary)}

+
+ Total Findings: {len(self.findings)} + Targets: {len(findings_by_target)} +
+
+ """ + + # Build severity chart HTML (CSS horizontal bars) + total_findings = max(len(self.findings), 1) + chart_html = """

Severity Distribution

""" + chart_items = [ + ("Critical", counts[Severity.CRITICAL], "#ff4757"), + ("High", counts[Severity.HIGH], "#ffa502"), + ("Medium", counts[Severity.MEDIUM], "#3742fa"), + ("Low", counts[Severity.LOW], "#888"), + ("Info", counts[Severity.INFO], "#555"), + ] + for label, count, color in chart_items: + pct = (count / total_findings * 100) if count > 0 else 0 + chart_html += f""" +
+ {label} +
+
+
+ {count} +
""" + chart_html += "
" + + # Build OWASP coverage grid + owasp_html = """

OWASP LLM Top 10 Coverage

""" + for cat_id, info in owasp_coverage.items(): + if info["findings_count"] > 0: + card_class = "owasp-card vulnerable" + status_icon = "✗" # ✗ + status_text = ( + f"{info['findings_count']} finding{'s' if info['findings_count'] > 1 else ''}" + ) + elif info["tested"]: + card_class = "owasp-card clean" + status_icon = "✓" # ✓ + status_text = "Clean" + else: + card_class = "owasp-card not-tested" + status_icon = "—" # — + status_text = "Not tested" + owasp_html += f""" +
+
{cat_id}
+
{self._escape_html(info['title'])}
+
{status_icon} {status_text}
+
""" + owasp_html += "
" + # Generate findings HTML findings_html = "" for target, target_findings in findings_by_target.items(): - findings_html += f'

{target}

' + findings_html += f'

{self._escape_html(target)}

' for finding in target_findings: severity_class = finding.severity.value @@ -207,12 +513,12 @@ def export_html(self, filepath: str) -> None:
{finding.severity.value.upper()} - {finding.title} - {finding.technique} + {self._escape_html(finding.title)} + {self._escape_html(finding.technique)}
{owasp_badges} - {f'
Reason: {finding.reason}
' if finding.reason else ''} + {f'
Reason: {self._escape_html(finding.reason)}
' if finding.reason else ''}
Payload & Response @@ -226,12 +532,39 @@ def export_html(self, filepath: str) -> None:
- {f'
Details: {finding.details}
' if finding.details else ''} + {f'
Details: {self._escape_html(finding.details)}
' if finding.details else ''}
""" findings_html += "
" + # Build remediation section + remediation_html = "" + affected_categories = { + cat_id for cat_id, info in owasp_coverage.items() if info["findings_count"] > 0 + } + if affected_categories: + remediation_html = """

Remediation Recommendations

""" + for cat_id in sorted(affected_categories): + rec = OWASP_REMEDIATION.get(cat_id, {}) + if rec: + remediation_html += f""" +
+

{cat_id}: {self._escape_html(rec.get('title', ''))}

+

{self._escape_html(rec.get('recommendation', ''))}

+
""" + remediation_html += "
" + + # Build footer + from aix import __version__ + + footer_html = f""" +
+ Generated by AIX v{__version__} - AI eXploit Framework
+ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} +
+ """ + html = f""" @@ -245,46 +578,108 @@ def export_html(self, filepath: str) -> None: padding: 0; box-sizing: border-box; }} - + body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background: #0a0a0f; color: #e0e0e0; line-height: 1.6; }} - + .container {{ max-width: 1200px; margin: 0 auto; padding: 2rem; }} - + header {{ text-align: center; padding: 3rem 0; border-bottom: 1px solid #2a2a3a; margin-bottom: 2rem; }} - + .logo {{ font-family: 'Courier New', monospace; font-size: 2rem; color: #00d4ff; margin-bottom: 0.5rem; }} - + .subtitle {{ color: #888; font-size: 1.1rem; }} - + + .metadata {{ + display: flex; + flex-wrap: wrap; + gap: 1.5rem; + margin-top: 1rem; + justify-content: center; + }} + .metadata-item {{ + color: #aaa; + font-size: 0.9rem; + }} + .metadata-item strong {{ + color: #00d4ff; + }} + + .executive-summary {{ + background: #1a1a2a; + border: 1px solid #2a2a3a; + border-radius: 8px; + padding: 2rem; + margin-bottom: 2rem; + }} + .executive-summary h2 {{ + color: #00d4ff; + margin-bottom: 1rem; + border-bottom: 1px solid #2a2a3a; + padding-bottom: 0.5rem; + }} + .risk-gauge {{ + margin-bottom: 1rem; + }} + .risk-label {{ + font-size: 1.1rem; + margin-bottom: 0.5rem; + }} + .risk-bar-bg {{ + background: #2a2a3a; + border-radius: 4px; + height: 12px; + overflow: hidden; + }} + .risk-bar-fill {{ + height: 100%; + border-radius: 4px; + transition: width 0.3s; + }} + .exec-text {{ + color: #ccc; + margin: 1rem 0; + }} + .key-stats {{ + display: flex; + gap: 2rem; + }} + .key-stat {{ + color: #aaa; + font-size: 0.9rem; + }} + .key-stat strong {{ + color: #00d4ff; + }} + .stats {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem; margin-bottom: 2rem; }} - + .stat-card {{ background: #1a1a2a; border-radius: 8px; @@ -292,39 +687,130 @@ def export_html(self, filepath: str) -> None: text-align: center; border: 1px solid #2a2a3a; }} - + .stat-value {{ font-size: 2.5rem; font-weight: bold; }} - + .stat-label {{ color: #888; text-transform: uppercase; font-size: 0.8rem; letter-spacing: 1px; }} - + .stat-card.critical .stat-value {{ color: #ff4757; }} .stat-card.high .stat-value {{ color: #ffa502; }} .stat-card.medium .stat-value {{ color: #3742fa; }} .stat-card.low .stat-value {{ color: #888; }} - + + .severity-chart {{ + background: #1a1a2a; + border: 1px solid #2a2a3a; + border-radius: 8px; + padding: 2rem; + margin-bottom: 2rem; + }} + .severity-chart h2 {{ + color: #00d4ff; + margin-bottom: 1rem; + border-bottom: 1px solid #2a2a3a; + padding-bottom: 0.5rem; + }} + .chart-row {{ + display: flex; + align-items: center; + gap: 1rem; + margin-bottom: 0.5rem; + }} + .chart-label {{ + width: 70px; + text-align: right; + color: #aaa; + font-size: 0.85rem; + }} + .chart-bar-bg {{ + flex: 1; + background: #2a2a3a; + border-radius: 4px; + height: 20px; + overflow: hidden; + }} + .chart-bar-fill {{ + height: 100%; + border-radius: 4px; + }} + .chart-count {{ + width: 30px; + text-align: right; + font-weight: bold; + color: #e0e0e0; + }} + + .owasp-coverage {{ + margin-bottom: 2rem; + }} + .owasp-coverage h2 {{ + color: #00d4ff; + margin-bottom: 1rem; + border-bottom: 1px solid #2a2a3a; + padding-bottom: 0.5rem; + }} + .owasp-grid {{ + display: grid; + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); + gap: 0.75rem; + }} + .owasp-card {{ + background: #1a1a2a; + border: 1px solid #2a2a3a; + border-radius: 6px; + padding: 1rem; + }} + .owasp-card.vulnerable {{ + border-color: #ff4757; + }} + .owasp-card.clean {{ + border-color: #2ed573; + }} + .owasp-card.not-tested {{ + border-color: #444; + opacity: 0.6; + }} + .owasp-card-id {{ + font-family: monospace; + font-weight: bold; + color: #00d4ff; + font-size: 0.85rem; + }} + .owasp-card-title {{ + font-size: 0.8rem; + color: #ccc; + margin: 0.25rem 0; + }} + .owasp-card-status {{ + font-size: 0.75rem; + color: #888; + }} + .owasp-card.vulnerable .owasp-card-status {{ color: #ff4757; }} + .owasp-card.clean .owasp-card-status {{ color: #2ed573; }} + .findings {{ margin-top: 2rem; }} - + .findings h2 {{ margin-bottom: 1rem; color: #00d4ff; border-bottom: 1px solid #2a2a3a; padding-bottom: 0.5rem; }} - + .target-group {{ margin-bottom: 2rem; }} - + .target-group h3 {{ color: #7bed9f; margin-bottom: 1rem; @@ -332,7 +818,7 @@ def export_html(self, filepath: str) -> None: border-left: 3px solid #7bed9f; padding-left: 1rem; }} - + .finding {{ background: #1a1a2a; border-radius: 8px; @@ -340,7 +826,7 @@ def export_html(self, filepath: str) -> None: border: 1px solid #2a2a3a; overflow: hidden; }} - + .finding-header {{ padding: 1rem 1.5rem; background: #252535; @@ -349,7 +835,7 @@ def export_html(self, filepath: str) -> None: gap: 1rem; flex-wrap: wrap; }} - + .severity-badge {{ padding: 0.25rem 0.75rem; border-radius: 4px; @@ -359,7 +845,7 @@ def export_html(self, filepath: str) -> None: min-width: 80px; text-align: center; }} - + .technique-badge {{ background: #2a2a3a; color: #aaa; @@ -392,37 +878,37 @@ def export_html(self, filepath: str) -> None: .severity-badge.high {{ background: #ffa502; color: black; }} .severity-badge.medium {{ background: #3742fa; color: white; }} .severity-badge.low {{ background: #555; color: white; }} - + .finding-title {{ font-weight: 600; }} - + .finding-body {{ padding: 1.5rem; }} - + .finding-field {{ margin-bottom: 1rem; }} - + .finding-field.reason {{ background: #2a2a3a; padding: 0.75rem; border-radius: 4px; border-left: 3px solid #00d4ff; }} - + .finding-field strong {{ color: #00d4ff; }} - - details summmary {{ + + details summary {{ cursor: pointer; color: #888; margin-bottom: 1rem; outline: none; }} - + details summary:hover {{ color: #fff; }} @@ -430,7 +916,7 @@ def export_html(self, filepath: str) -> None: details[open] summary {{ margin-bottom: 1rem; }} - + pre {{ background: #0a0a0f; padding: 1rem; @@ -447,22 +933,48 @@ def export_html(self, filepath: str) -> None: height: 8px; }} pre::-webkit-scrollbar-track {{ - background: #0a0a0f; + background: #0a0a0f; }} pre::-webkit-scrollbar-thumb {{ - background: #2a2a3a; + background: #2a2a3a; border-radius: 4px; }} pre::-webkit-scrollbar-thumb:hover {{ - background: #00d4ff; + background: #00d4ff; }} - + code {{ font-family: 'Courier New', monospace; font-size: 0.9rem; color: #7bed9f; }} - + + .remediation {{ + background: #1a1a2a; + border: 1px solid #2a2a3a; + border-radius: 8px; + padding: 2rem; + margin-top: 2rem; + }} + .remediation h2 {{ + color: #00d4ff; + margin-bottom: 1rem; + border-bottom: 1px solid #2a2a3a; + padding-bottom: 0.5rem; + }} + .remediation-item {{ + margin-bottom: 1.5rem; + }} + .remediation-item h3 {{ + color: #ffa502; + font-size: 1rem; + margin-bottom: 0.5rem; + }} + .remediation-item p {{ + color: #ccc; + font-size: 0.9rem; + }} + footer {{ text-align: center; padding: 2rem; @@ -475,10 +987,13 @@ def export_html(self, filepath: str) -> None:
- +
AI Security Testing Report
+ {metadata_html}
- + + {exec_summary_html} +
{counts[Severity.CRITICAL]}
@@ -497,16 +1012,19 @@ def export_html(self, filepath: str) -> None:
Low
- + + {chart_html} + + {owasp_html} +

Findings

{findings_html if findings_html else '

No findings to display.

'}
- - + + {remediation_html} + + {footer_html}
diff --git a/aix/core/scanner.py b/aix/core/scanner.py index 8a1dbe5..b933480 100644 --- a/aix/core/scanner.py +++ b/aix/core/scanner.py @@ -5,13 +5,14 @@ import asyncio import json import os +import uuid from abc import ABC from typing import Any, Optional from rich.console import Console from aix.core.ai_engine import AIEngine -from aix.core.connector import APIConnector, RequestConnector +from aix.core.connector import APIConnector, RequestConnector, WebSocketConnector from aix.core.context import TargetContext from aix.core.evaluator import LLMEvaluator from aix.core.evasion import PayloadEvasion @@ -65,6 +66,14 @@ def __init__( self.response_regex = kwargs.get("response_regex") self.response_path = kwargs.get("response_path") + # Chat ID config + self.chat_id_path = kwargs.get("chat_id_path") + self.chat_id_param = kwargs.get("chat_id_param") + self.new_chat = kwargs.get("new_chat", True) # Default: fresh per payload + + # Session config + self.session_id = kwargs.get("session_id") + # Filtering config self.level = kwargs.get("level", 1) self.risk = kwargs.get("risk", 1) @@ -230,6 +239,23 @@ def _create_connector(self): response_regex=self.response_regex, response_path=self.response_path, console=self.console, + chat_id_path=self.chat_id_path, + chat_id_param=self.chat_id_param, + ) + elif self.target.startswith(("ws://", "wss://")): + return WebSocketConnector( + self.target, + injection_param=self.injection_param, + response_path=self.response_path, + response_regex=self.response_regex, + cookies=self.cookies, + headers=self.headers, + timeout=self.timeout, + verbose=self.verbose, + console=self.console, + proxy=self.proxy, + chat_id_path=self.chat_id_path, + chat_id_param=self.chat_id_param, ) else: return APIConnector( @@ -246,6 +272,8 @@ def _create_connector(self): response_regex=self.response_regex, response_path=self.response_path, console=self.console, + chat_id_path=self.chat_id_path, + chat_id_param=self.chat_id_param, ) def _print(self, status: str, msg: str, tech: str = "", response: str = None): @@ -545,6 +573,13 @@ async def _run_payload_scan( ): self.stats["total"] += 1 try: + # Reset chat ID for fresh conversation per payload + if self.new_chat: + connector.reset_chat_id() + + # Generate per-payload conversation ID + conv_id = str(uuid.uuid4()) + is_vulnerable, best_resp = await self.scan_payload( connector, p["payload"], p["indicators"], p["name"] ) @@ -552,6 +587,26 @@ async def _run_payload_scan( # Let subclass modify vulnerability decision / add info is_vulnerable, extra_response = self._on_finding(p, best_resp, is_vulnerable) + # Save conversation record + target_chat_id = connector.current_chat_id + self.db.save_conversation( + target=self.target, + module=db_key, + technique=p["name"], + transcript=[ + {"role": "user", "content": p["payload"], "turn_number": 1}, + { + "role": "assistant", + "content": (best_resp or "")[:response_limit], + "turn_number": 1, + }, + ], + turn_count=1, + session_id=self.session_id, + target_chat_id=target_chat_id, + conversation_id=conv_id, + ) + if is_vulnerable: self.stats["success"] += 1 self._print("success", "", p["name"], response=best_resp) @@ -587,6 +642,8 @@ async def _run_payload_scan( p["severity"].value, reason=self.last_eval_reason, owasp=p.get("owasp", []), + session_id=self.session_id, + conversation_id=conv_id, **db_kwargs, ) else: diff --git a/aix/db/database.py b/aix/db/database.py index 0a42497..503b407 100644 --- a/aix/db/database.py +++ b/aix/db/database.py @@ -7,6 +7,7 @@ import json import sqlite3 +import uuid from datetime import datetime from pathlib import Path from typing import Any @@ -90,6 +91,56 @@ def _init_db(self) -> None: ) cursor.execute("ALTER TABLE results ADD COLUMN owasp TEXT") + # Sessions table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + name TEXT, + target TEXT NOT NULL, + status TEXT DEFAULT 'active', + notes TEXT, + modules_run TEXT, + start_time DATETIME DEFAULT CURRENT_TIMESTAMP, + end_time DATETIME + ) + """) + + # Conversations table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + session_id TEXT, + target TEXT NOT NULL, + module TEXT NOT NULL, + technique TEXT, + target_chat_id TEXT, + transcript TEXT, + turn_count INTEGER DEFAULT 0, + status TEXT DEFAULT 'completed', + started_at DATETIME DEFAULT CURRENT_TIMESTAMP, + finished_at DATETIME, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ) + """) + + # Migration: Add session_id column to results + try: + cursor.execute("SELECT session_id FROM results LIMIT 1") + except sqlite3.OperationalError: + console.print( + "[yellow][*] Migrating database: Adding 'session_id' column to results table[/yellow]" + ) + cursor.execute("ALTER TABLE results ADD COLUMN session_id TEXT") + + # Migration: Add conversation_id column to results + try: + cursor.execute("SELECT conversation_id FROM results LIMIT 1") + except sqlite3.OperationalError: + console.print( + "[yellow][*] Migrating database: Adding 'conversation_id' column to results table[/yellow]" + ) + cursor.execute("ALTER TABLE results ADD COLUMN conversation_id TEXT") + # Profiles table cursor.execute(""" CREATE TABLE IF NOT EXISTS profiles ( @@ -145,6 +196,8 @@ def add_result( reason: str = "", owasp: list[str] | None = None, dedup_payload: str | None = None, + session_id: str | None = None, + conversation_id: str | None = None, ) -> int: """ Add a scan result. @@ -190,10 +243,22 @@ def add_result( cursor.execute( """ UPDATE results - SET result = ?, payload = ?, response = ?, severity = ?, reason = ?, owasp = ?, timestamp = CURRENT_TIMESTAMP + SET result = ?, payload = ?, response = ?, severity = ?, reason = ?, owasp = ?, + session_id = COALESCE(?, session_id), conversation_id = COALESCE(?, conversation_id), + timestamp = CURRENT_TIMESTAMP WHERE id = ? """, - (result, payload, response, severity, reason, owasp_json, row_id), + ( + result, + payload, + response, + severity, + reason, + owasp_json, + session_id, + conversation_id, + row_id, + ), ) self.conn.commit() return row_id @@ -201,8 +266,8 @@ def add_result( # Insert new result cursor.execute( """ - INSERT INTO results (target, module, technique, result, payload, response, severity, reason, owasp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO results (target, module, technique, result, payload, response, severity, reason, owasp, session_id, conversation_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( target, @@ -214,6 +279,8 @@ def add_result( severity, reason, owasp_json, + session_id, + conversation_id, ), ) self.conn.commit() @@ -224,6 +291,7 @@ def get_results( target: str | None = None, module: str | None = None, result: str | None = None, + session_id: str | None = None, limit: int = 100, ) -> list[dict[str, Any]]: """Get scan results with optional filters""" @@ -244,6 +312,10 @@ def get_results( query += " AND result = ?" params.append(result) + if session_id: + query += " AND session_id = ?" + params.append(session_id) + query += " ORDER BY timestamp DESC LIMIT ?" params.append(limit) @@ -321,6 +393,304 @@ def clear(self) -> None: cursor.execute("DELETE FROM results") self.conn.commit() + # ======================================================================== + # Sessions + # ======================================================================== + + def create_session( + self, + target: str, + name: str | None = None, + notes: str | None = None, + ) -> str: + """Create a new scan session. Returns session_id.""" + session_id = str(uuid.uuid4()) + if not name: + name = f"Scan - {datetime.now().strftime('%Y-%m-%d %H:%M')}" + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT INTO sessions (id, name, target, status, notes, modules_run) + VALUES (?, ?, ?, 'active', ?, '[]') + """, + (session_id, name, target, notes), + ) + self.conn.commit() + return session_id + + def end_session(self, session_id: str, status: str = "completed") -> None: + """Mark a session as completed or aborted.""" + cursor = self.conn.cursor() + cursor.execute( + """ + UPDATE sessions SET status = ?, end_time = CURRENT_TIMESTAMP + WHERE id = ? + """, + (status, session_id), + ) + self.conn.commit() + + def update_session_modules(self, session_id: str, module: str) -> None: + """Append a module to the session's modules_run list.""" + cursor = self.conn.cursor() + cursor.execute("SELECT modules_run FROM sessions WHERE id = ?", (session_id,)) + row = cursor.fetchone() + if not row: + return + modules = json.loads(row[0] or "[]") + if module not in modules: + modules.append(module) + cursor.execute( + "UPDATE sessions SET modules_run = ? WHERE id = ?", + (json.dumps(modules), session_id), + ) + self.conn.commit() + + def get_session(self, session_id: str) -> dict[str, Any] | None: + """Get a session by ID.""" + cursor = self.conn.cursor() + cursor.execute("SELECT * FROM sessions WHERE id = ?", (session_id,)) + row = cursor.fetchone() + if row: + result = dict(row) + result["modules_run"] = json.loads(result.get("modules_run") or "[]") + return result + return None + + def list_sessions(self, limit: int = 50) -> list[dict[str, Any]]: + """List all sessions, newest first.""" + cursor = self.conn.cursor() + cursor.execute("SELECT * FROM sessions ORDER BY start_time DESC LIMIT ?", (limit,)) + rows = [] + for row in cursor.fetchall(): + r = dict(row) + r["modules_run"] = json.loads(r.get("modules_run") or "[]") + rows.append(r) + return rows + + def get_session_results(self, session_id: str) -> list[dict[str, Any]]: + """Get all results for a specific session.""" + return self.get_results(session_id=session_id, limit=10000) + + def get_or_create_session(self, target: str) -> str: + """Find active session for target or create new one. Returns session_id.""" + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT id FROM sessions + WHERE target = ? AND status = 'active' + AND start_time > datetime('now', '-24 hours') + ORDER BY start_time DESC LIMIT 1 + """, + (target,), + ) + row = cursor.fetchone() + if row: + return row[0] + return self.create_session(target=target) + + def display_sessions(self, sessions: list[dict[str, Any]] | None = None) -> None: + """Display sessions in a nice table.""" + if sessions is None: + sessions = self.list_sessions() + + if not sessions: + console.print("[dim]No sessions found[/dim]") + return + + table = Table(title="AIX Sessions") + table.add_column("ID", style="dim", max_width=8) + table.add_column("Name", style="cyan") + table.add_column("Target", style="green", max_width=30) + table.add_column("Status") + table.add_column("Modules", style="blue") + table.add_column("Started", style="dim") + + for s in sessions: + status_str = s["status"] + if status_str == "active": + status_str = "[green]active[/green]" + elif status_str == "completed": + status_str = "[blue]completed[/blue]" + elif status_str == "aborted": + status_str = "[red]aborted[/red]" + + target = s["target"] + if len(target) > 28: + target = target[:25] + "..." + + modules = ", ".join(s.get("modules_run", [])) + started = s.get("start_time", "")[:16] if s.get("start_time") else "" + + table.add_row( + s["id"][:8], + s.get("name", ""), + target, + status_str, + modules, + started, + ) + + console.print(table) + + # ======================================================================== + # Conversations + # ======================================================================== + + def save_conversation( + self, + target: str, + module: str, + technique: str = "", + transcript: list[dict] | None = None, + turn_count: int = 0, + session_id: str | None = None, + target_chat_id: str | None = None, + conversation_id: str | None = None, + status: str = "completed", + ) -> str: + """Save a conversation record. Returns conversation_id.""" + if not conversation_id: + conversation_id = str(uuid.uuid4()) + cursor = self.conn.cursor() + transcript_json = json.dumps(transcript or []) + cursor.execute( + """ + INSERT INTO conversations (id, session_id, target, module, technique, target_chat_id, transcript, turn_count, status, finished_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + """, + ( + conversation_id, + session_id, + target, + module, + technique, + target_chat_id, + transcript_json, + turn_count, + status, + ), + ) + self.conn.commit() + return conversation_id + + def get_conversation(self, conversation_id: str) -> dict[str, Any] | None: + """Get a conversation by ID.""" + cursor = self.conn.cursor() + cursor.execute("SELECT * FROM conversations WHERE id = ?", (conversation_id,)) + row = cursor.fetchone() + if row: + result = dict(row) + result["transcript"] = json.loads(result.get("transcript") or "[]") + return result + return None + + def list_conversations( + self, + session_id: str | None = None, + target: str | None = None, + limit: int = 100, + ) -> list[dict[str, Any]]: + """List conversations with optional filters.""" + cursor = self.conn.cursor() + query = "SELECT * FROM conversations WHERE 1=1" + params: list[Any] = [] + + if session_id: + query += " AND session_id = ?" + params.append(session_id) + if target: + query += " AND target LIKE ?" + params.append(f"%{target}%") + + query += " ORDER BY started_at DESC LIMIT ?" + params.append(limit) + + cursor.execute(query, params) + rows = [] + for row in cursor.fetchall(): + r = dict(row) + r["transcript"] = json.loads(r.get("transcript") or "[]") + rows.append(r) + return rows + + def display_conversations(self, conversations: list[dict[str, Any]] | None = None) -> None: + """Display conversations in a nice table.""" + if conversations is None: + conversations = self.list_conversations() + + if not conversations: + console.print("[dim]No conversations found[/dim]") + return + + table = Table(title="AIX Conversations") + table.add_column("ID", style="dim", max_width=8) + table.add_column("Target", style="cyan", max_width=25) + table.add_column("Module", style="blue") + table.add_column("Technique", max_width=20) + table.add_column("Turns", justify="right") + table.add_column("Chat ID", style="yellow", max_width=12) + table.add_column("Status") + table.add_column("Date", style="dim") + + for c in conversations: + target = c["target"] + if len(target) > 23: + target = target[:20] + "..." + + chat_id = c.get("target_chat_id") or "" + if len(chat_id) > 10: + chat_id = chat_id[:10] + ".." + + date_str = c.get("started_at", "")[:16] if c.get("started_at") else "" + + table.add_row( + c["id"][:8], + target, + c["module"], + c.get("technique", ""), + str(c.get("turn_count", 0)), + chat_id, + c.get("status", ""), + date_str, + ) + + console.print(table) + + def display_conversation_transcript(self, conversation_id: str) -> None: + """Display a full conversation transcript.""" + conv = self.get_conversation(conversation_id) + if not conv: + console.print(f"[red]Conversation {conversation_id} not found[/red]") + return + + from rich.panel import Panel + + console.print( + Panel( + f"[bold]Module:[/bold] {conv['module']} [bold]Technique:[/bold] {conv.get('technique', 'N/A')}\n" + f"[bold]Target:[/bold] {conv['target']}\n" + f"[bold]Target Chat ID:[/bold] {conv.get('target_chat_id') or 'N/A'}\n" + f"[bold]Turns:[/bold] {conv.get('turn_count', 0)} [bold]Status:[/bold] {conv.get('status', 'N/A')}", + title=f"[cyan]Conversation {conversation_id[:8]}[/cyan]", + ) + ) + + transcript = conv.get("transcript", []) + for entry in transcript: + role = entry.get("role", "unknown").upper() + content = entry.get("content", "") + turn = entry.get("turn_number", "?") + + if role == "USER": + console.print(f"\n[bold blue]Turn {turn} - USER:[/bold blue]") + else: + console.print(f"\n[bold green]Turn {turn} - ASSISTANT:[/bold green]") + + # Escape Rich markup in content + escaped = content.replace("[", r"\[") + console.print(f" {escaped[:1000]}") + # ======================================================================== # Profiles # ======================================================================== @@ -396,15 +766,37 @@ def export_html( filepath: str, target: str | None = None, module: str | None = None, + session_id: str | None = None, ) -> None: """Export results to HTML report""" from aix.core.owasp import parse_owasp_list - from aix.core.reporting.base import Finding, Reporter, Severity + from aix.core.reporting.base import Finding, Reporter, ScanMetadata, Severity - results = self.get_results(target=target, module=module, limit=1000) + results = self.get_results(target=target, module=module, session_id=session_id, limit=1000) reporter = Reporter() + # Load session metadata if available + if session_id: + session = self.get_session(session_id) + if session: + reporter.metadata = ScanMetadata( + session_id=session_id, + session_name=session.get("name"), + target=session.get("target", ""), + start_time=( + datetime.fromisoformat(session["start_time"]) + if session.get("start_time") + else None + ), + end_time=( + datetime.fromisoformat(session["end_time"]) + if session.get("end_time") + else None + ), + modules_run=session.get("modules_run", []), + ) + for r in results: if r["result"] == "success": severity = Severity(r.get("severity", "high")) diff --git a/aix/modules/multiturn.py b/aix/modules/multiturn.py index 4206fb6..50147ed 100644 --- a/aix/modules/multiturn.py +++ b/aix/modules/multiturn.py @@ -206,6 +206,24 @@ async def run(self, sequences: list[dict] = None): # Execute sequence result = await conv_manager.execute_sequence(seq) + # Save conversation transcript + target_chat_id = ( + connector.current_chat_id + if hasattr(connector, "current_chat_id") + else None + ) + self.db.save_conversation( + target=self.target, + module="multiturn", + technique=seq.get("name", "unnamed"), + transcript=conv_manager.get_transcript_as_dicts(), + turn_count=result.turns_executed, + session_id=self.session_id, + target_chat_id=target_chat_id, + conversation_id=result.conversation_id, + status="success" if result.success else "failed", + ) + if result.success: self.stats["success"] += 1 category_stats[seq_category]["success"] += 1 @@ -235,6 +253,8 @@ async def run(self, sequences: list[dict] = None): severity.value, reason=f"Category: {seq_category}, Indicators: {result.matched_indicators}", owasp=seq.get("owasp", []), + session_id=self.session_id, + conversation_id=result.conversation_id, ) self._print( diff --git a/aix/utils/cli.py b/aix/utils/cli.py index 8aa69c9..507ce4f 100644 --- a/aix/utils/cli.py +++ b/aix/utils/cli.py @@ -125,15 +125,48 @@ def wrapper(*args, **kwargs): return wrapper +def chat_id_options(func): + """ + Decorator that adds target chat ID handling options. + """ + + @click.option( + "--chat-id-path", + help="JSON path to extract chat ID from response (e.g., conversation_id, data.chat_id)", + ) + @click.option( + "--chat-id-param", + help="Field name to inject chat ID into request body (or use {chat_id} in URL)", + ) + @click.option( + "--new-chat", + is_flag=True, + default=False, + help="Force a fresh conversation per payload (default for single-turn scans)", + ) + @click.option( + "--reuse-chat", + is_flag=True, + default=False, + help="Reuse the same conversation across payloads (default for multi-turn)", + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def standard_options(func): """ - Combines all standard options: common, refresh, ai, and scan. + Combines all standard options: common, refresh, ai, scan, and chat_id. """ @common_options @refresh_options @ai_options @scan_options + @chat_id_options @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 952e7b6..af49b15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aix-framework" -version = "1.0.2" +version = "1.1.0" description = "AIX - AI eXploit Framework: Comprehensive security testing toolkit for AI/LLM systems" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_aix.py b/tests/test_aix.py index 7217f10..4990ce3 100644 --- a/tests/test_aix.py +++ b/tests/test_aix.py @@ -15,7 +15,7 @@ class TestVersion: def test_version_is_set(self): """Test version is set correctly""" - assert __version__ == "1.0.1" + assert __version__ == "1.1.0" def test_version_format(self): """Test version follows semver format""" diff --git a/tests/test_connector.py b/tests/test_connector.py index a3969d2..703acf5 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -4,7 +4,7 @@ import pytest -from aix.core.connector import APIConnector, RequestConnector +from aix.core.connector import APIConnector, RequestConnector, WebSocketConnector from aix.core.request_parser import ParsedRequest @@ -234,3 +234,343 @@ def test_parse_headers_empty(self): result = connector._parse_headers(None) assert result == {} + + +class TestWebSocketConnector: + """Tests for WebSocketConnector class""" + + def test_init_defaults(self): + """Test default initialization""" + connector = WebSocketConnector(url="wss://target.com/chat") + + assert connector.url == "wss://target.com/chat" + assert connector.injection_param == "message" + assert connector.response_path is None + assert connector.timeout == 30 + assert connector.ws is None + + def test_init_custom_params(self): + """Test initialization with custom parameters""" + connector = WebSocketConnector( + url="wss://target.com/chat", + injection_param="query", + response_path="data.content", + timeout=60, + ) + + assert connector.injection_param == "query" + assert connector.response_path == "data.content" + assert connector.timeout == 60 + + def test_build_message(self): + """Test message building with default param""" + connector = WebSocketConnector(url="wss://target.com/chat") + + result = connector._build_message("hello") + + assert result == '{"message": "hello"}' + + def test_build_message_custom_param(self): + """Test message building with custom injection param""" + connector = WebSocketConnector(url="wss://target.com/chat", injection_param="query") + + result = connector._build_message("test payload") + + assert result == '{"query": "test payload"}' + + def test_extract_response_with_path(self): + """Test response extraction with explicit path""" + connector = WebSocketConnector(url="wss://target.com/chat", response_path="content") + + raw = '{"user": "Bot", "content": "Hello there!"}' + result = connector._extract_response(raw) + + assert result == "Hello there!" + + def test_extract_response_nested_path(self): + """Test response extraction with nested dot path""" + connector = WebSocketConnector(url="wss://target.com/chat", response_path="data.text") + + raw = '{"data": {"text": "nested value"}}' + result = connector._extract_response(raw) + + assert result == "nested value" + + def test_extract_response_fallback(self): + """Test response extraction with fallback keys (no response_path)""" + connector = WebSocketConnector(url="wss://target.com/chat") + + raw = '{"user": "Bot", "content": "fallback hit"}' + result = connector._extract_response(raw) + + assert result == "fallback hit" + + def test_extract_response_fallback_response_key(self): + """Test fallback to 'response' key""" + connector = WebSocketConnector(url="wss://target.com/chat") + + raw = '{"response": "answer here"}' + result = connector._extract_response(raw) + + assert result == "answer here" + + def test_extract_response_no_known_keys(self): + """Test fallback to JSON dump when no known keys match""" + connector = WebSocketConnector(url="wss://target.com/chat") + + raw = '{"unknown_field": "some data"}' + result = connector._extract_response(raw) + + assert "unknown_field" in result + assert "some data" in result + + def test_extract_response_non_json(self): + """Test non-JSON response passthrough""" + connector = WebSocketConnector(url="wss://target.com/chat") + + result = connector._extract_response("plain text response") + + assert result == "plain text response" + + def test_build_extra_headers_cookies(self): + """Test extra headers with cookies""" + connector = WebSocketConnector(url="wss://target.com/chat", cookies="session=abc;token=xyz") + + headers = connector._build_extra_headers() + + assert "Cookie" in headers + assert "session=abc" in headers["Cookie"] + assert "token=xyz" in headers["Cookie"] + + def test_build_extra_headers_custom(self): + """Test extra headers with custom headers""" + connector = WebSocketConnector( + url="wss://target.com/chat", headers="Authorization:Bearer tok;X-Custom:val" + ) + + headers = connector._build_extra_headers() + + assert headers["Authorization"] == "Bearer tok" + assert headers["X-Custom"] == "val" + + def test_build_extra_headers_both(self): + """Test extra headers with both cookies and custom headers""" + connector = WebSocketConnector( + url="wss://target.com/chat", + cookies="session=abc", + headers="X-Custom:val", + ) + + headers = connector._build_extra_headers() + + assert "Cookie" in headers + assert headers["X-Custom"] == "val" + + def test_build_extra_headers_empty(self): + """Test extra headers when none configured""" + connector = WebSocketConnector(url="wss://target.com/chat") + + headers = connector._build_extra_headers() + + assert headers == {} + + def test_navigate_path_nested(self): + """Test dot-path navigation in nested data""" + connector = WebSocketConnector(url="wss://target.com/chat") + + data = {"a": {"b": {"c": "deep"}}} + result = connector._navigate_path(data, "a.b.c") + + assert result == "deep" + + def test_navigate_path_missing(self): + """Test dot-path navigation with missing key""" + connector = WebSocketConnector(url="wss://target.com/chat") + + data = {"a": {"x": 1}} + result = connector._navigate_path(data, "a.b.c") + + assert result == "" + + @pytest.mark.asyncio + async def test_context_manager(self): + """Test async context manager protocol (close without connect)""" + connector = WebSocketConnector(url="wss://target.com/chat") + + # close() should not raise even without a connection + await connector.close() + assert connector.ws is None + + +class TestChatIdCapture: + """Tests for chat ID capture from responses""" + + def test_capture_simple_path(self): + """Test capturing chat_id from a simple JSON path""" + connector = APIConnector( + url="https://example.com", + chat_id_path="conversation_id", + ) + + connector._capture_chat_id({"conversation_id": "abc123", "response": "hi"}) + assert connector.current_chat_id == "abc123" + + def test_capture_nested_path(self): + """Test capturing chat_id from a nested JSON path""" + connector = APIConnector( + url="https://example.com", + chat_id_path="data.chat_id", + ) + + connector._capture_chat_id({"data": {"chat_id": "nested-id"}, "response": "hi"}) + assert connector.current_chat_id == "nested-id" + + def test_capture_no_path_configured(self): + """Test that no capture happens when chat_id_path is not set""" + connector = APIConnector(url="https://example.com") + + connector._capture_chat_id({"conversation_id": "abc123"}) + assert connector.current_chat_id is None + + def test_capture_path_missing_in_response(self): + """Test that missing path in response doesn't crash""" + connector = APIConnector( + url="https://example.com", + chat_id_path="nonexistent.path", + ) + + connector._capture_chat_id({"other_field": "value"}) + assert connector.current_chat_id is None + + def test_capture_converts_to_string(self): + """Test that numeric chat IDs are converted to string""" + connector = APIConnector( + url="https://example.com", + chat_id_path="chat_id", + ) + + connector._capture_chat_id({"chat_id": 12345}) + assert connector.current_chat_id == "12345" + + def test_capture_on_websocket_connector(self): + """Test chat ID capture on WebSocketConnector""" + connector = WebSocketConnector( + url="wss://example.com/chat", + chat_id_path="session_id", + ) + + connector._capture_chat_id({"session_id": "ws-123", "text": "hi"}) + assert connector.current_chat_id == "ws-123" + + +class TestChatIdInjection: + """Tests for chat ID injection into requests""" + + def test_inject_into_payload_body(self): + """Test chat_id injection into API request body""" + connector = APIConnector( + url="https://example.com/chat", + chat_id_param="conversation_id", + ) + connector._current_chat_id = "injected-id" + + payload = connector._build_payload("hello") + assert payload["conversation_id"] == "injected-id" + + def test_no_injection_without_chat_id(self): + """Test no injection when no chat_id has been captured""" + connector = APIConnector( + url="https://example.com/chat", + chat_id_param="conversation_id", + ) + + payload = connector._build_payload("hello") + assert "conversation_id" not in payload + + def test_no_injection_without_param(self): + """Test no injection when chat_id_param is not set""" + connector = APIConnector(url="https://example.com/chat") + connector._current_chat_id = "some-id" + + payload = connector._build_payload("hello") + assert "conversation_id" not in payload + + def test_websocket_inject_into_message(self): + """Test chat_id injection into WebSocket message""" + connector = WebSocketConnector( + url="wss://example.com/chat", + chat_id_param="session_id", + ) + connector._current_chat_id = "ws-inject-id" + + import json + msg = json.loads(connector._build_message("test")) + assert msg["session_id"] == "ws-inject-id" + + def test_no_body_injection_when_url_placeholder(self): + """Test that chat_id is NOT injected into body when URL has {chat_id}""" + connector = APIConnector( + url="https://example.com/chat/{chat_id}/messages", + chat_id_param="conversation_id", + ) + connector._current_chat_id = "url-id" + + payload = connector._build_payload("hello") + assert "conversation_id" not in payload + + +class TestChatIdReset: + """Tests for chat ID reset""" + + def test_reset_clears_chat_id(self): + """Test that reset_chat_id clears the captured ID""" + connector = APIConnector( + url="https://example.com", + chat_id_path="chat_id", + ) + connector._current_chat_id = "some-id" + + connector.reset_chat_id() + assert connector.current_chat_id is None + + def test_reset_on_fresh_connector(self): + """Test that reset on a fresh connector doesn't crash""" + connector = APIConnector(url="https://example.com") + connector.reset_chat_id() + assert connector.current_chat_id is None + + +class TestNavigateJsonPath: + """Tests for _navigate_json_path on Connector base class""" + + def test_simple_path(self): + """Test simple key navigation""" + connector = APIConnector(url="https://example.com") + result = connector._navigate_json_path({"key": "value"}, "key") + assert result == "value" + + def test_nested_path(self): + """Test nested dot-separated path""" + connector = APIConnector(url="https://example.com") + data = {"a": {"b": {"c": "deep"}}} + result = connector._navigate_json_path(data, "a.b.c") + assert result == "deep" + + def test_list_index(self): + """Test list index navigation""" + connector = APIConnector(url="https://example.com") + data = {"items": [{"id": "first"}, {"id": "second"}]} + result = connector._navigate_json_path(data, "items.1.id") + assert result == "second" + + def test_missing_key(self): + """Test missing key returns None""" + connector = APIConnector(url="https://example.com") + result = connector._navigate_json_path({"a": 1}, "b") + assert result is None + + def test_list_index_out_of_range(self): + """Test out-of-range list index returns None""" + connector = APIConnector(url="https://example.com") + result = connector._navigate_json_path({"items": [1, 2]}, "items.5") + assert result is None diff --git a/tests/test_database.py b/tests/test_database.py index f865af6..bd46608 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -453,3 +453,317 @@ def test_very_long_payload(self): assert len(results[0]["payload"]) == 10000 finally: db.close() + + +class TestSessions: + """Tests for session management""" + + @pytest.fixture + def temp_db(self): + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_aix.db") + db = AIXDatabase(db_path) + yield db + db.close() + + def test_create_session(self, temp_db): + """Test creating a session returns a UUID""" + session_id = temp_db.create_session(target="https://example.com") + assert session_id is not None + assert len(session_id) == 36 # UUID length + + def test_create_session_with_name(self, temp_db): + """Test creating a session with a custom name""" + session_id = temp_db.create_session( + target="https://example.com", name="My Test Session" + ) + session = temp_db.get_session(session_id) + assert session is not None + assert session["name"] == "My Test Session" + assert session["target"] == "https://example.com" + assert session["status"] == "active" + + def test_end_session(self, temp_db): + """Test ending a session""" + session_id = temp_db.create_session(target="https://example.com") + temp_db.end_session(session_id, status="completed") + + session = temp_db.get_session(session_id) + assert session["status"] == "completed" + assert session["end_time"] is not None + + def test_end_session_aborted(self, temp_db): + """Test ending a session as aborted""" + session_id = temp_db.create_session(target="https://example.com") + temp_db.end_session(session_id, status="aborted") + + session = temp_db.get_session(session_id) + assert session["status"] == "aborted" + + def test_update_session_modules(self, temp_db): + """Test appending modules to a session""" + session_id = temp_db.create_session(target="https://example.com") + temp_db.update_session_modules(session_id, "inject") + temp_db.update_session_modules(session_id, "jailbreak") + + session = temp_db.get_session(session_id) + assert session["modules_run"] == ["inject", "jailbreak"] + + def test_update_session_modules_no_duplicates(self, temp_db): + """Test that duplicate modules are not added""" + session_id = temp_db.create_session(target="https://example.com") + temp_db.update_session_modules(session_id, "inject") + temp_db.update_session_modules(session_id, "inject") + + session = temp_db.get_session(session_id) + assert session["modules_run"] == ["inject"] + + def test_list_sessions(self, temp_db): + """Test listing sessions""" + temp_db.create_session(target="https://target1.com") + temp_db.create_session(target="https://target2.com") + + sessions = temp_db.list_sessions() + assert len(sessions) == 2 + + def test_get_session_not_found(self, temp_db): + """Test getting a non-existent session""" + result = temp_db.get_session("nonexistent-id") + assert result is None + + def test_get_or_create_session_creates_new(self, temp_db): + """Test get_or_create_session creates a new session when none exists""" + session_id = temp_db.get_or_create_session("https://example.com") + assert session_id is not None + session = temp_db.get_session(session_id) + assert session["target"] == "https://example.com" + assert session["status"] == "active" + + def test_get_or_create_session_reuses_existing(self, temp_db): + """Test get_or_create_session reuses an active session""" + id1 = temp_db.get_or_create_session("https://example.com") + id2 = temp_db.get_or_create_session("https://example.com") + assert id1 == id2 + + def test_get_or_create_session_different_targets(self, temp_db): + """Test get_or_create_session creates separate sessions per target""" + id1 = temp_db.get_or_create_session("https://target1.com") + id2 = temp_db.get_or_create_session("https://target2.com") + assert id1 != id2 + + def test_get_session_results(self, temp_db): + """Test getting results filtered by session""" + session_id = temp_db.create_session(target="https://example.com") + + temp_db.add_result( + target="https://example.com", + module="inject", + technique="test", + result="success", + payload="p", + response="r", + severity="high", + session_id=session_id, + ) + temp_db.add_result( + target="https://example.com", + module="jailbreak", + technique="test2", + result="success", + payload="p2", + response="r2", + severity="medium", + ) + + results = temp_db.get_session_results(session_id) + assert len(results) == 1 + assert results[0]["module"] == "inject" + + +class TestConversations: + """Tests for conversation management""" + + @pytest.fixture + def temp_db(self): + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_aix.db") + db = AIXDatabase(db_path) + yield db + db.close() + + def test_save_conversation(self, temp_db): + """Test saving a conversation""" + transcript = [ + {"role": "user", "content": "hello", "turn_number": 1}, + {"role": "assistant", "content": "hi!", "turn_number": 1}, + ] + conv_id = temp_db.save_conversation( + target="https://example.com", + module="multiturn", + technique="crescendo", + transcript=transcript, + turn_count=2, + ) + assert conv_id is not None + + conv = temp_db.get_conversation(conv_id) + assert conv is not None + assert conv["target"] == "https://example.com" + assert conv["module"] == "multiturn" + assert conv["technique"] == "crescendo" + assert conv["turn_count"] == 2 + assert len(conv["transcript"]) == 2 + + def test_save_conversation_with_target_chat_id(self, temp_db): + """Test that target_chat_id is stored""" + conv_id = temp_db.save_conversation( + target="https://example.com", + module="inject", + technique="test", + target_chat_id="abc-123", + ) + + conv = temp_db.get_conversation(conv_id) + assert conv["target_chat_id"] == "abc-123" + + def test_save_conversation_with_session(self, temp_db): + """Test saving a conversation linked to a session""" + session_id = temp_db.create_session(target="https://example.com") + conv_id = temp_db.save_conversation( + target="https://example.com", + module="multiturn", + session_id=session_id, + ) + + conv = temp_db.get_conversation(conv_id) + assert conv["session_id"] == session_id + + def test_save_conversation_custom_id(self, temp_db): + """Test saving a conversation with a custom ID""" + custom_id = "my-custom-conv-id" + conv_id = temp_db.save_conversation( + target="https://example.com", + module="inject", + conversation_id=custom_id, + ) + assert conv_id == custom_id + + def test_get_conversation_not_found(self, temp_db): + """Test getting a non-existent conversation""" + result = temp_db.get_conversation("nonexistent-id") + assert result is None + + def test_list_conversations(self, temp_db): + """Test listing conversations""" + temp_db.save_conversation(target="https://t1.com", module="inject") + temp_db.save_conversation(target="https://t2.com", module="jailbreak") + + convs = temp_db.list_conversations() + assert len(convs) == 2 + + def test_list_conversations_by_session(self, temp_db): + """Test filtering conversations by session""" + session_id = temp_db.create_session(target="https://example.com") + temp_db.save_conversation( + target="https://example.com", module="inject", session_id=session_id + ) + temp_db.save_conversation( + target="https://example.com", module="jailbreak" + ) + + convs = temp_db.list_conversations(session_id=session_id) + assert len(convs) == 1 + assert convs[0]["module"] == "inject" + + def test_list_conversations_by_target(self, temp_db): + """Test filtering conversations by target""" + temp_db.save_conversation(target="https://target1.com", module="inject") + temp_db.save_conversation(target="https://target2.com", module="inject") + + convs = temp_db.list_conversations(target="target1.com") + assert len(convs) == 1 + + +class TestResultsWithSessionAndConversation: + """Tests for results with session_id and conversation_id""" + + @pytest.fixture + def temp_db(self): + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_aix.db") + db = AIXDatabase(db_path) + yield db + db.close() + + def test_add_result_with_session_id(self, temp_db): + """Test adding a result with session_id""" + session_id = temp_db.create_session(target="https://example.com") + temp_db.add_result( + target="https://example.com", + module="inject", + technique="test", + result="success", + payload="p", + response="r", + severity="high", + session_id=session_id, + ) + + results = temp_db.get_results(session_id=session_id) + assert len(results) == 1 + assert results[0]["session_id"] == session_id + + def test_add_result_with_conversation_id(self, temp_db): + """Test adding a result with conversation_id""" + temp_db.add_result( + target="https://example.com", + module="inject", + technique="test", + result="success", + payload="p", + response="r", + severity="high", + conversation_id="conv-123", + ) + + results = temp_db.get_results() + assert results[0]["conversation_id"] == "conv-123" + + def test_get_results_filtered_by_session(self, temp_db): + """Test filtering results by session_id""" + s1 = temp_db.create_session(target="https://example.com") + s2 = temp_db.create_session(target="https://example.com") + + temp_db.add_result( + target="https://example.com", module="inject", technique="t1", + result="success", payload="p", response="r", severity="high", + session_id=s1, + ) + temp_db.add_result( + target="https://example.com", module="jailbreak", technique="t2", + result="success", payload="p", response="r", severity="high", + session_id=s2, + ) + + results = temp_db.get_results(session_id=s1) + assert len(results) == 1 + assert results[0]["technique"] == "t1" + + def test_update_result_preserves_session_id(self, temp_db): + """Test that updating a result with COALESCE preserves session_id""" + session_id = temp_db.create_session(target="https://example.com") + temp_db.add_result( + target="https://example.com", module="inject", technique="test", + result="success", payload="p", response="r1", severity="high", + session_id=session_id, + ) + # Update same result without providing session_id + temp_db.add_result( + target="https://example.com", module="inject", technique="test", + result="success", payload="p", response="r2", severity="critical", + ) + + results = temp_db.get_results() + assert len(results) == 1 + assert results[0]["response"] == "r2" + assert results[0]["session_id"] == session_id diff --git a/tests/test_reporter.py b/tests/test_reporter.py index 14fffaf..ec5df5b 100644 --- a/tests/test_reporter.py +++ b/tests/test_reporter.py @@ -282,3 +282,397 @@ def test_findings_by_severity(self): # Verify all severities present found_severities = {f.severity for f in reporter.findings} assert found_severities == set(severities) + + +class TestRiskScore: + """Tests for risk score calculation""" + + def test_no_findings_score_zero(self): + """Test risk score is 0 with no findings""" + reporter = Reporter() + assert reporter.calculate_risk_score() == 0.0 + + def test_single_critical_finding(self): + """Test risk score with a single critical finding""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.CRITICAL, technique="t", payload="p", response="r") + ) + # critical=10, 10/5=2.0 + assert reporter.calculate_risk_score() == 2.0 + + def test_score_capped_at_ten(self): + """Test risk score is capped at 10.0""" + reporter = Reporter() + # 10 critical findings: 10*10=100, 100/5=20 -> capped at 10 + for _ in range(10): + reporter.add_finding( + Finding(title="t", severity=Severity.CRITICAL, technique="t", payload="p", response="r") + ) + assert reporter.calculate_risk_score() == 10.0 + + def test_mixed_severities(self): + """Test risk score with mixed severities""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + reporter.add_finding( + Finding(title="t", severity=Severity.MEDIUM, technique="t", payload="p", response="r") + ) + # high=7 + medium=4 = 11, 11/5 = 2.2 + assert abs(reporter.calculate_risk_score() - 2.2) < 0.01 + + def test_info_findings_zero_weight(self): + """Test that info findings contribute 0 to risk score""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.INFO, technique="t", payload="p", response="r") + ) + assert reporter.calculate_risk_score() == 0.0 + + def test_risk_level_critical(self): + """Test risk level classification for critical scores""" + reporter = Reporter() + assert reporter.get_risk_level(8.0) == "Critical" + assert reporter.get_risk_level(10.0) == "Critical" + + def test_risk_level_high(self): + """Test risk level classification for high scores""" + reporter = Reporter() + assert reporter.get_risk_level(5.0) == "High" + assert reporter.get_risk_level(7.9) == "High" + + def test_risk_level_medium(self): + """Test risk level classification for medium scores""" + reporter = Reporter() + assert reporter.get_risk_level(2.0) == "Medium" + assert reporter.get_risk_level(4.9) == "Medium" + + def test_risk_level_low(self): + """Test risk level classification for low scores""" + reporter = Reporter() + assert reporter.get_risk_level(0.0) == "Low" + assert reporter.get_risk_level(1.9) == "Low" + + +class TestOWASPCoverage: + """Tests for OWASP coverage mapping""" + + def test_empty_coverage(self): + """Test OWASP coverage with no findings""" + reporter = Reporter() + coverage = reporter.get_owasp_coverage() + assert len(coverage) == 10 + for cat_id, info in coverage.items(): + assert info["tested"] is False + assert info["findings_count"] == 0 + + def test_coverage_with_findings(self): + """Test OWASP coverage with findings tagged to categories""" + from aix.core.owasp import OWASPCategory + + reporter = Reporter() + reporter.add_finding( + Finding( + title="Injection", + severity=Severity.CRITICAL, + technique="test", + payload="p", + response="r", + owasp=[OWASPCategory.LLM01], + ) + ) + + coverage = reporter.get_owasp_coverage() + assert coverage["LLM01"]["tested"] is True + assert coverage["LLM01"]["findings_count"] == 1 + assert coverage["LLM01"]["max_severity"] == "critical" + # Other categories should remain untested + assert coverage["LLM02"]["tested"] is False + + def test_coverage_multiple_findings_same_category(self): + """Test OWASP coverage counts multiple findings per category""" + from aix.core.owasp import OWASPCategory + + reporter = Reporter() + reporter.add_finding( + Finding(title="t1", severity=Severity.HIGH, technique="t", payload="p", response="r", + owasp=[OWASPCategory.LLM06]) + ) + reporter.add_finding( + Finding(title="t2", severity=Severity.CRITICAL, technique="t", payload="p", response="r", + owasp=[OWASPCategory.LLM06]) + ) + + coverage = reporter.get_owasp_coverage() + assert coverage["LLM06"]["findings_count"] == 2 + assert coverage["LLM06"]["max_severity"] == "critical" + + def test_coverage_all_categories(self): + """Test that all 10 OWASP categories are present""" + reporter = Reporter() + coverage = reporter.get_owasp_coverage() + for i in range(1, 11): + assert f"LLM{i:02d}" in coverage + + +class TestExecutiveSummary: + """Tests for executive summary generation""" + + def test_no_findings_summary(self): + """Test summary with no findings""" + reporter = Reporter() + summary = reporter.generate_executive_summary() + assert "No vulnerabilities" in summary + + def test_critical_risk_summary(self): + """Test summary for critical risk""" + reporter = Reporter() + for _ in range(5): + reporter.add_finding( + Finding(title="t", severity=Severity.CRITICAL, technique="t", payload="p", response="r") + ) + summary = reporter.generate_executive_summary() + assert "Immediate remediation" in summary + assert "5 critical" in summary + + def test_high_risk_summary(self): + """Test summary for high risk""" + reporter = Reporter() + for _ in range(4): + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + summary = reporter.generate_executive_summary() + assert "Significant vulnerabilities" in summary + + def test_medium_risk_summary(self): + """Test summary for medium risk""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + summary = reporter.generate_executive_summary() + # high=7, 7/5=1.4 -> Low risk + # Actually 1.4 < 2 so this is "Low risk" + assert "risk" in summary.lower() + + def test_summary_contains_count(self): + """Test summary contains finding count""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.MEDIUM, technique="t", payload="p", response="r") + ) + reporter.add_finding( + Finding(title="t", severity=Severity.LOW, technique="t", payload="p", response="r") + ) + summary = reporter.generate_executive_summary() + assert "2 vulnerabilities" in summary + + +class TestScanMetadata: + """Tests for ScanMetadata integration""" + + def test_metadata_defaults(self): + """Test ScanMetadata default values""" + from aix.core.reporting.base import ScanMetadata + + meta = ScanMetadata() + assert meta.session_id is None + assert meta.target == "" + assert meta.modules_run == [] + assert meta.risk_score == 0.0 + + def test_reporter_with_metadata(self): + """Test reporter stores metadata""" + from aix.core.reporting.base import ScanMetadata + + reporter = Reporter() + reporter.metadata = ScanMetadata( + session_id="test-session", + target="https://example.com", + modules_run=["inject", "jailbreak"], + ) + assert reporter.metadata.session_id == "test-session" + assert reporter.metadata.modules_run == ["inject", "jailbreak"] + + def test_json_export_with_metadata(self): + """Test JSON export includes metadata""" + from aix.core.reporting.base import ScanMetadata + + reporter = Reporter() + reporter.metadata = ScanMetadata( + session_id="sid-123", + session_name="Test Session", + target="https://example.com", + modules_run=["inject"], + ) + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + filepath = f.name + + try: + reporter.export_json(filepath) + with open(filepath) as f: + data = json.load(f) + + assert data["scan_info"]["session_id"] == "sid-123" + assert data["scan_info"]["session_name"] == "Test Session" + assert data["scan_info"]["target"] == "https://example.com" + assert data["scan_info"]["modules_run"] == ["inject"] + assert data["scan_info"]["risk_score"] > 0 + assert "executive_summary" in data + assert "owasp_coverage" in data + finally: + Path(filepath).unlink(missing_ok=True) + + +class TestEnhancedHTMLExport: + """Tests for enhanced HTML export""" + + def test_html_contains_executive_summary(self): + """Test HTML report contains executive summary section""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + filepath = f.name + + try: + reporter.export_html(filepath) + content = Path(filepath).read_text() + assert "Executive Summary" in content + assert "Risk Score" in content + finally: + Path(filepath).unlink(missing_ok=True) + + def test_html_contains_severity_chart(self): + """Test HTML report contains severity distribution chart""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.CRITICAL, technique="t", payload="p", response="r") + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + filepath = f.name + + try: + reporter.export_html(filepath) + content = Path(filepath).read_text() + assert "Severity Distribution" in content + assert "chart-bar-fill" in content + finally: + Path(filepath).unlink(missing_ok=True) + + def test_html_contains_owasp_grid(self): + """Test HTML report contains OWASP coverage grid""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + filepath = f.name + + try: + reporter.export_html(filepath) + content = Path(filepath).read_text() + assert "OWASP LLM Top 10 Coverage" in content + assert "owasp-card" in content + assert "LLM01" in content + assert "LLM10" in content + finally: + Path(filepath).unlink(missing_ok=True) + + def test_html_contains_remediation(self): + """Test HTML report contains remediation when OWASP findings exist""" + from aix.core.owasp import OWASPCategory + + reporter = Reporter() + reporter.add_finding( + Finding( + title="Injection Finding", + severity=Severity.CRITICAL, + technique="test", + payload="p", + response="r", + owasp=[OWASPCategory.LLM01], + ) + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + filepath = f.name + + try: + reporter.export_html(filepath) + content = Path(filepath).read_text() + assert "Remediation Recommendations" in content + assert "Prompt Injection" in content + finally: + Path(filepath).unlink(missing_ok=True) + + def test_html_no_remediation_without_owasp(self): + """Test HTML report omits remediation when no OWASP findings""" + reporter = Reporter() + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + filepath = f.name + + try: + reporter.export_html(filepath) + content = Path(filepath).read_text() + assert "Remediation Recommendations" not in content + finally: + Path(filepath).unlink(missing_ok=True) + + def test_html_contains_metadata_section(self): + """Test HTML report contains metadata when provided""" + from aix.core.reporting.base import ScanMetadata + + reporter = Reporter() + reporter.metadata = ScanMetadata( + session_name="Test Session", + target="https://example.com", + modules_run=["inject", "jailbreak"], + start_time=datetime(2026, 1, 15, 10, 30), + end_time=datetime(2026, 1, 15, 10, 35), + ) + reporter.add_finding( + Finding(title="t", severity=Severity.HIGH, technique="t", payload="p", response="r") + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + filepath = f.name + + try: + reporter.export_html(filepath) + content = Path(filepath).read_text() + assert "Test Session" in content + assert "https://example.com" in content + assert "inject, jailbreak" in content + finally: + Path(filepath).unlink(missing_ok=True) + + def test_html_contains_aix_version(self): + """Test HTML footer contains AIX version""" + reporter = Reporter() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f: + filepath = f.name + + try: + reporter.export_html(filepath) + content = Path(filepath).read_text() + assert "AIX v" in content + assert "AI eXploit Framework" in content + finally: + Path(filepath).unlink(missing_ok=True) diff --git a/tests/test_scanner.py b/tests/test_scanner.py index b28ab90..1b0500d 100644 --- a/tests/test_scanner.py +++ b/tests/test_scanner.py @@ -259,6 +259,47 @@ def test_connector_inherits_timeout(self): # Connector uses scanner's timeout (default 30) assert connector.config.get("timeout") == 30 # Default timeout + def test_create_connector_ws_url(self): + """Test ws:// URL creates WebSocketConnector""" + from aix.core.connector import WebSocketConnector + from aix.modules.inject import InjectScanner + + scanner = InjectScanner(target="ws://target.com/chat") + + connector = scanner._create_connector() + + assert isinstance(connector, WebSocketConnector) + + def test_create_connector_wss_url(self): + """Test wss:// URL creates WebSocketConnector""" + from aix.core.connector import WebSocketConnector + from aix.modules.inject import InjectScanner + + scanner = InjectScanner(target="wss://target.com/chat") + + connector = scanner._create_connector() + + assert isinstance(connector, WebSocketConnector) + + def test_ws_connector_inherits_params(self): + """Test WebSocketConnector gets scanner parameters""" + from aix.core.connector import WebSocketConnector + from aix.modules.inject import InjectScanner + + scanner = InjectScanner( + target="wss://target.com/chat", + injection_param="query", + response_path="data.text", + cookies="session=abc", + ) + + connector = scanner._create_connector() + + assert isinstance(connector, WebSocketConnector) + assert connector.injection_param == "query" + assert connector.response_path == "data.text" + assert connector.cookies == "session=abc" + class TestScannerEvaluator: """Tests for evaluator integration"""