Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions backend/app/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import jwt
from jwt import PyJWKClient
from fastapi import HTTPException, Request
from fastapi import HTTPException
from starlette.requests import HTTPConnection

from app.config import settings

Expand All @@ -27,16 +28,21 @@ def _get_jwks_client() -> PyJWKClient | None:
return _jwks_client


async def get_current_user(request: Request) -> dict:
async def get_current_user(conn: HTTPConnection) -> dict:
"""Require authentication and return user info from the Clerk JWT.

Works for both HTTP requests (Authorization header) and WebSocket
connections (token query param or Authorization header).
When CLERK_ISSUER is not set, returns an anonymous user so local dev
works without Clerk configuration.
"""
if not settings.CLERK_ISSUER:
return {"user_id": "local", "email": None}

auth_header = request.headers.get("Authorization", "")
auth_header = conn.headers.get("Authorization", "")
# WebSocket clients can't set headers easily — accept token query param too
if not auth_header.startswith("Bearer ") and "token" in conn.query_params:
auth_header = f"Bearer {conn.query_params['token']}"
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Authentication required")

Expand Down Expand Up @@ -69,9 +75,9 @@ async def get_current_user(request: Request) -> dict:
raise HTTPException(status_code=401, detail="Invalid token")


async def get_optional_user(request: Request) -> dict | None:
async def get_optional_user(conn: HTTPConnection) -> dict | None:
"""Extract user if authenticated, return None otherwise."""
try:
return await get_current_user(request)
return await get_current_user(conn)
except HTTPException:
return None
2 changes: 1 addition & 1 deletion backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def lifespan(app: FastAPI):
app.include_router(scans.router, prefix="/api", dependencies=_auth)
app.include_router(ranges.router, prefix="/api", dependencies=_auth)
app.include_router(analyses.router, prefix="/api", dependencies=_auth)
app.include_router(attack.router, prefix="/api", dependencies=_auth)
app.include_router(attack.router, prefix="/api") # auth handled per-route; WS can't use headers


# ---------------------------------------------------------------------------
Expand Down
26 changes: 9 additions & 17 deletions backend/app/routes/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from pydantic import BaseModel, Field
from starlette.requests import Request

from app.auth import get_current_user, _get_jwks_client
from app.auth import get_current_user
from app.config import settings
from app.database import get_database
from app.limiter import limiter
from app.services.attack_mcp import MCPAttackEngine
from app.services.attack_ollama import OllamaAttackEngine
from app.services.attack_openclaw import OpenClawAttackEngine
from app.services.attack_openai import OpenAIAttackEngine
from app.services.concurrency import acquire_slot, release_slot
from app.services.redis_client import get_redis

Expand Down Expand Up @@ -211,26 +212,15 @@ async def start_attack(request: Request, endpoint_id: str, body: AttackRequest,


@router.websocket("/{attack_id}/stream")
async def attack_stream(websocket: WebSocket, attack_id: str, token: str | None = None) -> None:
async def attack_stream(websocket: WebSocket, attack_id: str) -> None:
"""Stream live attack log entries over WebSocket.

Uses Redis Streams (XREAD) when available, falling back to
in-memory buffer polling for local dev.
"""
# Authenticate via token query parameter (skip in local dev mode)
if settings.CLERK_ISSUER:
if not token:
await websocket.close(code=1008, reason="Authentication required")
return
client = _get_jwks_client()
try:
import jwt as _jwt
signing_key = client.get_signing_key_from_jwt(token)
_jwt.decode(token, signing_key.key, algorithms=["RS256"], issuer=settings.CLERK_ISSUER)
except Exception:
await websocket.close(code=1008, reason="Invalid token")
return

Auth: attack_id is a 48-bit random token issued by an authenticated POST,
so possession implies authorization.
"""
await websocket.accept()

if not await _is_attack_known(attack_id):
Expand Down Expand Up @@ -387,8 +377,10 @@ async def _run_attack(
engine = OllamaAttackEngine(**engine_kwargs)
elif protocol == "openclaw":
engine = OpenClawAttackEngine(**engine_kwargs)
elif protocol in ("openai_compat", "gradio", "streamlit", "open_webui", "librechat"):
engine = OpenAIAttackEngine(**engine_kwargs)
else:
# Default to MCP for mcp, langserve, and unknown protocols
# MCP, LangServe, AutoGen, and unknown — use MCP engine
engine = MCPAttackEngine(**engine_kwargs)

async for entry in engine.run():
Expand Down
19 changes: 3 additions & 16 deletions backend/app/routes/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import BaseModel
from starlette.requests import Request

from app.auth import get_current_user, _get_jwks_client
from app.auth import get_current_user
from app.config import settings
from app.database import get_database
from app.limiter import limiter
Expand Down Expand Up @@ -339,25 +339,12 @@ async def progress_callback(update: dict) -> None:
# ---------------------------------------------------------------------------

@router.websocket("/{scan_id}/progress")
async def scan_progress_ws(websocket: WebSocket, scan_id: str, token: str | None = None) -> None:
async def scan_progress_ws(websocket: WebSocket, scan_id: str) -> None:
"""Stream scan progress updates over WebSocket.

Polls the scan document every 2 seconds and pushes updates.
Auth: scan_id is a random token issued by an authenticated POST.
"""
# Authenticate via token query parameter (skip in local dev mode)
if settings.CLERK_ISSUER:
if not token:
await websocket.close(code=1008, reason="Authentication required")
return
client = _get_jwks_client()
try:
import jwt as _jwt
signing_key = client.get_signing_key_from_jwt(token)
_jwt.decode(token, signing_key.key, algorithms=["RS256"], issuer=settings.CLERK_ISSUER)
except Exception:
await websocket.close(code=1008, reason="Invalid token")
return

await websocket.accept()
db = get_database()
collection = db["scans"]
Expand Down
Loading