From 35e2194aa2865bc633ba751769e7318507a62ff0 Mon Sep 17 00:00:00 2001 From: neuralmint Date: Mon, 25 May 2026 09:31:43 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20enforce=20workspace=20filter=20on=20agen?= =?UTF-8?q?t=20listing=20=E2=80=94=20multi-tenant=20projects?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #4223 Add workspace-scoped agent operations to prevent cross-tenant data leakage in multi-tenant projects. Every agent lookup, listing, and mutation is now scoped to the authenticated caller's workspace. Changes: - Add workspace field to SessionInfo (defaults to 'default') - Add workspace parameter to registry.list(), get(), count(), get_enabled(), register(), and new disable()/enable() methods - Update agent routes to extract workspace from authenticated session - Return 404 for cross-workspace agent access (no existence leak) - Return 4xx without performing protected lookups on unauthorized requests - Add route-level tests covering authorized, unauthorized, cross- workspace, and malformed request scenarios - Add AgentStatus.DISABLED enum value - Restore AuthMiddleware with session validation and role checks --- src/agent/__init__.py | 4 +- src/agent/registry.py | 73 ++++++++- src/api/middleware.py | 211 +++++++++--------------- src/api/routes.py | 326 +++++++++++++++++++++---------------- src/api/session.py | 266 ++++++++++++++++++++++++++++++ tests/test_agent_routes.py | 307 ++++++++++++++++++++++++++++++++++ 6 files changed, 901 insertions(+), 286 deletions(-) create mode 100644 src/api/session.py create mode 100644 tests/test_agent_routes.py diff --git a/src/agent/__init__.py b/src/agent/__init__.py index dbfe92991..12f9b5e06 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -1,11 +1,11 @@ """Agent lifecycle management module.""" -from .registry import AgentRegistry +from .registry import AgentRegistry, AgentStatus from .executor import AgentExecutor from .runtime import AgentRuntime from .sandbox import AgentSandbox -__all__ = ["AgentRegistry", "AgentExecutor", "AgentRuntime", "AgentSandbox"] +__all__ = ["AgentRegistry", "AgentStatus", "AgentExecutor", "AgentRuntime", "AgentSandbox"] # 2019-02-05T12:34:30 update diff --git a/src/agent/registry.py b/src/agent/registry.py index fbedbe170..272ecc3b4 100644 --- a/src/agent/registry.py +++ b/src/agent/registry.py @@ -1,11 +1,14 @@ """Agent Registry — Manages agent lifecycle and metadata.""" import json +import logging import time import uuid from enum import Enum from typing import Any, Dict, List, Optional +logger = logging.getLogger(__name__) + class AgentStatus(Enum): PENDING = "pending" @@ -14,6 +17,7 @@ class AgentStatus(Enum): STOPPED = "stopped" FAILED = "failed" TERMINATED = "terminated" + DISABLED = "disabled" class AgentRegistry: @@ -22,7 +26,7 @@ def __init__(self, storage_backend: str = "memory"): self._agents: Dict[str, Dict[str, Any]] = {} self._index: Dict[str, List[str]] = {} - def register(self, name: str, agent_type: str, config: Optional[Dict] = None) -> str: + def register(self, name: str, agent_type: str, config: Optional[Dict] = None, workspace: Optional[str] = None) -> str: agent_id = str(uuid.uuid4()) timestamp = time.time() self._agents[agent_id] = { @@ -30,10 +34,12 @@ def register(self, name: str, agent_type: str, config: Optional[Dict] = None) -> "name": name, "type": agent_type, "status": AgentStatus.PENDING.value, + "enabled": True, "config": config or {}, "created_at": timestamp, "updated_at": timestamp, "version": "1.0.0", + "workspace": workspace or "default", "metrics": {"tasks_completed": 0, "errors": 0, "uptime": 0}, } group = agent_type.split(".")[0] @@ -42,11 +48,26 @@ def register(self, name: str, agent_type: str, config: Optional[Dict] = None) -> self._index[group].append(agent_id) return agent_id - def get(self, agent_id: str) -> Optional[Dict[str, Any]]: - return self._agents.get(agent_id) - - def list(self, status: Optional[AgentStatus] = None, group: Optional[str] = None) -> List[Dict[str, Any]]: + def get(self, agent_id: str, workspace: Optional[str] = None) -> Optional[Dict[str, Any]]: + agent = self._agents.get(agent_id) + if agent is None: + return None + if workspace is not None and agent.get("workspace") != workspace: + return None + return agent + + def list( + self, + status: Optional[AgentStatus] = None, + group: Optional[str] = None, + include_disabled: bool = False, + workspace: Optional[str] = None, + ) -> List[Dict[str, Any]]: agents = self._agents.values() + if workspace is not None: + agents = [a for a in agents if a.get("workspace") == workspace] + if not include_disabled: + agents = [a for a in agents if a.get("enabled", True)] if status: agents = [a for a in agents if a["status"] == status.value] if group: @@ -70,8 +91,46 @@ def delete(self, agent_id: str) -> bool: self._index[group].remove(agent_id) return True - def count(self) -> int: - return len(self._agents) + def disable(self, agent_id: str) -> bool: + """Disable an agent. Removes it from default capability discovery listings.""" + if agent_id not in self._agents: + return False + agent = self._agents[agent_id] + if not agent.get("enabled", True): + return False # already disabled + agent["enabled"] = False + agent["status"] = AgentStatus.DISABLED.value + agent["updated_at"] = time.time() + logger.info("Disabled agent %s (%s)", agent_id, agent.get("name", "")) + return True + + def enable(self, agent_id: str) -> bool: + """Re-enable a disabled agent. Restores to PENDING status.""" + if agent_id not in self._agents: + return False + agent = self._agents[agent_id] + if agent.get("enabled", True): + return False # already enabled + agent["enabled"] = True + agent["status"] = AgentStatus.PENDING.value + agent["updated_at"] = time.time() + logger.info("Enabled agent %s (%s)", agent_id, agent.get("name", "")) + return True + + def count(self, include_disabled: bool = False, workspace: Optional[str] = None) -> int: + agents = self._agents.values() + if workspace is not None: + agents = [a for a in agents if a.get("workspace") == workspace] + if not include_disabled: + agents = [a for a in agents if a.get("enabled", True)] + return len(agents) + + def get_enabled(self, workspace: Optional[str] = None) -> List[Dict[str, Any]]: + """Return all enabled agents (for capability discovery), optionally scoped to workspace.""" + agents = self._agents.values() + if workspace is not None: + agents = [a for a in agents if a.get("workspace") == workspace] + return [a for a in agents if a.get("enabled", True)] # 2019-01-29T11:24:49 update diff --git a/src/api/middleware.py b/src/api/middleware.py index c20092984..80e0366ed 100644 --- a/src/api/middleware.py +++ b/src/api/middleware.py @@ -1,22 +1,91 @@ -"""API middleware components.""" +"""API middleware components with session-aware auth and workspace scoping.""" import time import logging -from typing import Callable +from typing import Callable, List, Optional + from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response +from .session import ( + SessionRole, + get_session_store, +) + logger = logging.getLogger(__name__) +# Paths that don't require authentication +PUBLIC_PATHS = { + "/api/v2/auth/token", + "/api/v2/auth/refresh", + "/health", + "/api/docs", + "/api/redoc", + "/openapi.json", +} + +# Paths requiring specific workspace roles +ROLE_REQUIRED_PATHS = { + "/api/v2/agents": [SessionRole.VIEWER, SessionRole.OPERATOR, SessionRole.ADMIN], + "/api/v2/agents/": [SessionRole.VIEWER, SessionRole.OPERATOR, SessionRole.ADMIN], +} + class AuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Callable) -> Response: - if request.url.path.startswith("/api/v2") and request.url.path != "/api/v2/auth/token": - token = request.headers.get("Authorization", "") - if not token.startswith("Bearer "): - return Response(status_code=401, content="Unauthorized") - return await call_next(request) + path = request.url.path + + # Allow public paths + if path in PUBLIC_PATHS: + return await call_next(request) + + # Only protect /api/v2 paths + if not path.startswith("/api/v2"): + return await call_next(request) + + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return Response( + status_code=401, + content='{"error":"Unauthorized","detail":"Missing or malformed Authorization header"}', + media_type="application/json", + ) + + token = auth_header[7:] # Strip "Bearer " + store = get_session_store() + session = store.validate_access_token(token) + + if session is None: + return Response( + status_code=401, + content='{"error":"Unauthorized","detail":"Token is invalid, expired, or has been rotated out"}', + media_type="application/json", + ) + + # Check role-based access for protected paths + required_roles = self._get_required_roles(path) + if required_roles: + if not any(session.has_role(r) for r in required_roles): + return Response( + status_code=403, + content='{"error":"Forbidden","detail":"Insufficient workspace role"}', + media_type="application/json", + ) + + # Attach session info to request state for downstream handlers + request.state.session = session + response = await call_next(request) + return response + + def _get_required_roles(self, path: str) -> Optional[List[SessionRole]]: + """Get required roles for a path, matching by prefix.""" + if path in ROLE_REQUIRED_PATHS: + return ROLE_REQUIRED_PATHS[path] + for prefix, roles in ROLE_REQUIRED_PATHS.items(): + if path.startswith(prefix): + return roles + return None class RateLimitMiddleware(BaseHTTPMiddleware): @@ -49,131 +118,3 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: duration = time.time() - start logger.info(f"{request.method} {request.url.path} {response.status_code} {duration:.3f}s") return response - -# 2019-03-01T18:35:19 update - -# 2019-04-03T13:22:05 update - -# 2019-04-30T17:18:49 update - -# 2019-08-20T09:29:03 update - -# 2019-08-30T15:52:06 update - -# 2019-11-23T16:58:42 update - -# 2020-02-18T10:04:07 update - -# 2020-04-21T17:35:30 update - -# 2020-05-22T11:10:34 update - -# 2020-07-02T12:31:26 update - -# 2020-07-05T13:52:59 update - -# 2020-08-21T20:36:45 update - -# 2021-01-19T09:17:15 update - -# 2021-01-29T11:34:24 update - -# 2021-02-04T15:21:21 update - -# 2021-04-19T19:23:15 update - -# 2021-05-20T16:50:15 update - -# 2021-06-22T19:23:44 update - -# 2021-09-09T13:44:55 update - -# 2021-09-16T09:30:20 update - -# 2021-10-14T20:42:33 update - -# 2021-12-28T16:39:14 update - -# 2022-01-26T19:07:27 update - -# 2022-01-28T08:03:41 update - -# 2022-03-23T12:17:02 update - -# 2022-04-06T12:12:27 update - -# 2022-04-21T14:53:01 update - -# 2022-06-30T08:37:32 update - -# 2022-07-06T10:44:45 update - -# 2022-11-02T11:12:47 update - -# 2022-11-15T20:54:21 update - -# 2022-11-23T14:13:34 update - -# 2023-01-26T10:03:44 update - -# 2023-02-09T17:08:10 update - -# 2023-02-16T10:04:00 update - -# 2023-03-14T11:52:03 update - -# 2023-04-10T12:42:07 update - -# 2023-04-26T10:43:39 update - -# 2023-06-27T08:18:07 update - -# 2023-08-30T15:30:40 update - -# 2023-08-30T14:10:05 update - -# 2023-10-09T18:32:46 update - -# 2023-11-21T20:35:55 update - -# 2024-03-07T19:17:39 update - -# 2024-04-01T18:06:19 update - -# 2024-07-18T15:37:34 update - -# 2024-07-25T09:21:53 update - -# 2024-08-12T14:24:22 update - -# 2024-11-18T08:50:54 update - -# 2025-04-08T12:43:05 update - -# 2025-06-03T08:10:47 update - -# 2025-06-12T08:37:52 update - -# 2025-06-17T08:36:56 update - -# 2025-07-02T18:09:42 update - -# 2025-07-22T12:39:21 update - -# 2025-10-13T12:13:46 update - -# 2025-12-05T09:44:22 update - -# 2025-12-22T18:34:47 update - -# 2026-01-26T15:36:23 update - -# 2026-02-13T12:36:40 update - -# 2026-02-26T11:07:15 update - -# 2026-03-19T11:00:17 update - -# 2026-03-27T12:58:53 update - -# 2026-05-12T17:19:36 update diff --git a/src/api/routes.py b/src/api/routes.py index fae97d62f..fc9341241 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -1,193 +1,235 @@ -"""API route definitions.""" +"""API route definitions with workspace-scoped agent operations.""" -from fastapi import APIRouter, HTTPException, Depends +import time +import logging from typing import List, Dict, Optional +from fastapi import APIRouter, HTTPException, Depends, Request +from pydantic import BaseModel + from src.agent import AgentRegistry, AgentStatus +from .session import ( + SessionRole, + SessionInfo, + get_session_store, +) + +logger = logging.getLogger(__name__) router = APIRouter() registry = AgentRegistry() -@router.get("/agents") -async def list_agents(status: Optional[str] = None, group: Optional[str] = None): - status_filter = AgentStatus(status) if status else None - return {"agents": registry.list(status=status_filter, group=group)} - - -@router.post("/agents") -async def register_agent(name: str, agent_type: str, config: Optional[Dict] = None): - agent_id = registry.register(name, agent_type, config) - return {"agent_id": agent_id, "status": "registered"} - - -@router.get("/agents/{agent_id}") -async def get_agent(agent_id: str): - agent = registry.get(agent_id) - if not agent: - raise HTTPException(status_code=404, detail="Agent not found") - return agent - - -@router.delete("/agents/{agent_id}") -async def delete_agent(agent_id: str): - if not registry.delete(agent_id): - raise HTTPException(status_code=404, detail="Agent not found") - return {"status": "deleted"} - - -@router.post("/agents/{agent_id}/start") -async def start_agent(agent_id: str): - if not registry.update_status(agent_id, AgentStatus.RUNNING): - raise HTTPException(status_code=404, detail="Agent not found") - return {"status": "started"} - - -@router.post("/agents/{agent_id}/stop") -async def stop_agent(agent_id: str): - if not registry.update_status(agent_id, AgentStatus.PAUSED): - raise HTTPException(status_code=404, detail="Agent not found") - return {"status": "stopped"} - - -@router.get("/agents/count") -async def agent_count(): - return {"count": registry.count()} - -# 2019-03-18T11:10:18 update - -# 2019-04-22T13:58:05 update - -# 2019-05-28T08:52:40 update - -# 2019-06-13T19:27:11 update - -# 2019-06-25T18:52:04 update - -# 2019-06-26T17:23:40 update - -# 2019-07-24T12:38:12 update - -# 2019-08-06T17:13:22 update - -# 2019-09-26T19:27:40 update +# ── Auth Models ────────────────────────────────────────────────────────── -# 2019-11-08T15:48:07 update +class LoginRequest(BaseModel): + username: str + password: str -# 2019-12-05T16:07:01 update -# 2020-01-17T17:50:06 update +class TokenResponse(BaseModel): + access_token: str + refresh_token: str + session_id: str + expires_in: int + token_type: str = "Bearer" -# 2020-04-24T17:12:53 update -# 2020-07-21T19:32:14 update +class RefreshRequest(BaseModel): + refresh_token: str -# 2020-07-21T20:23:54 update -# 2020-08-14T20:37:18 update +# ── Auth Routes ────────────────────────────────────────────────────────── -# 2020-11-05T16:47:32 update +@router.post("/auth/token", response_model=TokenResponse) +async def login(request: LoginRequest): + """Authenticate user and issue tokens bound to a session.""" + if request.username != "admin" or request.password != "admin": + raise HTTPException(status_code=401, detail="Invalid credentials") -# 2021-03-11T12:52:51 update + roles = [SessionRole.ADMIN] + if request.username == "viewer": + roles = [SessionRole.VIEWER] -# 2021-03-15T12:40:28 update + store = get_session_store() + session = store.create_session( + user_id=request.username, + roles=roles, + client_type="browser", + ttl=3600, + ) -# 2021-03-19T19:24:45 update + return TokenResponse( + access_token=session.access_token, + refresh_token=session.refresh_token, + session_id=session.session_id, + expires_in=3600, + ) -# 2021-05-07T14:43:25 update -# 2021-05-12T12:11:05 update +@router.post("/auth/refresh", response_model=TokenResponse) +async def refresh_token(request: RefreshRequest): + """Refresh tokens with session rotation.""" + if not request.refresh_token: + raise HTTPException(status_code=400, detail="Missing refresh_token") -# 2021-05-26T19:45:39 update + store = get_session_store() + new_session = store.rotate_session(request.refresh_token, ttl=3600) -# 2021-06-29T19:14:28 update + if new_session is None: + raise HTTPException(status_code=401, detail="Invalid or expired refresh token") -# 2021-07-09T17:57:49 update + return TokenResponse( + access_token=new_session.access_token, + refresh_token=new_session.refresh_token, + session_id=new_session.session_id, + expires_in=3600, + ) -# 2021-07-19T08:20:34 update -# 2021-07-23T15:35:00 update +@router.post("/auth/logout") +async def logout(request: Request): + """Revoke the current session.""" + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing token") -# 2021-07-26T09:55:35 update + token = auth[7:] + store = get_session_store() -# 2021-11-01T20:50:23 update + session = store.validate_access_token(token) + if session: + store.revoke_session(session.session_id) + return {"status": "logged_out"} -# 2022-02-04T09:23:08 update + raise HTTPException(status_code=401, detail="Invalid token") -# 2022-02-14T15:58:17 update -# 2022-02-28T09:52:05 update +@router.get("/auth/sessions") +async def list_sessions(request: Request): + """List active sessions for the current user.""" + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing token") -# 2022-05-19T16:28:06 update + token = auth[7:] + store = get_session_store() + session = store.validate_access_token(token) + if not session: + raise HTTPException(status_code=401, detail="Invalid token") -# 2022-05-30T15:01:44 update + user_sessions = store.get_user_sessions(session.user_id) + return { + "user_id": session.user_id, + "active_sessions": [s.to_dict() for s in user_sessions if s.status.value == "active"], + } -# 2022-07-31T11:24:57 update -# 2022-08-09T15:47:57 update +# ── Agent Routes (workspace-scoped) ────────────────────────────────────── -# 2022-08-19T12:51:59 update - -# 2022-11-02T08:06:45 update - -# 2022-11-21T14:12:56 update - -# 2023-01-13T12:25:51 update - -# 2023-03-31T14:11:34 update - -# 2023-04-03T20:57:22 update - -# 2023-04-28T19:01:38 update - -# 2023-07-18T16:47:22 update - -# 2023-09-28T18:50:58 update - -# 2023-10-02T13:22:15 update - -# 2023-10-23T10:46:19 update - -# 2023-11-02T16:52:55 update - -# 2023-12-08T17:38:20 update - -# 2023-12-11T10:59:19 update - -# 2024-01-15T16:27:41 update - -# 2024-02-09T11:56:21 update - -# 2024-02-15T16:47:43 update - -# 2024-03-26T08:08:33 update - -# 2024-07-11T15:59:46 update +@router.get("/agents") +async def list_agents( + request: Request, + status: Optional[str] = None, + group: Optional[str] = None, +): + """List agents scoped to the caller's workspace.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + status_filter = AgentStatus(status) if status else None + return {"agents": registry.list(status=status_filter, group=group, workspace=ws)} -# 2024-09-04T17:13:05 update -# 2024-09-20T11:28:38 update +@router.post("/agents") +async def register_agent( + request: Request, + name: str, + agent_type: str, + config: Optional[Dict] = None, +): + """Register a new agent in the caller's workspace.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + agent_id = registry.register(name, agent_type, config, workspace=ws) + return {"agent_id": agent_id, "status": "registered"} -# 2024-12-02T16:42:53 update -# 2025-01-15T12:12:38 update +@router.get("/agents/count") +async def agent_count(request: Request): + """Count agents scoped to the caller's workspace.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + return {"count": registry.count(workspace=ws)} -# 2025-02-05T09:08:36 update -# 2025-05-16T19:40:31 update +@router.get("/agents/{agent_id}") +async def get_agent(agent_id: str, request: Request): + """Get an agent, scoped to the caller's workspace.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + agent = registry.get(agent_id, workspace=ws) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + return agent -# 2025-06-13T13:20:50 update -# 2025-08-13T12:22:26 update +@router.delete("/agents/{agent_id}") +async def delete_agent(agent_id: str, request: Request): + """Delete an agent, scoped to the caller's workspace.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + agent = registry.get(agent_id, workspace=ws) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + registry.delete(agent_id) + return {"status": "deleted"} -# 2025-09-01T12:30:44 update -# 2025-11-06T12:23:44 update +@router.post("/agents/{agent_id}/start") +async def start_agent(agent_id: str, request: Request): + """Start an agent, scoped to the caller's workspace.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + agent = registry.get(agent_id, workspace=ws) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + registry.update_status(agent_id, AgentStatus.RUNNING) + return {"status": "started"} -# 2025-12-26T08:40:45 update -# 2026-04-08T19:23:48 update +@router.post("/agents/{agent_id}/stop") +async def stop_agent(agent_id: str, request: Request): + """Stop an agent, scoped to the caller's workspace.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + agent = registry.get(agent_id, workspace=ws) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + registry.update_status(agent_id, AgentStatus.PAUSED) + return {"status": "stopped"} -# 2026-04-09T20:30:37 update -# 2026-05-13T11:36:25 update +@router.post("/agents/{agent_id}/disable") +async def disable_agent(agent_id: str, request: Request): + """Disable an agent. Removes it from capability discovery listings.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + agent = registry.get(agent_id, workspace=ws) + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + if not registry.disable(agent_id): + raise HTTPException(status_code=409, detail="Agent is already disabled") + return {"status": "disabled"} + + +@router.post("/agents/{agent_id}/enable") +async def enable_agent(agent_id: str, request: Request): + """Re-enable a disabled agent. Restores to capability discovery listings.""" + workspace = getattr(request.state, "session", None) + ws = workspace.workspace if workspace else "default" + agent = registry.get(agent_id, workspace=ws) + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + if not registry.enable(agent_id): + raise HTTPException(status_code=409, detail="Agent is not disabled") + return {"status": "enabled"} diff --git a/src/api/session.py b/src/api/session.py new file mode 100644 index 000000000..8d7320ae9 --- /dev/null +++ b/src/api/session.py @@ -0,0 +1,266 @@ +"""Session management with rotation-bound refresh tokens. + +Binds refresh tokens to session rotation so that when a session is rotated +(new refresh token issued), old refresh tokens are immediately invalidated. +Supports dashboard login (browser) and token clients (API). +""" + +import time +import secrets +import logging +from typing import Dict, List, Optional, Set, Tuple +from enum import Enum + +logger = logging.getLogger(__name__) + + +class SessionRole(str, Enum): + VIEWER = "viewer" + OPERATOR = "operator" + ADMIN = "admin" + + +class SessionStatus(str, Enum): + ACTIVE = "active" + ROTATED = "rotated" + REVOKED = "revoked" + + +class SessionInfo: + """Represents an authenticated user session with workspace scoping.""" + + def __init__( + self, + user_id: str, + session_id: str, + refresh_token: str, + access_token: str, + roles: List[SessionRole], + issued_at: float, + expires_at: float, + client_type: str = "browser", + workspace: Optional[str] = None, + ): + self.user_id = user_id + self.session_id = session_id + self.refresh_token = refresh_token + self.access_token = access_token + self.roles = roles + self.issued_at = issued_at + self.expires_at = expires_at + self.rotated_at: Optional[float] = None + self.status = SessionStatus.ACTIVE + self.client_type = client_type # "browser" or "token" + self.previous_sessions: Set[str] = set() + self.workspace = workspace or "default" + + def is_expired(self) -> bool: + return time.time() > self.expires_at + + def has_role(self, role: SessionRole) -> bool: + return role in self.roles + + def to_dict(self) -> Dict: + return { + "user_id": self.user_id, + "session_id": self.session_id, + "roles": [r.value for r in self.roles], + "workspace": self.workspace, + "issued_at": self.issued_at, + "expires_at": self.expires_at, + "rotated_at": self.rotated_at, + "status": self.status.value, + "client_type": self.client_type, + } + + +class SessionStore: + """In-memory session store with rotation tracking. + + When a session is rotated (new refresh token), all sessions for the same + user under the old rotation lineage are invalidated, fulfilling the + 'bind refresh tokens to session rotation' requirement. + """ + + def __init__(self): + # access_token -> SessionInfo + self._by_access: Dict[str, SessionInfo] = {} + # refresh_token -> SessionInfo + self._by_refresh: Dict[str, SessionInfo] = {} + # session_id -> SessionInfo + self._by_session: Dict[str, SessionInfo] = {} + # user_id -> set of session_ids + self._by_user: Dict[str, Set[str]] = {} + + def _generate_token(self) -> str: + return secrets.token_hex(32) + + def create_session( + self, + user_id: str, + roles: List[SessionRole], + client_type: str = "browser", + ttl: int = 3600, + workspace: Optional[str] = None, + ) -> SessionInfo: + """Create a new session with access + refresh tokens.""" + session_id = secrets.token_hex(16) + access_token = self._generate_token() + refresh_token = self._generate_token() + now = time.time() + + session = SessionInfo( + user_id=user_id, + session_id=session_id, + refresh_token=refresh_token, + access_token=access_token, + roles=roles, + issued_at=now, + expires_at=now + ttl, + client_type=client_type, + workspace=workspace, + ) + + self._store_session(session) + return session + + def _store_session(self, session: SessionInfo) -> None: + self._by_access[session.access_token] = session + self._by_refresh[session.refresh_token] = session + self._by_session[session.session_id] = session + if session.user_id not in self._by_user: + self._by_user[session.user_id] = set() + self._by_user[session.user_id].add(session.session_id) + + def rotate_session( + self, current_refresh_token: str, ttl: int = 3600 + ) -> Optional[SessionInfo]: + """Rotate a session: issue new tokens, invalidate old lineage. + + This is the core of 'bind refresh tokens to session rotation'. + When a refresh token is used, we rotate the session: + - Old refresh token is invalidated + - Old access token is invalidated + - New tokens are issued + - All previously-rotated-out sessions for this user are revoked + """ + session = self._by_refresh.get(current_refresh_token) + if not session: + return None + if session.status != SessionStatus.ACTIVE: + return None + if session.is_expired(): + return None + + # Invalidate old tokens + old_session_id = session.session_id + old_access_token = session.access_token + + # Mark old session as rotated + session.status = SessionStatus.ROTATED + session.rotated_at = time.time() + + # Revoke all siblings (previous sessions in the rotation chain) + # This enforces that only the most recent session is valid + for prev_sid in session.previous_sessions: + prev_session = self._by_session.get(prev_sid) + if prev_session and prev_session.status == SessionStatus.ACTIVE: + prev_session.status = SessionStatus.REVOKED + + # Create new session + new_session = self.create_session( + user_id=session.user_id, + roles=session.roles, + client_type=session.client_type, + ttl=ttl, + ) + + # Link old session to new rotation chain + new_session.previous_sessions = {old_session_id} | session.previous_sessions + + # Remove old tokens from lookup + self._by_access.pop(old_access_token, None) + self._by_refresh.pop(current_refresh_token, None) + + return new_session + + def validate_access_token(self, token: str) -> Optional[SessionInfo]: + """Validate an access token. Returns session or None.""" + session = self._by_access.get(token) + if not session: + return None + if session.status != SessionStatus.ACTIVE: + return None + if session.is_expired(): + # Clean up expired session + self.revoke_session(session.session_id) + return None + return session + + def validate_refresh_token(self, token: str) -> Optional[SessionInfo]: + """Validate a refresh token. Returns session or None.""" + session = self._by_refresh.get(token) + if not session: + return None + if session.status != SessionStatus.ACTIVE: + return None + if session.is_expired(): + return None + return session + + def revoke_session(self, session_id: str) -> bool: + """Revoke a specific session.""" + session = self._by_session.get(session_id) + if not session: + return False + session.status = SessionStatus.REVOKED + self._by_access.pop(session.access_token, None) + self._by_refresh.pop(session.refresh_token, None) + return True + + def revoke_all_user_sessions(self, user_id: str) -> int: + """Revoke all sessions for a user. Returns count revoked.""" + count = 0 + session_ids = list(self._by_user.get(user_id, set())) + for sid in session_ids: + if self.revoke_session(sid): + count += 1 + return count + + def get_user_sessions(self, user_id: str) -> List[SessionInfo]: + """Get all sessions for a user.""" + session_ids = self._by_user.get(user_id, set()) + return [ + self._by_session[sid] + for sid in session_ids + if sid in self._by_session + ] + + def cleanup_expired(self) -> int: + """Remove expired sessions. Returns count cleaned.""" + now = time.time() + expired = [ + sid + for sid, s in self._by_session.items() + if s.expires_at < now + ] + for sid in expired: + self.revoke_session(sid) + return len(expired) + + +# Global session store instance +_session_store: Optional[SessionStore] = None + + +def get_session_store() -> SessionStore: + global _session_store + if _session_store is None: + _session_store = SessionStore() + return _session_store + + +def reset_session_store() -> None: + """Reset the global session store (useful for tests).""" + global _session_store + _session_store = SessionStore() diff --git a/tests/test_agent_routes.py b/tests/test_agent_routes.py new file mode 100644 index 000000000..a96eaa934 --- /dev/null +++ b/tests/test_agent_routes.py @@ -0,0 +1,307 @@ +"""Route-level tests for workspace-scoped agent operations. + +Verifies: +- Agent listing is scoped to the caller's workspace +- Agent CRUD operations respect workspace boundaries +- Cross-workspace access returns 404 (resource not found) +- Agents are created in the caller's workspace +- The workspace filter is enforced at the service layer +""" + +import pytest +from fastapi.testclient import TestClient + +from src.api.server import create_app +from src.api.session import ( + SessionStore, + SessionRole, + SessionInfo, + get_session_store, + reset_session_store, +) +from src.api.routes import registry +from src.agent.registry import AgentRegistry + + +@pytest.fixture +def app(): + """Create a fresh app with clean state for each test.""" + reset_session_store() + # Reset registry: clear internal state + registry._agents = {} + registry._index = {} + return create_app() + + +@pytest.fixture +def client(app): + """Test client bound to the fresh app.""" + return TestClient(app) + + +def _auth_header(store: SessionStore, user_id: str = "alice", + roles: list = None, workspace: str = "ws-alpha") -> str: + """Helper: create a session and return Bearer token header.""" + if roles is None: + roles = [SessionRole.ADMIN] + session = store.create_session( + user_id=user_id, + roles=roles, + workspace=workspace, + ) + return f"Bearer {session.access_token}" + + +# ── Workspace-scoped agent listing ─────────────────────────────────────── + + +class TestListAgentsWorkspaceScope: + """list_agents must return only agents in the caller's workspace.""" + + def test_list_returns_only_own_workspace(self, client): + store = get_session_store() + + # Register agents in workspace ws-alpha and ws-beta + registry.register("agent-alpha-1", "worker", workspace="ws-alpha") + registry.register("agent-alpha-2", "worker", workspace="ws-alpha") + registry.register("agent-beta-1", "worker", workspace="ws-beta") + + headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + resp = client.get("/api/v2/agents", headers=headers) + assert resp.status_code == 200 + data = resp.json() + agent_names = {a["name"] for a in data["agents"]} + assert "agent-alpha-1" in agent_names + assert "agent-alpha-2" in agent_names + assert "agent-beta-1" not in agent_names + + def test_different_workspace_sees_no_overlap(self, client): + store = get_session_store() + + registry.register("agent-alpha-1", "worker", workspace="ws-alpha") + registry.register("agent-beta-1", "worker", workspace="ws-beta") + + headers = {"Authorization": _auth_header(store, workspace="ws-beta")} + resp = client.get("/api/v2/agents", headers=headers) + data = resp.json() + agent_names = {a["name"] for a in data["agents"]} + assert "agent-beta-1" in agent_names + assert "agent-alpha-1" not in agent_names + + def test_empty_list_for_workspace_with_no_agents(self, client): + store = get_session_store() + + registry.register("agent-alpha-1", "worker", workspace="ws-alpha") + + headers = {"Authorization": _auth_header(store, workspace="ws-empty")} + resp = client.get("/api/v2/agents", headers=headers) + data = resp.json() + assert data["agents"] == [] + + def test_default_workspace_isolation(self, client): + store = get_session_store() + + # Agents without explicit workspace get "default" + registry.register("agent-1", "worker") + registry.register("agent-2", "worker", workspace="default") + registry.register("agent-other", "worker", workspace="other-ws") + + headers = {"Authorization": _auth_header(store, workspace="default")} + resp = client.get("/api/v2/agents", headers=headers) + data = resp.json() + agent_names = {a["name"] for a in data["agents"]} + assert "agent-1" in agent_names + assert "agent-2" in agent_names + assert "agent-other" not in agent_names + + +class TestGetAgentWorkspaceScope: + """Individual agent lookup must be scoped to workspace.""" + + def test_get_agent_in_own_workspace_succeeds(self, client): + store = get_session_store() + agent_id = registry.register("my-agent", "worker", workspace="ws-alpha") + + headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + resp = client.get(f"/api/v2/agents/{agent_id}", headers=headers) + assert resp.status_code == 200 + assert resp.json()["name"] == "my-agent" + + def test_get_agent_in_different_workspace_returns_404(self, client): + store = get_session_store() + agent_id = registry.register("my-agent", "worker", workspace="ws-alpha") + + headers = {"Authorization": _auth_header(store, workspace="ws-beta")} + resp = client.get(f"/api/v2/agents/{agent_id}", headers=headers) + assert resp.status_code == 404 + assert "not found" in resp.json()["detail"].lower() + + def test_get_nonexistent_agent_returns_404(self, client): + store = get_session_store() + headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + resp = client.get("/api/v2/agents/nonexistent-id", headers=headers) + assert resp.status_code == 404 + + +class TestRegisterAgentWorkspace: + """New agents should be registered in the caller's workspace.""" + + def test_agent_registered_in_callers_workspace(self, client): + store = get_session_store() + headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + + resp = client.post( + "/api/v2/agents?name=test-agent&agent_type=worker.processor", + headers=headers, + ) + assert resp.status_code == 200 + agent_id = resp.json()["agent_id"] + + # Verify it shows up in workspace listing + list_resp = client.get("/api/v2/agents", headers=headers) + agent_ids = {a["id"] for a in list_resp.json()["agents"]} + assert agent_id in agent_ids + + def test_agent_not_visible_in_other_workspace(self, client): + store = get_session_store() + headers_alpha = {"Authorization": _auth_header(store, workspace="ws-alpha")} + headers_beta = {"Authorization": _auth_header(store, workspace="ws-beta")} + + resp = client.post( + "/api/v2/agents?name=secret-agent&agent_type=worker", + headers=headers_alpha, + ) + agent_id = resp.json()["agent_id"] + + # Should not be visible from ws-beta + beta_resp = client.get(f"/api/v2/agents/{agent_id}", headers=headers_beta) + assert beta_resp.status_code == 404 + + +class TestDeleteAgentWorkspaceScope: + """Agent deletion must be scoped to workspace.""" + + def test_delete_agent_in_own_workspace(self, client): + store = get_session_store() + agent_id = registry.register("my-agent", "worker", workspace="ws-alpha") + + headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + resp = client.delete(f"/api/v2/agents/{agent_id}", headers=headers) + assert resp.status_code == 200 + + # Verify it's gone + get_resp = client.get(f"/api/v2/agents/{agent_id}", headers=headers) + assert get_resp.status_code == 404 + + def test_delete_agent_in_different_workspace_returns_404(self, client): + store = get_session_store() + agent_id = registry.register("my-agent", "worker", workspace="ws-alpha") + + headers = {"Authorization": _auth_header(store, workspace="ws-beta")} + resp = client.delete(f"/api/v2/agents/{agent_id}", headers=headers) + assert resp.status_code == 404 + + +class TestAgentStatusOperationsWorkspaceScope: + """Agent status operations must be scoped to workspace.""" + + @pytest.mark.parametrize("endpoint", ["start", "stop", "disable", "enable"]) + def test_status_operation_own_workspace(self, client, endpoint): + store = get_session_store() + agent_id = registry.register("my-agent", "worker", workspace="ws-alpha") + + headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + resp = client.post( + f"/api/v2/agents/{agent_id}/{endpoint}", + headers=headers, + ) + assert resp.status_code in (200, 409) # 409 ok if already in that state + + @pytest.mark.parametrize("endpoint", ["start", "stop", "disable", "enable"]) + def test_status_operation_wrong_workspace_returns_404(self, client, endpoint): + store = get_session_store() + agent_id = registry.register("my-agent", "worker", workspace="ws-alpha") + + headers = {"Authorization": _auth_header(store, workspace="ws-beta")} + resp = client.post( + f"/api/v2/agents/{agent_id}/{endpoint}", + headers=headers, + ) + assert resp.status_code == 404 + + def test_disable_in_wrong_workspace_does_not_leak_state(self, client): + store = get_session_store() + agent_id = registry.register("my-agent", "worker", workspace="ws-alpha") + + # Attempt disable from wrong workspace — should 404 + headers = {"Authorization": _auth_header(store, workspace="ws-beta")} + resp = client.post( + f"/api/v2/agents/{agent_id}/disable", + headers=headers, + ) + assert resp.status_code == 404 + + # Agent should still be enabled in its own workspace + own_headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + get_resp = client.get(f"/api/v2/agents/{agent_id}", headers=own_headers) + assert get_resp.json()["enabled"] is True + + +class TestAgentCountWorkspace: + """Agent count must be scoped to workspace.""" + + def test_count_respects_workspace(self, client): + store = get_session_store() + + registry.register("a1", "worker", workspace="ws-alpha") + registry.register("a2", "worker", workspace="ws-alpha") + registry.register("b1", "worker", workspace="ws-beta") + + headers = {"Authorization": _auth_header(store, workspace="ws-alpha")} + resp = client.get("/api/v2/agents/count", headers=headers) + assert resp.json()["count"] == 2 + + headers_beta = {"Authorization": _auth_header(store, workspace="ws-beta")} + resp = client.get("/api/v2/agents/count", headers=headers_beta) + assert resp.json()["count"] == 1 + + +class TestUnauthenticatedAccess: + """Unauthenticated requests should get 401.""" + + def test_list_agents_no_auth(self, client): + resp = client.get("/api/v2/agents") + assert resp.status_code == 401 + + def test_get_agent_no_auth(self, client): + resp = client.get("/api/v2/agents/some-id") + assert resp.status_code == 401 + + def test_delete_agent_no_auth(self, client): + resp = client.delete("/api/v2/agents/some-id") + assert resp.status_code == 401 + + +class TestMalformedRequests: + """Malformed requests should return appropriate 4xx responses.""" + + def test_malformed_token(self, client): + resp = client.get( + "/api/v2/agents", + headers={"Authorization": "Bearer not-a-valid-token"}, + ) + assert resp.status_code == 401 + + def test_invalid_auth_header_format(self, client): + resp = client.get( + "/api/v2/agents", + headers={"Authorization": "Basic dXNlcjpwYXNz"}, + ) + assert resp.status_code == 401 + + def test_empty_bearer_token(self, client): + resp = client.get( + "/api/v2/agents", + headers={"Authorization": "Bearer "}, + ) + assert resp.status_code == 401