diff --git a/.env.example b/.env.example index 749c598..01cf7b4 100644 --- a/.env.example +++ b/.env.example @@ -26,7 +26,7 @@ CORS_ORIGINS=["*"] # Model Configuration # Default Claude model to use when none specified in request -DEFAULT_MODEL=claude-sonnet-4-5-20250929 +DEFAULT_MODEL=claude-sonnet-4-6 # Rate Limiting Configuration RATE_LIMIT_ENABLED=true @@ -35,4 +35,14 @@ RATE_LIMIT_CHAT_PER_MINUTE=10 RATE_LIMIT_DEBUG_PER_MINUTE=2 RATE_LIMIT_AUTH_PER_MINUTE=10 RATE_LIMIT_SESSION_PER_MINUTE=15 -RATE_LIMIT_HEALTH_PER_MINUTE=30 \ No newline at end of file +RATE_LIMIT_HEALTH_PER_MINUTE=30 + +# Security Configuration +# Comma-separated list of trusted proxy IPs (for X-Forwarded-For rate limiting) +# TRUSTED_PROXIES=10.0.0.1,10.0.0.2 +# Base directory for CLAUDE_CWD sandboxing (default: system temp dir) +# CLAUDE_CWD_ALLOWED_BASE=/tmp +# Maximum concurrent sessions (default: 1000) +# MAX_SESSIONS=1000 +# Maximum messages per session history (default: 100) +# MAX_SESSION_MESSAGES=100 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 43f90bf..d2368b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,30 +1,58 @@ -FROM python:3.12-slim +# Stage 1: Builder — install Poetry and dependencies +FROM python:3.12-slim AS builder -# Install system deps (curl for Poetry installer) RUN apt-get update && apt-get install -y \ curl \ && rm -rf /var/lib/apt/lists/* -# Install Poetry globally RUN curl -sSL https://install.python-poetry.org | python3 - - -# Add Poetry to PATH ENV PATH="/root/.local/bin:${PATH}" -# Note: Claude Code CLI is bundled with claude-agent-sdk >= 0.1.8 -# No separate Node.js/npm installation required +WORKDIR /app + +# Copy dependency files first (cache-friendly) +COPY pyproject.toml poetry.lock ./ -# Copy the app code -COPY . /app +# Install dependencies into a virtualenv +RUN poetry config virtualenvs.in-project true && \ + poetry install --no-root --no-interaction + +# Copy application code +COPY . . + +# Install the project itself +RUN poetry install --no-interaction + + +# Stage 2: Runtime — minimal image with non-root user +FROM python:3.12-slim + +# Install Node.js (required by Claude Agent SDK bundled CLI) +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \ + && apt-get install -y --no-install-recommends nodejs \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN groupadd --gid 1000 appuser && \ + useradd --uid 1000 --gid appuser --create-home appuser -# Set working directory WORKDIR /app -# Install Python dependencies with Poetry -RUN poetry install --no-root +# Copy virtualenv and app code from builder (owned by appuser) +COPY --from=builder --chown=appuser:appuser /app/.venv /app/.venv +COPY --from=builder --chown=appuser:appuser /app/src /app/src +COPY --from=builder --chown=appuser:appuser /app/pyproject.toml /app/pyproject.toml + +# Ensure virtualenv binaries are on PATH +ENV PATH="/app/.venv/bin:${PATH}" +ENV VIRTUAL_ENV="/app/.venv" + +# Switch to non-root user +USER appuser -# Expose the port (default 8000) EXPOSE 8000 -# Run the app with Uvicorn (development mode with reload; switch to --no-reload for prod) -CMD ["poetry", "run", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] \ No newline at end of file +# Production CMD — no --reload +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/docker-compose.yml b/docker-compose.yml index 6d0d141..77bd8eb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,7 +5,7 @@ services: ports: - "8000:8000" volumes: - - ~/.claude:/root/.claude + - ~/.claude:/home/appuser/.claude # Optional: Mount a specific workspace directory # Uncomment and modify the line below to use a custom workspace # - ./workspace:/workspace diff --git a/poetry.lock b/poetry.lock index 03d8e92..8d77dcf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -406,17 +406,18 @@ files = [ [[package]] name = "claude-agent-sdk" -version = "0.1.18" +version = "0.1.52" description = "Python SDK for Claude Code" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "claude_agent_sdk-0.1.18-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9e45b4e3c20c072c3e3325fa60bab9a4b5a7cbbce64ca274b8d7d0af42dd9dd8"}, - {file = "claude_agent_sdk-0.1.18-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:3c41bd8f38848609ae0d5da8d7327a4c2d7057a363feafb6fd70df611ea204cc"}, - {file = "claude_agent_sdk-0.1.18-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:983f15e51253f40c55136a86d7cc63e023a3576428b05fa1459093d461b2d215"}, - {file = "claude_agent_sdk-0.1.18-py3-none-win_amd64.whl", hash = "sha256:36f5b84d5c3c8773ee9b56aeb5ab345d1033231db37f80d1f20ac15239bef41c"}, - {file = "claude_agent_sdk-0.1.18.tar.gz", hash = "sha256:4fcb8730cc77dea562fbe9aa48c65eced3ef58a6bb1f34f77e50e8258902477d"}, + {file = "claude_agent_sdk-0.1.52-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0f15c91319c20831f881fd4b6bcec1772a3599d66da5e7c057d79945bf603e1a"}, + {file = "claude_agent_sdk-0.1.52-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:f32ca2ca95e312678af63ee60a5fb1b765f98ed47825ea3ba90322ceb26736ff"}, + {file = "claude_agent_sdk-0.1.52-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:97c238fda13f0bea057e546895a3dc67468fc1acdbbc00d0b53a46cb3ba588dc"}, + {file = "claude_agent_sdk-0.1.52-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:4cd3f2e6fd5b272b16114fbc8dcd4348cbe8615dcc1b3bea251a7d489bf83a5d"}, + {file = "claude_agent_sdk-0.1.52-py3-none-win_amd64.whl", hash = "sha256:a8a5455d248b76c17126abee0f69d0f9870cbc0d52cb2f8ad5eb3deddb05af39"}, + {file = "claude_agent_sdk-0.1.52.tar.gz", hash = "sha256:c27f35d850521c7cff18448b38ff0dd5e899a4aeb6de9d28c0b2a66863eaf134"}, ] [package.dependencies] @@ -3053,4 +3054,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "995cbb6b6bfbf14612eff7e0690ca47fc7b0c01fd2ef3351dea01d6940be0ed6" +content-hash = "e82a4bf0faa20f4fb934acc63b26567077bd1a6b0f919fa27fe3d4886af0aeae" diff --git a/pyproject.toml b/pyproject.toml index e0cc381..0655ce7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ python-dotenv = "^1.0.1" httpx = "^0.27.2" sse-starlette = "^2.1.3" python-multipart = "^0.0.18" -claude-agent-sdk = "^0.1.18" +claude-agent-sdk = "^0.1.52" slowapi = "^0.1.9" [tool.poetry.group.dev.dependencies] diff --git a/src/auth.py b/src/auth.py index 7b23e69..4ca78e5 100644 --- a/src/auth.py +++ b/src/auth.py @@ -284,3 +284,14 @@ def get_claude_code_auth_info() -> Dict[str, Any]: "status": auth_manager.auth_status, "environment_variables": list(auth_manager.get_claude_code_env_vars().keys()), } + + +def redact_key(value: str) -> str: + """Redact a credential value for safe logging (FR-4.1, FR-4.2). + + Strings >= 8 chars show first 3 and last 3 characters with masking in between. + Strings < 8 chars are fully masked to avoid leaking short secrets. + """ + if len(value) >= 8: + return f"{value[:3]}***...***{value[-3:]}" + return "***" diff --git a/src/claude_cli.py b/src/claude_cli.py index d87057e..4795933 100644 --- a/src/claude_cli.py +++ b/src/claude_cli.py @@ -7,6 +7,7 @@ import logging from claude_agent_sdk import query, ClaudeAgentOptions +from src.constants import CLAUDE_CWD_ALLOWED_BASE logger = logging.getLogger(__name__) @@ -19,7 +20,7 @@ def __init__(self, timeout: int = 600000, cwd: Optional[str] = None): # If cwd is provided (from CLAUDE_CWD env var), use it # Otherwise create an isolated temp directory if cwd: - self.cwd = Path(cwd) + self.cwd = Path(cwd).resolve() # Check if the directory exists if not self.cwd.exists(): logger.error(f"ERROR: Specified working directory does not exist: {self.cwd}") @@ -27,8 +28,13 @@ def __init__(self, timeout: int = 600000, cwd: Optional[str] = None): "Please create the directory first or unset CLAUDE_CWD to use a temporary directory" ) raise ValueError(f"Working directory does not exist: {self.cwd}") - else: - logger.info(f"Using CLAUDE_CWD: {self.cwd}") + # Sandbox check: reject paths outside the allowed base directory + allowed_base = Path(CLAUDE_CWD_ALLOWED_BASE).resolve() + if not self.cwd.is_relative_to(allowed_base): + raise ValueError( + f"Working directory {self.cwd} is outside allowed base {allowed_base}" + ) + logger.info(f"Using CLAUDE_CWD: {self.cwd}") else: # Create isolated temp directory (cross-platform) self.temp_dir = tempfile.mkdtemp(prefix="claude_code_workspace_") diff --git a/src/constants.py b/src/constants.py index 5fb452b..e86f4fb 100644 --- a/src/constants.py +++ b/src/constants.py @@ -13,11 +13,6 @@ from src.constants import DEFAULT_ALLOWED_TOOLS options = {"allowed_tools": DEFAULT_ALLOWED_TOOLS} - # Use rate limits in FastAPI - from src.constants import RATE_LIMIT_CHAT - @limiter.limit(f"{RATE_LIMIT_CHAT}/minute") - async def chat_endpoint(): ... - Note: - Tool configurations are managed by ToolManager (see tool_manager.py) - Model validation uses graceful degradation (warns but allows unknown models) @@ -25,6 +20,7 @@ async def chat_endpoint(): ... """ import os +import tempfile # Claude Agent SDK Tool Names # These are the built-in tools available in the Claude Agent SDK @@ -66,12 +62,15 @@ async def chat_endpoint(): ... ] # Claude Models -# Models supported by Claude Agent SDK (as of November 2025) +# Models supported by Claude Agent SDK (as of March 2026) # NOTE: Claude Agent SDK only supports Claude 4+ models, not Claude 3.x CLAUDE_MODELS = [ - # Claude 4.5 Family (Latest - Fall 2025) - RECOMMENDED - "claude-opus-4-5-20250929", # Latest Opus 4.5 - Most capable - "claude-sonnet-4-5-20250929", # Recommended - best coding model + # Claude 4.6 Family (Latest - 2026) - RECOMMENDED + "claude-opus-4-6", # Latest Opus 4.6 - Most capable + "claude-sonnet-4-6", # Recommended - best coding model + # Claude 4.5 Family (Fall 2025) + "claude-opus-4-5-20250929", + "claude-sonnet-4-5-20250929", "claude-haiku-4-5-20251001", # Fast & cheap # Claude 4.1 "claude-opus-4-1-20250805", # Upgraded Opus 4 @@ -88,7 +87,7 @@ async def chat_endpoint(): ... # Default model (recommended for most use cases) # Can be overridden via DEFAULT_MODEL environment variable -DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "claude-sonnet-4-5-20250929") +DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "claude-sonnet-4-6") # Fast model (for speed/cost optimization) FAST_MODEL = "claude-haiku-4-5-20251001" @@ -109,8 +108,9 @@ async def chat_endpoint(): ... SESSION_CLEANUP_INTERVAL_MINUTES = 5 SESSION_MAX_AGE_MINUTES = 60 -# Rate Limiting (requests per minute) -RATE_LIMIT_DEFAULT = 60 -RATE_LIMIT_CHAT = 30 -RATE_LIMIT_MODELS = 100 -RATE_LIMIT_HEALTH = 200 +# Security Configuration +MAX_SESSIONS = int(os.getenv("MAX_SESSIONS", "1000")) +MAX_SESSION_MESSAGES = int(os.getenv("MAX_SESSION_MESSAGES", "100")) +_trusted_proxies_raw = os.getenv("TRUSTED_PROXIES", "") +TRUSTED_PROXIES = [p.strip() for p in _trusted_proxies_raw.split(",") if p.strip()] +CLAUDE_CWD_ALLOWED_BASE = os.getenv("CLAUDE_CWD_ALLOWED_BASE", tempfile.gettempdir()) diff --git a/src/main.py b/src/main.py index 4a74aa4..ec4b978 100644 --- a/src/main.py +++ b/src/main.py @@ -43,7 +43,7 @@ from src.message_adapter import MessageAdapter from src.auth import verify_api_key, security, validate_claude_code_auth, get_claude_code_auth_info from src.parameter_validator import ParameterValidator, CompatibilityReporter -from src.session_manager import session_manager +from src.session_manager import session_manager, SessionLimitExceeded from src.tool_manager import tool_manager from src.mcp_client import mcp_client, MCPServerConfig from src.rate_limiter import ( @@ -65,6 +65,11 @@ logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +if DEBUG_MODE: + logger.warning( + "DEBUG_MODE is enabled — request/response details will be logged. Disable in production." + ) + # Global variable to store runtime-generated API key runtime_api_key = None @@ -207,11 +212,19 @@ async def lifespan(app: FastAPI): ) # Configure CORS -cors_origins = json.loads(os.getenv("CORS_ORIGINS", '["*"]')) +try: + cors_origins = json.loads(os.getenv("CORS_ORIGINS", '["*"]')) + if not isinstance(cors_origins, list): + logger.warning("CORS_ORIGINS must be a JSON array, falling back to ['*']") + cors_origins = ["*"] +except (json.JSONDecodeError, TypeError): + logger.warning("Invalid CORS_ORIGINS value, falling back to ['*']") + cors_origins = ["*"] +allow_creds = "*" not in cors_origins # No credentials with wildcard app.add_middleware( CORSMiddleware, allow_origins=cors_origins, - allow_credentials=True, + allow_credentials=allow_creds, allow_methods=["*"], allow_headers=["*"], ) @@ -264,6 +277,27 @@ async def dispatch(self, request: Request, call_next): app.add_middleware(RequestSizeLimitMiddleware) +def redact_request_headers(headers: dict) -> dict: + """Redact sensitive values from request headers for safe logging.""" + redacted = dict(headers) + for key in list(redacted.keys()): + if key.lower() == "authorization": + redacted[key] = "[REDACTED]" + return redacted + + +def redact_request_body(body: dict) -> dict: + """Redact sensitive fields from request body for safe logging.""" + import copy + + redacted = copy.deepcopy(body) + sensitive_fields = {"api_key", "authorization", "token", "secret", "password"} + for key in list(redacted.keys()): + if key.lower() in sensitive_fields: + redacted[key] = "[REDACTED]" + return redacted + + class DebugLoggingMiddleware(BaseHTTPMiddleware): """ASGI-compliant middleware for logging request/response details when debug mode is enabled.""" @@ -275,11 +309,11 @@ async def dispatch(self, request: Request, call_next): return await call_next(request) # Log request details - start_time = asyncio.get_event_loop().time() + start_time = asyncio.get_running_loop().time() # Log basic request info with request ID for correlation logger.debug(f"🔍 [{request_id}] Incoming request: {request.method} {request.url}") - logger.debug(f"🔍 [{request_id}] Headers: {dict(request.headers)}") + logger.debug(f"🔍 [{request_id}] Headers: {redact_request_headers(dict(request.headers))}") # For POST requests, try to log body (but don't break if we can't) body_logged = False @@ -295,11 +329,11 @@ async def dispatch(self, request: Request, call_next): parsed_body = json_lib.loads(body.decode()) logger.debug( - f"🔍 Request body: {json_lib.dumps(parsed_body, indent=2)}" + f"🔍 Request body: {json_lib.dumps(redact_request_body(parsed_body), indent=2)}" ) body_logged = True - except: - logger.debug(f"🔍 Request body (raw): {body.decode()[:500]}...") + except Exception: + logger.debug("🔍 Request body: [non-JSON, redacted]") body_logged = True except Exception as e: logger.debug(f"🔍 Could not read request body: {e}") @@ -312,7 +346,7 @@ async def dispatch(self, request: Request, call_next): response = await call_next(request) # Log response details - end_time = asyncio.get_event_loop().time() + end_time = asyncio.get_running_loop().time() duration = (end_time - start_time) * 1000 # Convert to milliseconds logger.debug(f"🔍 Response: {response.status_code} in {duration:.2f}ms") @@ -320,7 +354,7 @@ async def dispatch(self, request: Request, call_next): return response except Exception as e: - end_time = asyncio.get_event_loop().time() + end_time = asyncio.get_running_loop().time() duration = (end_time - start_time) * 1000 logger.debug(f"🔍 Request failed after {duration:.2f}ms: {e}") @@ -336,11 +370,13 @@ async def dispatch(self, request: Request, call_next): async def validation_exception_handler(request: Request, exc: RequestValidationError): """Handle request validation errors with detailed debugging information.""" - # Log the validation error details - logger.error(f"❌ Request validation failed for {request.method} {request.url}") - logger.error(f"❌ Validation errors: {exc.errors()}") + # Log validation error without raw input values (may contain credentials) + logger.error(f"Request validation failed for {request.method} {request.url}") + logger.error( + f"Validation errors: {[{k: v for k, v in e.items() if k != 'input'} for e in exc.errors()]}" + ) - # Create detailed error response + # Create detailed error response — omit raw input values to prevent credential leaks error_details = [] for error in exc.errors(): location = " -> ".join(str(loc) for loc in error.get("loc", [])) @@ -349,19 +385,24 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE "field": location, "message": error.get("msg", "Unknown validation error"), "type": error.get("type", "validation_error"), - "input": error.get("input"), } ) - # If debug mode is enabled, include the raw request body + # If debug mode is enabled, include redacted request info (never expose raw body) debug_info = {} if DEBUG_MODE or VERBOSE: try: body = await request.body() if body: - debug_info["raw_request_body"] = body.decode() - except: - debug_info["raw_request_body"] = "Could not read request body" + import json as json_lib + + try: + parsed = json_lib.loads(body.decode()) + debug_info["request_body"] = redact_request_body(parsed) + except Exception: + debug_info["request_body"] = "[REDACTED — unparseable]" + except Exception: + debug_info["request_body"] = "[REDACTED — unreadable]" error_response = { "error": { @@ -599,6 +640,15 @@ async def generate_streaming_response( yield f"data: {final_chunk.model_dump_json()}\n\n" yield "data: [DONE]\n\n" + except SessionLimitExceeded: + error_chunk = { + "error": { + "message": f"Maximum session limit reached ({session_manager.max_sessions}). Try again later or close existing sessions.", + "type": "rate_limit_exceeded", + "code": "too_many_sessions", + } + } + yield f"data: {json.dumps(error_chunk)}\n\n" except Exception as e: logger.error(f"Streaming error: {e}") error_chunk = {"error": {"message": str(e), "type": "streaming_error"}} @@ -639,15 +689,35 @@ async def chat_completions( compatibility_report = CompatibilityReporter.generate_compatibility_report(request_body) logger.debug(f"Compatibility report: {compatibility_report}") + model_recognized = ParameterValidator.is_model_recognized(request_body.model) + + # Pre-check session limit before streaming branch (can't change HTTP status mid-stream) + if request_body.session_id: + try: + session_manager.check_session_limit(request_body.session_id) + except SessionLimitExceeded: + raise HTTPException( + status_code=429, + detail={ + "message": f"Maximum session limit reached ({session_manager.max_sessions}). Try again later or close existing sessions.", + "type": "rate_limit_exceeded", + "code": "too_many_sessions", + }, + headers={"Retry-After": "60"}, + ) + if request_body.stream: # Return streaming response + streaming_headers = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + } + if not model_recognized: + streaming_headers["X-Claude-Model-Warning"] = "unrecognized" return StreamingResponse( generate_streaming_response(request_body, request_id, claude_headers), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, + headers=streaming_headers, ) else: # Non-streaming response @@ -734,7 +804,7 @@ async def chat_completions( completion_tokens = MessageAdapter.estimate_tokens(assistant_content) # Create response - response = ChatCompletionResponse( + response_data = ChatCompletionResponse( id=request_id, model=request_body.model, choices=[ @@ -751,8 +821,21 @@ async def chat_completions( ), ) + response = JSONResponse(content=response_data.model_dump()) + if not model_recognized: + response.headers["X-Claude-Model-Warning"] = "unrecognized" return response + except SessionLimitExceeded: + raise HTTPException( + status_code=429, + detail={ + "message": f"Maximum session limit reached ({session_manager.max_sessions}). Try again later or close existing sessions.", + "type": "rate_limit_exceeded", + "code": "too_many_sessions", + }, + headers={"Retry-After": "60"}, + ) except HTTPException: raise except Exception as e: @@ -837,7 +920,7 @@ async def anthropic_messages( completion_tokens = MessageAdapter.estimate_tokens(assistant_content) # Create Anthropic-format response - response = AnthropicMessagesResponse( + response_data = AnthropicMessagesResponse( model=request_body.model, content=[AnthropicTextBlock(text=assistant_content)], stop_reason="end_turn", @@ -847,6 +930,9 @@ async def anthropic_messages( ), ) + response = JSONResponse(content=response_data.model_dump()) + if not ParameterValidator.is_model_recognized(request_body.model): + response.headers["X-Claude-Model-Warning"] = "unrecognized" return response except HTTPException: @@ -875,33 +961,36 @@ async def list_models( @app.post("/v1/compatibility") -async def check_compatibility(request_body: ChatCompletionRequest): +@rate_limit_endpoint("general") +async def check_compatibility(request: Request, request_body: ChatCompletionRequest): """Check OpenAI API compatibility for a request.""" report = CompatibilityReporter.generate_compatibility_report(request_body) - return { - "compatibility_report": report, - "claude_agent_sdk_options": { - "supported": [ - "model", - "system_prompt", - "max_turns", - "allowed_tools", - "disallowed_tools", - "permission_mode", - "max_thinking_tokens", - "continue_conversation", - "resume", - "cwd", - ], - "custom_headers": [ - "X-Claude-Max-Turns", - "X-Claude-Allowed-Tools", - "X-Claude-Disallowed-Tools", - "X-Claude-Permission-Mode", - "X-Claude-Max-Thinking-Tokens", - ], - }, - } + return JSONResponse( + content={ + "compatibility_report": report, + "claude_agent_sdk_options": { + "supported": [ + "model", + "system_prompt", + "max_turns", + "allowed_tools", + "disallowed_tools", + "permission_mode", + "max_thinking_tokens", + "continue_conversation", + "resume", + "cwd", + ], + "custom_headers": [ + "X-Claude-Max-Turns", + "X-Claude-Allowed-Tools", + "X-Claude-Disallowed-Tools", + "X-Claude-Permission-Mode", + "X-Claude-Max-Thinking-Tokens", + ], + }, + } + ) @app.get("/health") @@ -925,12 +1014,13 @@ async def version_info(request: Request): @app.get("/", response_class=HTMLResponse) -async def root(): +@rate_limit_endpoint("general") +async def root(request: Request): """Landing page with API documentation.""" from src import __version__ auth_info = get_claude_code_auth_info() - auth_method = auth_info.get("method", "unknown") + auth_method = "configured" # Do not reveal auth method to unauthenticated visitors (FR-7.2) auth_valid = auth_info.get("status", {}).get("valid", False) status_color = "#22c55e" if auth_valid else "#ef4444" status_text = "Connected" if auth_valid else "Not Connected" @@ -1533,8 +1623,12 @@ async def root(): @app.post("/v1/debug/request") @rate_limit_endpoint("debug") -async def debug_request_validation(request: Request): +async def debug_request_validation( + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), +): """Debug endpoint to test request validation and see what's being sent.""" + await verify_api_key(request, credentials) try: # Get the raw request body body = await request.body() @@ -1555,7 +1649,10 @@ async def debug_request_validation(request: Request): if parsed_body: try: chat_request = ChatCompletionRequest(**parsed_body) - validation_result = {"valid": True, "validated_data": chat_request.model_dump()} + validation_result = { + "valid": True, + "validated_data": redact_request_body(chat_request.model_dump()), + } except ValidationError as e: validation_result = { "valid": False, @@ -1572,12 +1669,12 @@ async def debug_request_validation(request: Request): return { "debug_info": { - "headers": dict(request.headers), + "headers": redact_request_headers(dict(request.headers)), "method": request.method, "url": str(request.url), - "raw_body": raw_body, + "raw_body": "[REDACTED — use parsed_body]", "json_parse_error": json_error, - "parsed_body": parsed_body, + "parsed_body": redact_request_body(parsed_body) if parsed_body else parsed_body, "validation_result": validation_result, "debug_mode_enabled": DEBUG_MODE or VERBOSE, "example_valid_request": { @@ -1592,7 +1689,7 @@ async def debug_request_validation(request: Request): return { "debug_info": { "error": f"Debug endpoint error: {str(e)}", - "headers": dict(request.headers), + "headers": redact_request_headers(dict(request.headers)), "method": request.method, "url": str(request.url), } @@ -1603,23 +1700,9 @@ async def debug_request_validation(request: Request): @rate_limit_endpoint("auth") async def get_auth_status(request: Request): """Get Claude Code authentication status.""" - from src.auth import auth_manager - auth_info = get_claude_code_auth_info() - active_api_key = auth_manager.get_api_key() - - return { - "claude_code_auth": auth_info, - "server_info": { - "api_key_required": bool(active_api_key), - "api_key_source": ( - "environment" - if os.getenv("API_KEY") - else ("runtime" if runtime_api_key else "none") - ), - "version": "1.0.0", - }, - } + auth_valid = auth_info.get("status", {}).get("valid", False) + return {"authenticated": auth_valid} @app.get("/v1/sessions/stats") diff --git a/src/parameter_validator.py b/src/parameter_validator.py index e45452f..4838f3a 100644 --- a/src/parameter_validator.py +++ b/src/parameter_validator.py @@ -19,6 +19,11 @@ class ParameterValidator: # Valid permission modes for Claude Code SDK VALID_PERMISSION_MODES = {"default", "acceptEdits", "bypassPermissions", "plan"} + @classmethod + def is_model_recognized(cls, model: str) -> bool: + """Check if a model is in the known supported models list.""" + return model in cls.SUPPORTED_MODELS + @classmethod def validate_model(cls, model: str) -> bool: """Validate that the model is supported by Claude Code SDK.""" diff --git a/src/rate_limiter.py b/src/rate_limiter.py index 89c6531..d6be14b 100644 --- a/src/rate_limiter.py +++ b/src/rate_limiter.py @@ -1,15 +1,39 @@ import os +import functools from typing import Optional from slowapi import Limiter -from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from fastapi import Request -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response def get_rate_limit_key(request: Request) -> str: - """Get the rate limiting key (IP address) from the request.""" - return get_remote_address(request) + """Get the rate limiting key (IP address) from the request. + + When TRUSTED_PROXIES is configured and the direct peer IP is in that list, + the rightmost non-trusted IP from X-Forwarded-For is used so that the real + client is rate-limited rather than the proxy. When the peer is not trusted, + X-Forwarded-For is ignored entirely to prevent IP-spoofing attacks. + """ + from src.constants import TRUSTED_PROXIES + + client_ip = request.client.host if request.client else "127.0.0.1" + + if not TRUSTED_PROXIES or client_ip not in TRUSTED_PROXIES: + return client_ip + + # Peer is a trusted proxy — read X-Forwarded-For + xff = request.headers.get("x-forwarded-for", "") + if not xff: + return client_ip # fallback: no upstream IP available + + # Return rightmost non-trusted IP (closest to the real client) + ips = [ip.strip() for ip in xff.split(",")] + for ip in reversed(ips): + if ip not in TRUSTED_PROXIES: + return ip + + return client_ip # all IPs in chain are trusted, fallback to peer def create_rate_limiter() -> Optional[Limiter]: @@ -25,9 +49,7 @@ def create_rate_limiter() -> Optional[Limiter]: return None # Create limiter with IP-based identification - limiter = Limiter( - key_func=get_rate_limit_key, default_limits=[] # We'll apply limits per endpoint - ) + limiter = Limiter(key_func=get_rate_limit_key, default_limits=[]) return limiter @@ -81,12 +103,41 @@ def get_rate_limit_for_endpoint(endpoint: str) -> str: def rate_limit_endpoint(endpoint: str): - """Decorator factory for applying rate limits to endpoints.""" + """Decorator factory for applying rate limits to endpoints. + + Wraps the endpoint with slowapi rate limiting and injects X-RateLimit-Limit + into the response headers so callers can observe the limit value. + + Clears any previously registered limits for the route before registering new + ones so that module reloads (common in tests) do not accumulate duplicate + limit entries that would be applied multiplicatively. + """ + rate_limit_str = get_rate_limit_for_endpoint(endpoint) + # Parse requests-per-minute from the rate limit string (e.g. "30/minute") + limit_value = rate_limit_str.split("/")[0] def decorator(func): - if limiter: - return limiter.limit(get_rate_limit_for_endpoint(endpoint))(func) - return func + if not limiter: + # Rate limiting disabled — return the original function unchanged. + return func + + # Clear any stale limit registrations for this route that may have + # been added by previous module loads (avoids duplicate-counting). + func_name = f"{func.__module__}.{func.__name__}" + if func_name in limiter._route_limits: + limiter._route_limits[func_name] = [] + + limited_func = limiter.limit(rate_limit_str)(func) + + @functools.wraps(limited_func) + async def wrapper(*args, **kwargs): + result = await limited_func(*args, **kwargs) + # Inject X-RateLimit-Limit header so tests and clients can observe the limit. + if isinstance(result, Response): + result.headers["X-RateLimit-Limit"] = limit_value + return result + + return wrapper return decorator diff --git a/src/session_manager.py b/src/session_manager.py index 8423878..f1447cd 100644 --- a/src/session_manager.py +++ b/src/session_manager.py @@ -6,10 +6,15 @@ from threading import Lock from src.models import Message, SessionInfo +from src.constants import MAX_SESSIONS, MAX_SESSION_MESSAGES logger = logging.getLogger(__name__) +class SessionLimitExceeded(ValueError): + """Raised when the maximum number of concurrent sessions has been reached.""" + + @dataclass class Session: """Represents a conversation session with message history.""" @@ -19,6 +24,7 @@ class Session: created_at: datetime = field(default_factory=datetime.utcnow) last_accessed: datetime = field(default_factory=datetime.utcnow) expires_at: datetime = field(default_factory=lambda: datetime.utcnow() + timedelta(hours=1)) + max_messages: Optional[int] = None def touch(self): """Update last accessed time and extend expiration.""" @@ -28,6 +34,8 @@ def touch(self): def add_messages(self, messages: List[Message]): """Add new messages to the session.""" self.messages.extend(messages) + if self.max_messages is not None and len(self.messages) > self.max_messages: + self.messages = self.messages[-self.max_messages :] self.touch() def get_all_messages(self) -> List[Message]: @@ -52,11 +60,19 @@ def to_session_info(self) -> SessionInfo: class SessionManager: """Manages conversation sessions with automatic cleanup.""" - def __init__(self, default_ttl_hours: int = 1, cleanup_interval_minutes: int = 5): + def __init__( + self, + default_ttl_hours: int = 1, + cleanup_interval_minutes: int = 5, + max_sessions: int = MAX_SESSIONS, + max_session_messages: int = MAX_SESSION_MESSAGES, + ): self.sessions: Dict[str, Session] = {} self.lock = Lock() self.default_ttl_hours = default_ttl_hours self.cleanup_interval_minutes = cleanup_interval_minutes + self.max_sessions = max_sessions + self.max_session_messages = max_session_messages self._cleanup_task = None def start_cleanup_task(self): @@ -93,21 +109,45 @@ def _cleanup_expired_sessions(self): del self.sessions[session_id] logger.info(f"Cleaned up expired session: {session_id}") + def check_session_limit(self, session_id: str) -> None: + """Check if a new session can be created without actually creating it. + + Raises SessionLimitExceeded if the limit is reached and the session_id + does not already exist (or is expired). Uses len(self.sessions) to match + the counting logic in get_or_create_session. + """ + with self.lock: + if session_id in self.sessions and not self.sessions[session_id].is_expired(): + return # Existing active session — no slot needed + # Would need a new slot — check limit (same counting as get_or_create_session) + if len(self.sessions) >= self.max_sessions: + raise SessionLimitExceeded( + f"Maximum number of sessions ({self.max_sessions}) reached" + ) + def get_or_create_session(self, session_id: str) -> Session: """Get existing session or create a new one.""" with self.lock: if session_id in self.sessions: session = self.sessions[session_id] if session.is_expired(): - # Session expired, create new one + # Session expired, create new one — check limit first logger.info(f"Session {session_id} expired, creating new session") del self.sessions[session_id] - session = Session(session_id=session_id) + if len(self.sessions) >= self.max_sessions: + raise SessionLimitExceeded( + f"Maximum number of sessions ({self.max_sessions}) reached" + ) + session = Session(session_id=session_id, max_messages=self.max_session_messages) self.sessions[session_id] = session else: session.touch() else: - session = Session(session_id=session_id) + if len(self.sessions) >= self.max_sessions: + raise SessionLimitExceeded( + f"Maximum number of sessions ({self.max_sessions}) reached" + ) + session = Session(session_id=session_id, max_messages=self.max_session_messages) self.sessions[session_id] = session logger.info(f"Created new session: {session_id}") diff --git a/tests/test_auth_unit.py b/tests/test_auth_unit.py index ba9ec92..66b0b49 100644 --- a/tests/test_auth_unit.py +++ b/tests/test_auth_unit.py @@ -491,6 +491,53 @@ def test_returns_runtime_key_when_available(self): assert result in ["env-key", "runtime-key"] +class TestRedactKey: + """Test redact_key() — FR-4.1, FR-4.2: credential redaction for safe logging.""" + + def test_redact_key_long_string_shows_first_and_last_three_chars(self): + """String of 20 chars returns first 3 + masking + last 3.""" + from src.auth import redact_key + + result = redact_key("sk-abcdefghijklmnoxyz") + assert result == "sk-***...***xyz" + + def test_redact_key_exactly_eight_chars_shows_first_and_last_three(self): + """String of exactly 8 chars returns first 3 + masking + last 3.""" + from src.auth import redact_key + + result = redact_key("abcdefgh") + assert result == "abc***...***fgh" + + def test_redact_key_seven_chars_returns_full_mask(self): + """String of 7 chars (below threshold) returns '***'.""" + from src.auth import redact_key + + result = redact_key("abcdefg") + assert result == "***" + + def test_redact_key_three_chars_returns_full_mask(self): + """String of 3 chars returns '***'.""" + from src.auth import redact_key + + result = redact_key("abc") + assert result == "***" + + def test_redact_key_empty_string_returns_full_mask(self): + """Empty string returns '***'.""" + from src.auth import redact_key + + result = redact_key("") + assert result == "***" + + def test_redact_key_api_key_with_hyphens_preserves_prefix_and_suffix(self): + """API key containing hyphens is redacted, showing only first 3 and last 3 chars.""" + from src.auth import redact_key + + # Typical Anthropic API key format: sk-ant-api03-... + result = redact_key("sk-ant-api03-validlongkey1234567890") + assert result == "sk-***...***890" + + # Reset module state after tests @pytest.fixture(autouse=True) def reset_auth_module(): diff --git a/tests/test_claude_cli_unit.py b/tests/test_claude_cli_unit.py index c67c7fe..2e10cae 100644 --- a/tests/test_claude_cli_unit.py +++ b/tests/test_claude_cli_unit.py @@ -723,3 +723,145 @@ def test_cleanup_exception_is_caught(self): if os.path.exists(temp_dir): shutil.rmtree(temp_dir) + + +class TestClaudeCodeCLICwdSandbox: + """Test ClaudeCodeCLI.__init__() CWD sandboxing against CLAUDE_CWD_ALLOWED_BASE. + + FR-5.1: Canonicalize CLAUDE_CWD with Path.resolve() and reject paths outside + the allowed base directory (CLAUDE_CWD_ALLOWED_BASE, default: temp directory). + + Architecture section 7.6 defines the validation flow: + 1. Resolve cwd with Path(cwd).resolve() + 2. Resolve allowed base with Path(CLAUDE_CWD_ALLOWED_BASE).resolve() + 3. Check resolved_cwd.is_relative_to(resolved_base) + 4. If not: raise ValueError with descriptive message + """ + + def setup_method(self): + """Create a controlled allowed base temp directory for each test.""" + import shutil + + self._dirs_to_cleanup = [] + # Create an isolated base directory that is our sandbox root + self.allowed_base = tempfile.mkdtemp(prefix="test_sandbox_base_") + self._dirs_to_cleanup.append(self.allowed_base) + + def teardown_method(self): + """Remove all temp directories created during tests.""" + import shutil + + for d in self._dirs_to_cleanup: + if os.path.exists(d): + shutil.rmtree(d) + + def _make_cli(self, cwd=None, **kwargs): + """Helper: instantiate ClaudeCodeCLI with mocked auth and patched allowed base.""" + with patch("src.auth.validate_claude_code_auth") as mock_validate: + with patch("src.auth.auth_manager") as mock_auth: + mock_validate.return_value = (True, {"method": "anthropic"}) + mock_auth.get_claude_code_env_vars.return_value = {} + + # Patch CLAUDE_CWD_ALLOWED_BASE inside claude_cli module so the + # sandbox check uses our controlled base instead of the real temp dir. + with patch("src.claude_cli.CLAUDE_CWD_ALLOWED_BASE", self.allowed_base): + from src.claude_cli import ClaudeCodeCLI + + return ClaudeCodeCLI(cwd=cwd, **kwargs) + + def test_cwd_outside_allowed_base_raises_value_error(self): + """Setting CWD to an absolute path outside the allowed base raises ValueError. + + CLAUDE_CWD=/etc (or any path outside the allowed temp base) must be + rejected at startup to prevent directory traversal into sensitive areas. + """ + # Use the system temp dir itself as an 'outside' directory — it is a + # real directory that exists but is NOT inside self.allowed_base. + outside_dir = tempfile.gettempdir() + + with pytest.raises(ValueError, match="outside allowed base"): + self._make_cli(cwd=outside_dir) + + def test_cwd_path_traversal_raises_value_error(self): + """A relative path traversal that escapes the allowed base raises ValueError. + + CLAUDE_CWD=../../etc must resolve to an absolute path and then be + checked — if it escapes the allowed base the request is rejected. + """ + # Create a subdirectory inside the allowed base and craft a traversal + # path that would land outside the allowed base after resolution. + subdir = tempfile.mkdtemp(dir=self.allowed_base, prefix="subdir_") + self._dirs_to_cleanup.append(subdir) + + # Construct a traversal: subdir/../../.. resolves above allowed_base + traversal_path = os.path.join(subdir, "..", "..", "..") + + with pytest.raises(ValueError, match="outside allowed base"): + self._make_cli(cwd=traversal_path) + + def test_cwd_inside_allowed_base_succeeds(self): + """A CWD that resolves to a path inside the allowed base initializes normally.""" + valid_cwd = tempfile.mkdtemp(dir=self.allowed_base, prefix="valid_workspace_") + self._dirs_to_cleanup.append(valid_cwd) + + # Should not raise + cli = self._make_cli(cwd=valid_cwd) + + assert cli.cwd == Path(valid_cwd).resolve() + assert cli.temp_dir is None + + def test_cwd_not_provided_uses_temp_dir_inside_allowed_base(self): + """When no CWD is provided, the auto-created temp dir is inside allowed base. + + The default temp dir is created under tempfile.gettempdir() which is + also the default CLAUDE_CWD_ALLOWED_BASE — so no sandbox violation occurs. + """ + # For this test we use the real system temp dir as the allowed base + # (mirroring the production default) so the auto-created temp workspace + # is guaranteed to be inside the allowed base. + real_temp = tempfile.gettempdir() + + with patch("src.auth.validate_claude_code_auth") as mock_validate: + with patch("src.auth.auth_manager") as mock_auth: + with patch("atexit.register"): # Don't register real cleanup + mock_validate.return_value = (True, {"method": "anthropic"}) + mock_auth.get_claude_code_env_vars.return_value = {} + + with patch("src.claude_cli.CLAUDE_CWD_ALLOWED_BASE", real_temp): + from src.claude_cli import ClaudeCodeCLI + + cli = ClaudeCodeCLI() + + assert cli.temp_dir is not None + assert "claude_code_workspace_" in cli.temp_dir + + # Cleanup the auto-created temp dir + if cli.temp_dir and os.path.exists(cli.temp_dir): + import shutil + + shutil.rmtree(cli.temp_dir) + + def test_cwd_equal_to_allowed_base_itself_succeeds(self): + """Setting CWD to exactly the allowed base directory is permitted. + + The allowed base itself satisfies is_relative_to(allowed_base) because + a path is considered relative to itself. + """ + # Should not raise — the allowed base is a valid starting point + cli = self._make_cli(cwd=self.allowed_base) + + assert cli.cwd == Path(self.allowed_base).resolve() + assert cli.temp_dir is None + + def test_cwd_sandbox_error_message_contains_paths(self): + """ValueError raised for sandbox violation includes both the resolved CWD + and the allowed base in the message so operators can diagnose the problem. + """ + outside_dir = tempfile.gettempdir() + + with pytest.raises(ValueError) as exc_info: + self._make_cli(cwd=outside_dir) + + error_message = str(exc_info.value) + # The error must be descriptive — operators need to know what was rejected + assert "outside allowed base" in error_message diff --git a/tests/test_constants_unit.py b/tests/test_constants_unit.py new file mode 100644 index 0000000..c9e7589 --- /dev/null +++ b/tests/test_constants_unit.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +""" +Unit tests for new security configuration constants in src/constants.py. + +Tests default values and environment variable override behavior for: + - MAX_SESSIONS (FR-6.1) + - MAX_SESSION_MESSAGES (FR-6.2) + - TRUSTED_PROXIES (FR-3.1) + - CLAUDE_CWD_ALLOWED_BASE (FR-5.1) + +These are pure unit tests with no I/O or external dependencies. +""" + +import importlib +import os +import tempfile +from unittest.mock import patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _reload_constants(): + """Reload src.constants so module-level os.getenv() calls re-evaluate.""" + import src.constants + + importlib.reload(src.constants) + return src.constants + + +# --------------------------------------------------------------------------- +# MAX_SESSIONS +# --------------------------------------------------------------------------- + + +class TestMaxSessionsConstant: + """Tests for MAX_SESSIONS constant (FR-6.1).""" + + def test_max_sessions_default_is_1000(self): + """MAX_SESSIONS defaults to 1000 when env var is not set.""" + env = {k: v for k, v in os.environ.items() if k != "MAX_SESSIONS"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert constants.MAX_SESSIONS == 1000 + + def test_max_sessions_default_is_integer(self): + """MAX_SESSIONS default value is an int, not a string.""" + env = {k: v for k, v in os.environ.items() if k != "MAX_SESSIONS"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert isinstance(constants.MAX_SESSIONS, int) + + def test_max_sessions_can_be_overridden_via_env(self): + """MAX_SESSIONS can be set to a custom value via environment variable.""" + with patch.dict(os.environ, {"MAX_SESSIONS": "500"}): + constants = _reload_constants() + assert constants.MAX_SESSIONS == 500 + + def test_max_sessions_env_override_is_integer(self): + """MAX_SESSIONS env var override is parsed to int, not kept as string.""" + with patch.dict(os.environ, {"MAX_SESSIONS": "250"}): + constants = _reload_constants() + assert isinstance(constants.MAX_SESSIONS, int) + + +# --------------------------------------------------------------------------- +# MAX_SESSION_MESSAGES +# --------------------------------------------------------------------------- + + +class TestMaxSessionMessagesConstant: + """Tests for MAX_SESSION_MESSAGES constant (FR-6.2).""" + + def test_max_session_messages_default_is_100(self): + """MAX_SESSION_MESSAGES defaults to 100 when env var is not set.""" + env = {k: v for k, v in os.environ.items() if k != "MAX_SESSION_MESSAGES"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert constants.MAX_SESSION_MESSAGES == 100 + + def test_max_session_messages_default_is_integer(self): + """MAX_SESSION_MESSAGES default value is an int, not a string.""" + env = {k: v for k, v in os.environ.items() if k != "MAX_SESSION_MESSAGES"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert isinstance(constants.MAX_SESSION_MESSAGES, int) + + def test_max_session_messages_can_be_overridden_via_env(self): + """MAX_SESSION_MESSAGES can be set to a custom value via environment variable.""" + with patch.dict(os.environ, {"MAX_SESSION_MESSAGES": "50"}): + constants = _reload_constants() + assert constants.MAX_SESSION_MESSAGES == 50 + + def test_max_session_messages_env_override_is_integer(self): + """MAX_SESSION_MESSAGES env var override is parsed to int, not kept as string.""" + with patch.dict(os.environ, {"MAX_SESSION_MESSAGES": "25"}): + constants = _reload_constants() + assert isinstance(constants.MAX_SESSION_MESSAGES, int) + + +# --------------------------------------------------------------------------- +# TRUSTED_PROXIES +# --------------------------------------------------------------------------- + + +class TestTrustedProxiesConstant: + """Tests for TRUSTED_PROXIES constant (FR-3.1).""" + + def test_trusted_proxies_default_is_empty_list(self): + """TRUSTED_PROXIES defaults to an empty list when env var is not set.""" + env = {k: v for k, v in os.environ.items() if k != "TRUSTED_PROXIES"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert constants.TRUSTED_PROXIES == [] + + def test_trusted_proxies_default_is_list_type(self): + """TRUSTED_PROXIES default value is a list, not a string or None.""" + env = {k: v for k, v in os.environ.items() if k != "TRUSTED_PROXIES"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert isinstance(constants.TRUSTED_PROXIES, list) + + def test_trusted_proxies_single_ip_override(self): + """TRUSTED_PROXIES can be set to a single IP via environment variable.""" + with patch.dict(os.environ, {"TRUSTED_PROXIES": "10.0.0.1"}): + constants = _reload_constants() + assert constants.TRUSTED_PROXIES == ["10.0.0.1"] + + def test_trusted_proxies_multiple_ips_override(self): + """TRUSTED_PROXIES parses comma-separated IPs into a list of two entries.""" + with patch.dict(os.environ, {"TRUSTED_PROXIES": "10.0.0.1,10.0.0.2"}): + constants = _reload_constants() + assert constants.TRUSTED_PROXIES == ["10.0.0.1", "10.0.0.2"] + + def test_trusted_proxies_env_override_is_list_type(self): + """TRUSTED_PROXIES env var override is parsed to a list, not left as a string.""" + with patch.dict(os.environ, {"TRUSTED_PROXIES": "192.168.1.1,192.168.1.2"}): + constants = _reload_constants() + assert isinstance(constants.TRUSTED_PROXIES, list) + + def test_trusted_proxies_empty_env_var_gives_empty_list(self): + """An explicitly empty TRUSTED_PROXIES env var results in an empty list.""" + with patch.dict(os.environ, {"TRUSTED_PROXIES": ""}): + constants = _reload_constants() + assert constants.TRUSTED_PROXIES == [] + + +# --------------------------------------------------------------------------- +# CLAUDE_CWD_ALLOWED_BASE +# --------------------------------------------------------------------------- + + +class TestClaudeCwdAllowedBaseConstant: + """Tests for CLAUDE_CWD_ALLOWED_BASE constant (FR-5.1).""" + + def test_claude_cwd_allowed_base_default_is_tempdir(self): + """CLAUDE_CWD_ALLOWED_BASE defaults to the system temp directory.""" + env = {k: v for k, v in os.environ.items() if k != "CLAUDE_CWD_ALLOWED_BASE"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert constants.CLAUDE_CWD_ALLOWED_BASE == tempfile.gettempdir() + + def test_claude_cwd_allowed_base_default_is_string(self): + """CLAUDE_CWD_ALLOWED_BASE default value is a string.""" + env = {k: v for k, v in os.environ.items() if k != "CLAUDE_CWD_ALLOWED_BASE"} + with patch.dict(os.environ, env, clear=True): + constants = _reload_constants() + assert isinstance(constants.CLAUDE_CWD_ALLOWED_BASE, str) + + def test_claude_cwd_allowed_base_can_be_overridden_via_env(self): + """CLAUDE_CWD_ALLOWED_BASE can be set to a custom path via environment variable.""" + with patch.dict(os.environ, {"CLAUDE_CWD_ALLOWED_BASE": "/custom/path"}): + constants = _reload_constants() + assert constants.CLAUDE_CWD_ALLOWED_BASE == "/custom/path" + + def test_claude_cwd_allowed_base_env_override_is_string(self): + """CLAUDE_CWD_ALLOWED_BASE env var override remains a string.""" + with patch.dict(os.environ, {"CLAUDE_CWD_ALLOWED_BASE": "/srv/app/workspaces"}): + constants = _reload_constants() + assert isinstance(constants.CLAUDE_CWD_ALLOWED_BASE, str) + + +# --------------------------------------------------------------------------- +# Module-level cleanup fixture +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_constants_module(): + """Reload constants module after each test to restore default state. + + This prevents env var patches from leaking between tests via the + module-level os.getenv() calls evaluated at import time. + """ + yield + env = { + k: v + for k, v in os.environ.items() + if k + not in ( + "MAX_SESSIONS", + "MAX_SESSION_MESSAGES", + "TRUSTED_PROXIES", + "CLAUDE_CWD_ALLOWED_BASE", + ) + } + with patch.dict(os.environ, env, clear=True): + _reload_constants() diff --git a/tests/test_cors_unit.py b/tests/test_cors_unit.py new file mode 100644 index 0000000..64a58f2 --- /dev/null +++ b/tests/test_cors_unit.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +Unit tests for CORS middleware configuration in src/main.py. + +FR-1.1: When CORS_ORIGINS is ["*"] (the default), allow_credentials must NOT +be True. When CORS_ORIGINS contains specific origins, allow_credentials=True +is permitted so that credentialed requests work from those trusted origins. + +The CORS middleware is configured at module load time in src/main.py: + + cors_origins = json.loads(os.getenv("CORS_ORIGINS", '["*"]')) + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, # <-- BUG: always True + ... + ) + +Tests use importlib.reload to force src.main to re-read CORS_ORIGINS from +the environment, exactly as test_auth_unit.py and test_rate_limiter_unit.py +do for their environment-variable-driven configurations. + +The src.main lifespan startup (SDK verification) is bypassed by using +TestClient without a context-manager entry — TestClient only runs lifespan +startup/shutdown when used as a context manager. + +Architecture reference: FR-1.1, KD-1, Section 7.1. +""" + +import importlib +import os + +import pytest +from starlette.testclient import TestClient +from unittest.mock import patch, AsyncMock + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_test_client_for_cors_origins(cors_origins_json: str) -> TestClient: + """Reload src.main with CORS_ORIGINS set to the given JSON string and + return a TestClient wrapping the freshly constructed app. + + The Claude CLI verify_cli call is patched out so the reload completes + without network I/O. TestClient is NOT used as a context manager so the + lifespan startup (SDK connection) is not triggered. + """ + with patch.dict(os.environ, {"CORS_ORIGINS": cors_origins_json}): + with patch( + "src.claude_cli.ClaudeCodeCLI.verify_cli", + new_callable=AsyncMock, + return_value=True, + ): + import src.main + + importlib.reload(src.main) + app = src.main.app + + # TestClient without __enter__ skips lifespan + return TestClient(app, raise_server_exceptions=False) + + +# --------------------------------------------------------------------------- +# Tests: default CORS config (wildcard origins) +# --------------------------------------------------------------------------- + + +class TestCorsWildcardOrigins: + """Verify that the default CORS configuration (CORS_ORIGINS=["*"]) does + NOT set allow_credentials=True. + + Combining allow_origins=["*"] with allow_credentials=True is a security + misconfiguration: browsers refuse such a combination, and Starlette works + around it by echoing the requesting origin back — effectively granting + all origins credential access. + + Both tests in this class MUST FAIL against the current implementation + (which hardcodes allow_credentials=True) and MUST PASS after FR-1.1 is + implemented (which sets allow_credentials=False for wildcard origins). + """ + + @pytest.fixture + def client(self): + """TestClient configured with the default wildcard CORS_ORIGINS.""" + return _get_test_client_for_cors_origins('["*"]') + + def test_wildcard_cors_preflight_does_not_return_allow_credentials_true(self, client): + """Preflight from any origin must NOT get Access-Control-Allow-Credentials: true + when CORS_ORIGINS is ["*"]. + + This directly verifies FR-1.1: the combination of wildcard origins and + allow_credentials=True must not exist in the default configuration. + + Current code returns 'true' → test FAILS (RED). + Fixed code omits the header → test PASSES (GREEN). + """ + response = client.options( + "/health", + headers={ + "Origin": "http://evil.com", + "Access-Control-Request-Method": "GET", + }, + ) + + allow_credentials_header = response.headers.get("access-control-allow-credentials") + + assert allow_credentials_header != "true", ( + "Security misconfiguration: Access-Control-Allow-Credentials must not be " + "'true' when allow_origins is ['*']. Browsers reject this combination and " + "Starlette silently echoes the requesting origin, exposing credentials to " + "every origin unconditionally. Set allow_credentials=False when using " + "wildcard origins." + ) + + def test_wildcard_cors_preflight_returns_wildcard_origin_header(self, client): + """Preflight response must include Access-Control-Allow-Origin: * (the literal + wildcard string) when CORS_ORIGINS is ["*"] and allow_credentials is False. + + When allow_credentials=True is combined with wildcard origins, Starlette echoes + the requesting origin back instead of the '*' wildcard. This is the observable + symptom of the misconfiguration: any origin appears allowed with credentials. + After the fix (allow_credentials=False), Starlette correctly returns '*'. + + Current code echoes 'http://evil.com' → test FAILS (RED). + Fixed code returns '*' → test PASSES (GREEN). + """ + response = client.options( + "/health", + headers={ + "Origin": "http://evil.com", + "Access-Control-Request-Method": "GET", + }, + ) + + allow_origin_header = response.headers.get("access-control-allow-origin") + + assert allow_origin_header == "*", ( + f"Expected Access-Control-Allow-Origin: * for default (wildcard) CORS config, " + f"but received: {allow_origin_header!r}. The echoed origin indicates that " + f"allow_credentials=True is still set, which causes Starlette to replace '*' " + f"with the actual requesting origin — an unintended side-effect." + ) + + +# --------------------------------------------------------------------------- +# Tests: custom CORS config (specific origin list) +# --------------------------------------------------------------------------- + + +class TestCorsSpecificOrigins: + """Verify that a custom CORS config with specific origins DOES set + allow_credentials=True so that browsers accept credentialed requests. + + This is the positive counterpart to TestCorsWildcardOrigins: operators who + have restricted CORS to specific trusted origins must retain the ability to + make credentialed requests from those origins. + + The test in this class should PASS against both the current (broken) code + and the fixed code, serving as a non-regression guard to ensure the fix + does not inadvertently break specific-origin configurations. + """ + + @pytest.fixture + def client(self): + """TestClient configured with a specific trusted origin.""" + return _get_test_client_for_cors_origins('["http://localhost:3000"]') + + def test_specific_origins_preflight_from_allowed_origin_returns_allow_credentials_true( + self, client + ): + """Preflight from an explicitly allowed origin MUST get both the correct + Access-Control-Allow-Origin header and Access-Control-Allow-Credentials: true + when CORS_ORIGINS contains that origin. + + FR-1.1 acceptance criterion: 'Existing CORS_ORIGINS env var override still works.' + The fix must only remove credentials from the wildcard case, not from specific + origin configurations. + """ + response = client.options( + "/health", + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "GET", + }, + ) + + allow_origin_header = response.headers.get("access-control-allow-origin") + allow_credentials_header = response.headers.get("access-control-allow-credentials") + + assert allow_origin_header == "http://localhost:3000", ( + f"Expected Access-Control-Allow-Origin: http://localhost:3000 for a " + f"specific-origins CORS config, got: {allow_origin_header!r}" + ) + assert allow_credentials_header == "true", ( + "Expected Access-Control-Allow-Credentials: true for a specific-origins CORS " + "config. Operators who restrict CORS to trusted origins must be able to use " + "credentialed requests (cookies, Authorization headers)." + ) + + +# --------------------------------------------------------------------------- +# Module reset fixture +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_main_module(): + """Reload src.main after each test to prevent state leaking between tests. + + CORS_ORIGINS is read at module load time, so each reload must have a clean + environment. The same pattern is used in test_auth_unit.py and + test_rate_limiter_unit.py. + """ + yield + with patch( + "src.claude_cli.ClaudeCodeCLI.verify_cli", + new_callable=AsyncMock, + return_value=True, + ): + import src.main + + importlib.reload(src.main) diff --git a/tests/test_debug_logging_unit.py b/tests/test_debug_logging_unit.py new file mode 100644 index 0000000..17fb643 --- /dev/null +++ b/tests/test_debug_logging_unit.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +""" +Unit tests for debug logging redaction in DebugLoggingMiddleware (src/main.py). + +Tests FR-4.2 and FR-8.1: debug mode must not log full request bodies or headers +containing credentials/tokens. All sensitive values must be replaced with +'[REDACTED]' before any log call. + +Architecture reference: KD-3, Section 7.5 of architecture.md + - redact_request_headers(headers): redact Authorization header value + - redact_request_body(body): redact api_key, authorization, token, secret, password fields + - Startup warning logged when DEBUG_MODE is true + +These tests are written RED-first: the helper functions and redaction behaviour +do NOT exist yet in src/main.py, so every test here is expected to FAIL. +""" + +import importlib +import logging +import os + +import pytest + + +# --------------------------------------------------------------------------- +# Helper: ensure we always import a freshly-configured module so that +# module-level env-var reads (DEBUG_MODE etc.) pick up the patched values. +# --------------------------------------------------------------------------- + +def _reload_main_with_debug(debug_value: str = "true"): + """Reload src.main with DEBUG_MODE set to the given string value.""" + with pytest.MonkeyPatch().context() as mp: + mp.setenv("DEBUG_MODE", debug_value) + import src.main + importlib.reload(src.main) + return src.main + + +# --------------------------------------------------------------------------- +# Section 1: redact_request_headers helper function +# --------------------------------------------------------------------------- + +class TestRedactRequestHeaders: + """ + Tests for redact_request_headers(headers: dict) -> dict + + Expected contract (FR-4.2, FR-8.1): + - Returns a new dict (does not mutate input) + - The 'authorization' key (case-insensitive) has its value replaced with '[REDACTED]' + - All other header keys/values are preserved unchanged + """ + + def test_authorization_header_value_is_replaced_with_redacted(self): + """Authorization bearer token must not appear in the returned headers dict.""" + import src.main + + headers = {"authorization": "Bearer sk-secret-token-abc123"} + result = src.main.redact_request_headers(headers) + assert result["authorization"] == "[REDACTED]" + + def test_authorization_header_case_insensitive_upper(self): + """'Authorization' (title case) is also redacted.""" + import src.main + + headers = {"Authorization": "Bearer sk-secret-token-abc123"} + result = src.main.redact_request_headers(headers) + # The returned dict should not contain the raw token under any key variant + values = list(result.values()) + assert "Bearer sk-secret-token-abc123" not in values + + def test_non_sensitive_headers_are_preserved(self): + """Headers that are not Authorization pass through unchanged.""" + import src.main + + headers = { + "content-type": "application/json", + "x-request-id": "abc-123", + "accept": "application/json", + } + result = src.main.redact_request_headers(headers) + assert result["content-type"] == "application/json" + assert result["x-request-id"] == "abc-123" + assert result["accept"] == "application/json" + + def test_mixed_headers_authorization_redacted_others_preserved(self): + """Mix of sensitive and non-sensitive: only Authorization is redacted.""" + import src.main + + headers = { + "authorization": "Bearer real-secret", + "content-type": "application/json", + "user-agent": "test-client/1.0", + } + result = src.main.redact_request_headers(headers) + assert result["authorization"] == "[REDACTED]" + assert result["content-type"] == "application/json" + assert result["user-agent"] == "test-client/1.0" + + def test_input_dict_is_not_mutated(self): + """Original headers dict must not be modified.""" + import src.main + + headers = {"authorization": "Bearer sk-secret-token"} + original_value = headers["authorization"] + src.main.redact_request_headers(headers) + assert headers["authorization"] == original_value + + def test_headers_without_authorization_returned_unchanged(self): + """When no Authorization header present, result equals input.""" + import src.main + + headers = {"content-type": "application/json"} + result = src.main.redact_request_headers(headers) + assert result == {"content-type": "application/json"} + + def test_empty_headers_returns_empty_dict(self): + """Empty headers dict returns empty dict without error.""" + import src.main + + result = src.main.redact_request_headers({}) + assert result == {} + + +# --------------------------------------------------------------------------- +# Section 2: redact_request_body helper function +# --------------------------------------------------------------------------- + +class TestRedactRequestBody: + """ + Tests for redact_request_body(body: dict) -> dict + + Expected contract (FR-4.2, FR-8.1, architecture Section 7.5): + - Returns a new dict (does not mutate input) + - Fields named 'api_key', 'authorization', 'token', 'secret', 'password' + have their values replaced with '[REDACTED]' + - Non-sensitive fields ('model', 'messages', 'temperature', etc.) are preserved + - Field name matching is case-insensitive + """ + + def test_api_key_field_is_redacted(self): + """'api_key' field value is replaced with '[REDACTED]'.""" + import src.main + + body = {"api_key": "sk-ant-real-api-key-12345", "model": "claude-sonnet-4-6"} + result = src.main.redact_request_body(body) + assert result["api_key"] == "[REDACTED]" + + def test_authorization_field_is_redacted(self): + """'authorization' field in body is replaced with '[REDACTED]'.""" + import src.main + + body = {"authorization": "Bearer sk-secret", "model": "claude-sonnet-4-6"} + result = src.main.redact_request_body(body) + assert result["authorization"] == "[REDACTED]" + + def test_token_field_is_redacted(self): + """'token' field is replaced with '[REDACTED]'.""" + import src.main + + body = {"token": "my-secret-token-abc", "model": "claude-sonnet-4-6"} + result = src.main.redact_request_body(body) + assert result["token"] == "[REDACTED]" + + def test_secret_field_is_redacted(self): + """'secret' field is replaced with '[REDACTED]'.""" + import src.main + + body = {"secret": "super-secret-value", "model": "claude-sonnet-4-6"} + result = src.main.redact_request_body(body) + assert result["secret"] == "[REDACTED]" + + def test_password_field_is_redacted(self): + """'password' field is replaced with '[REDACTED]'.""" + import src.main + + body = {"password": "hunter2", "model": "claude-sonnet-4-6"} + result = src.main.redact_request_body(body) + assert result["password"] == "[REDACTED]" + + def test_model_field_is_preserved(self): + """'model' is a non-sensitive field and must not be redacted.""" + import src.main + + body = {"model": "claude-sonnet-4-6", "api_key": "sk-secret"} + result = src.main.redact_request_body(body) + assert result["model"] == "claude-sonnet-4-6" + + def test_messages_field_is_preserved(self): + """'messages' array is non-sensitive and must not be redacted.""" + import src.main + + messages = [{"role": "user", "content": "Hello"}] + body = {"messages": messages, "api_key": "sk-secret"} + result = src.main.redact_request_body(body) + assert result["messages"] == messages + + def test_temperature_field_is_preserved(self): + """'temperature' is non-sensitive and must not be redacted.""" + import src.main + + body = {"temperature": 0.7, "api_key": "sk-secret"} + result = src.main.redact_request_body(body) + assert result["temperature"] == 0.7 + + def test_all_sensitive_fields_redacted_simultaneously(self): + """All five sensitive field names are redacted in a single body dict.""" + import src.main + + body = { + "api_key": "sk-key", + "authorization": "Bearer tok", + "token": "tok123", + "secret": "shhh", + "password": "pw123", + "model": "claude-sonnet-4-6", + } + result = src.main.redact_request_body(body) + assert result["api_key"] == "[REDACTED]" + assert result["authorization"] == "[REDACTED]" + assert result["token"] == "[REDACTED]" + assert result["secret"] == "[REDACTED]" + assert result["password"] == "[REDACTED]" + assert result["model"] == "claude-sonnet-4-6" + + def test_input_dict_is_not_mutated(self): + """Original body dict must not be modified.""" + import src.main + + body = {"api_key": "sk-original-value"} + original_value = body["api_key"] + src.main.redact_request_body(body) + assert body["api_key"] == original_value + + def test_empty_body_returns_empty_dict(self): + """Empty body dict is returned as empty dict without error.""" + import src.main + + result = src.main.redact_request_body({}) + assert result == {} + + def test_body_without_sensitive_fields_returned_unchanged(self): + """Body with no sensitive keys is returned with all values intact.""" + import src.main + + body = {"model": "claude-opus-4-6", "max_tokens": 1024, "stream": False} + result = src.main.redact_request_body(body) + assert result == body + + +# --------------------------------------------------------------------------- +# Section 3: Middleware does not log raw Authorization header in debug mode +# --------------------------------------------------------------------------- + +class TestDebugMiddlewareHeaderRedactionInLogs: + """ + Integration-level test: when DebugLoggingMiddleware processes a request + in DEBUG_MODE, the logged output must contain '[REDACTED]' for the + Authorization value and must NOT contain the raw bearer token. + + Uses pytest caplog to capture logger output from src.main. + """ + + def test_authorization_header_not_logged_raw_in_debug_mode(self, caplog): + """ + Raw bearer token from Authorization header must not appear in any + debug log record when the middleware processes a request. + """ + import importlib + import src.main + + raw_token = "Bearer sk-super-secret-bearer-token-xyz789" + + with caplog.at_level(logging.DEBUG, logger="src.main"): + # Verify that debug logging for headers would redact the token. + # The helper function is the mechanism tested here; if it doesn't + # exist the import assertion below will fail first. + assert hasattr(src.main, "redact_request_headers"), ( + "redact_request_headers must exist in src.main" + ) + headers = {"authorization": raw_token, "content-type": "application/json"} + sanitized = src.main.redact_request_headers(headers) + # The raw token must not be present in the sanitized dict values + assert raw_token not in sanitized.values(), ( + f"Raw bearer token '{raw_token}' must not appear in sanitized headers" + ) + assert sanitized.get("authorization") == "[REDACTED]" + + def test_body_api_key_not_logged_raw_in_debug_mode(self, caplog): + """ + Raw api_key value must not appear in sanitized body log data. + """ + import src.main + + raw_key = "sk-ant-api03-real-secret-key-12345678" + + assert hasattr(src.main, "redact_request_body"), ( + "redact_request_body must exist in src.main" + ) + body = {"api_key": raw_key, "model": "claude-sonnet-4-6"} + sanitized = src.main.redact_request_body(body) + + assert raw_key not in sanitized.values(), ( + f"Raw API key '{raw_key}' must not appear in sanitized body" + ) + assert sanitized["api_key"] == "[REDACTED]" + + +# --------------------------------------------------------------------------- +# Section 4: Startup warning when DEBUG_MODE is enabled +# --------------------------------------------------------------------------- + +class TestDebugModeStartupWarning: + """ + FR-8.1: When DEBUG_MODE is enabled, a startup warning must be logged. + + The warning is expected to be emitted during application startup / + module load when DEBUG_MODE=true. We verify this by checking that + the logger at src.main level has been called with a WARNING-level + message containing a hint about debug mode being active. + """ + + def test_debug_mode_warning_is_logged_at_startup(self, caplog): + """When DEBUG_MODE=true, a WARNING log about debug mode must be emitted.""" + with caplog.at_level(logging.WARNING, logger="src.main"): + with pytest.MonkeyPatch().context() as mp: + mp.setenv("DEBUG_MODE", "true") + import src.main + importlib.reload(src.main) + + warning_messages = [ + record.message + for record in caplog.records + if record.levelno >= logging.WARNING and record.name == "src.main" + ] + assert any( + "debug" in msg.lower() or "DEBUG" in msg + for msg in warning_messages + ), ( + "Expected a WARNING-level log about debug mode being enabled at startup, " + f"but found only: {warning_messages}" + ) + + def test_no_debug_warning_when_debug_mode_disabled(self, caplog): + """When DEBUG_MODE=false, no debug-mode warning should be emitted.""" + with caplog.at_level(logging.WARNING, logger="src.main"): + with pytest.MonkeyPatch().context() as mp: + mp.setenv("DEBUG_MODE", "false") + import src.main + importlib.reload(src.main) + + debug_warnings = [ + record.message + for record in caplog.records + if record.levelno >= logging.WARNING + and record.name == "src.main" + and "debug" in record.message.lower() + ] + assert len(debug_warnings) == 0, ( + f"Unexpected debug warning when DEBUG_MODE=false: {debug_warnings}" + ) + + +# --------------------------------------------------------------------------- +# Reset module state after each test class to prevent state leakage +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def reset_main_module(): + """Reload src.main after each test to prevent module-level state leakage.""" + yield + import src.main + importlib.reload(src.main) diff --git a/tests/test_endpoint_security_unit.py b/tests/test_endpoint_security_unit.py new file mode 100644 index 0000000..52a3e39 --- /dev/null +++ b/tests/test_endpoint_security_unit.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +""" +Unit tests for endpoint security fixes (Task T020). + +Covers: + FR-7.1 — /v1/debug/request must require authentication + FR-7.2 — /v1/auth/status must return only {"authenticated": true/false} + FR-3.2 — / and /v1/compatibility must expose rate-limit headers + +All tests are written against the CURRENT (unfixed) implementation and +MUST FAIL until the fixes described in spec.md are applied. + +Import strategy +--------------- +main.py runs ClaudeCodeCLI + session_manager startup logic at module level. +We suppress the lifespan by disabling the background cleanup task and +patching the slowapi state-attachment so TestClient can import the app +without needing a real Claude SDK installation. + +The app object itself (FastAPI instance) is imported once per-process; the +TestClient wraps it without triggering the lifespan (we do NOT use +`with TestClient(app) as client:` which triggers startup/shutdown events). +""" + +import os +import importlib +import pytest +from unittest.mock import patch, AsyncMock, MagicMock + +# --------------------------------------------------------------------------- +# App fixture +# --------------------------------------------------------------------------- + +TEST_API_KEY = "test-key-12345678" + + +def _get_test_client(): + """ + Import and return a TestClient wrapping the FastAPI app. + + We set API_KEY before importing so the verify_api_key dependency uses a + known value. The lifespan is NOT executed when TestClient is instantiated + without entering a context manager — FastAPI docs confirm that startup + events only fire inside `with TestClient(app)`. + + We patch asyncio.wait_for used inside the lifespan to be a no-op so that + even if TestClient *does* try to run startup it does not block indefinitely. + """ + # Patch the blocking verify_cli call and cleanup-task start that happen + # during lifespan, to prevent hangs in any test runner that triggers it. + with patch.dict(os.environ, {"API_KEY": TEST_API_KEY}, clear=False): + # Delay import until env is set so auth module picks up API_KEY + import src.main as main_module + + importlib.reload(src.main) + + from starlette.testclient import TestClient + + # TestClient without context manager does NOT run lifespan events. + client = TestClient(main_module.app, raise_server_exceptions=False) + return client, main_module.app + + +# --------------------------------------------------------------------------- +# FR-7.1 — Debug endpoint authentication +# --------------------------------------------------------------------------- + + +class TestDebugEndpointRequiresAuth: + """ + FR-7.1: POST /v1/debug/request must be protected by the verify_api_key + dependency. + + Current state: the endpoint has no Depends(security) / verify_api_key + call, so unauthenticated requests return 200. Tests MUST FAIL until + the auth guard is added. + """ + + def test_debug_endpoint_without_auth_header_returns_401_or_403(self): + """ + POST /v1/debug/request with no Authorization header must return + 401 (Unauthorized) or 403 (Forbidden) when API_KEY is configured. + + CURRENT BEHAVIOUR: returns 200 — test will FAIL (RED phase). + """ + with patch.dict(os.environ, {"API_KEY": TEST_API_KEY}, clear=False): + import src.auth + + importlib.reload(src.auth) + import src.main + + importlib.reload(src.main) + + from starlette.testclient import TestClient + + client = TestClient(src.main.app, raise_server_exceptions=False) + + response = client.post( + "/v1/debug/request", + json={ + "model": "claude-3-sonnet-20240229", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + + assert response.status_code in (401, 403), ( + f"Expected 401 or 403 for unauthenticated debug request, " + f"got {response.status_code}. " + "The /v1/debug/request endpoint currently has no auth guard — " + "FR-7.1 fix is required." + ) + + def test_debug_endpoint_with_valid_auth_returns_200(self): + """ + POST /v1/debug/request with a valid Bearer token must return 200. + + This test verifies the endpoint still works once the auth guard is in + place with correct credentials. Rate limiting is disabled so that the + in-process limiter counter from the preceding test does not cause a + spurious 429. + + CURRENT BEHAVIOUR: also returns 200 (no auth check), so this test + passes now — but it is paired with the above to document the full + contract. It is kept here so that it continues to pass after the + fix is applied, acting as a regression guard. + + NOTE: Because this test currently passes, the RED-phase failure is + driven entirely by the previous test. Both are included as a pair + to fully specify the authenticated-access contract. + """ + env = dict(os.environ) + env["API_KEY"] = TEST_API_KEY + env["RATE_LIMIT_ENABLED"] = "false" # isolate from rate-limiter state + with patch.dict(os.environ, env, clear=True): + import src.auth + import src.rate_limiter + import src.main + + importlib.reload(src.auth) + importlib.reload(src.rate_limiter) + importlib.reload(src.main) + + from starlette.testclient import TestClient + + client = TestClient(src.main.app, raise_server_exceptions=False) + + response = client.post( + "/v1/debug/request", + json={ + "model": "claude-3-sonnet-20240229", + "messages": [{"role": "user", "content": "hello"}], + }, + headers={"Authorization": f"Bearer {TEST_API_KEY}"}, + ) + + assert ( + response.status_code == 200 + ), f"Expected 200 for authenticated debug request, got {response.status_code}." + + +# --------------------------------------------------------------------------- +# FR-7.2 — Auth status response must be stripped +# --------------------------------------------------------------------------- + + +class TestAuthStatusResponseStripped: + """ + FR-7.2: GET /v1/auth/status must return ONLY {"authenticated": true/false}. + + Current state: the endpoint returns a verbose object containing + claude_code_auth (with method, status, environment_variables) and + server_info (with api_key_required, api_key_source, version). + All three tests below MUST FAIL until the response body is stripped. + """ + + def _get_auth_status_response(self): + """Helper: call GET /v1/auth/status and return the response.""" + # No API_KEY set so the endpoint is publicly reachable (matches + # current behaviour and keeps the test independent of auth state) + env = {k: v for k, v in os.environ.items() if k != "API_KEY"} + with patch.dict(os.environ, env, clear=True): + import src.auth + + importlib.reload(src.auth) + import src.main + + importlib.reload(src.main) + + from starlette.testclient import TestClient + + client = TestClient(src.main.app, raise_server_exceptions=False) + return client.get("/v1/auth/status") + + def test_auth_status_response_contains_only_authenticated_key(self): + """ + Response body must contain ONLY the key "authenticated". + + CURRENT BEHAVIOUR: body also contains "claude_code_auth" and + "server_info" — test will FAIL (RED phase). + """ + response = self._get_auth_status_response() + assert response.status_code == 200 + + body = response.json() + assert set(body.keys()) == {"authenticated"}, ( + f"Response body keys must be exactly {{'authenticated'}}, " + f"got {set(body.keys())}. " + "FR-7.2 requires stripping all auth method details from this endpoint." + ) + + def test_auth_status_response_does_not_contain_claude_code_auth(self): + """ + Response body must NOT contain the 'claude_code_auth' key which + reveals the authentication method and configuration details. + + CURRENT BEHAVIOUR: 'claude_code_auth' is present — test will FAIL. + """ + response = self._get_auth_status_response() + assert response.status_code == 200 + + body = response.json() + assert "claude_code_auth" not in body, ( + "'claude_code_auth' key must not appear in /v1/auth/status response. " + "It reveals the auth strategy (method name, env vars, config). " + "FR-7.2 requires removing it." + ) + + def test_auth_status_response_does_not_contain_server_info(self): + """ + Response body must NOT contain the 'server_info' key which reveals + whether an API key is required and how it is sourced. + + CURRENT BEHAVIOUR: 'server_info' is present — test will FAIL. + """ + response = self._get_auth_status_response() + assert response.status_code == 200 + + body = response.json() + assert "server_info" not in body, ( + "'server_info' key must not appear in /v1/auth/status response. " + "It reveals api_key_required, api_key_source, and version — " + "reconnaissance information per KD-8. FR-7.2 requires removing it." + ) + + def test_auth_status_authenticated_value_is_bool(self): + """ + The 'authenticated' value must be a boolean (true or false). + + This test documents the shape of the stripped response so that the + implementer knows exactly what the body should look like after the fix. + + CURRENT BEHAVIOUR: 'authenticated' key does not exist — test will FAIL. + """ + response = self._get_auth_status_response() + assert response.status_code == 200 + + body = response.json() + assert ( + "authenticated" in body + ), "Response must contain 'authenticated' key after FR-7.2 fix." + assert isinstance( + body["authenticated"], bool + ), f"'authenticated' value must be a bool, got {type(body['authenticated'])}." + + +# --------------------------------------------------------------------------- +# FR-3.2 — Rate limit headers on previously unprotected endpoints +# --------------------------------------------------------------------------- + + +class TestUnprotectedEndpointsHaveRateLimitHeaders: + """ + FR-3.2: GET / and POST /v1/compatibility must return rate-limit response + headers (e.g. X-RateLimit-Limit) indicating slowapi is active. + + Current state: neither endpoint has a @rate_limit_endpoint decorator, + so no rate-limit headers are present. Tests MUST FAIL until the + decorator is added (architecture sections 9.1, 9.2). + + Implementation note: slowapi injects headers such as + X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset + only on endpoints that carry the @limiter.limit(...) decoration. + Endpoints without the decorator return no such headers regardless of + whether the global limiter is active. + """ + + RATE_LIMIT_HEADER_PREFIXES = ( + "X-RateLimit-Limit", + "X-Ratelimit-Limit", + "RateLimit-Limit", + ) + + def _has_any_rate_limit_header(self, headers: dict) -> bool: + """Return True if any rate-limit indicator header is present.""" + headers_lower = {k.lower(): v for k, v in headers.items()} + for prefix in self.RATE_LIMIT_HEADER_PREFIXES: + if prefix.lower() in headers_lower: + return True + # Also accept the generic Retry-After that slowapi sets when limited + # (not applicable here, but belt-and-suspenders) + return False + + def _make_client(self): + """Return a TestClient with rate limiting enabled.""" + env = {k: v for k, v in os.environ.items() if k != "API_KEY"} + env["RATE_LIMIT_ENABLED"] = "true" + + with patch.dict(os.environ, env, clear=True): + import src.rate_limiter + import src.main + + importlib.reload(src.rate_limiter) + importlib.reload(src.main) + + from starlette.testclient import TestClient + + return TestClient(src.main.app, raise_server_exceptions=False) + + def test_root_endpoint_has_rate_limit_header(self): + """ + GET / must include at least one rate-limit header (e.g. X-RateLimit-Limit) + when the global limiter is active. + + CURRENT BEHAVIOUR: no rate-limit headers are returned because the root + endpoint has no @rate_limit_endpoint decorator — test will FAIL (RED phase). + + Architecture reference: Section 9.1 — root() must receive + `request: Request` parameter and @rate_limit_endpoint("general") decorator. + """ + client = self._make_client() + response = client.get("/") + + assert response.status_code == 200, f"GET / returned {response.status_code}, expected 200." + assert self._has_any_rate_limit_header(dict(response.headers)), ( + f"GET / response has no rate-limit headers. " + f"Headers present: {list(response.headers.keys())}. " + "FR-3.2 requires @rate_limit_endpoint('general') on the root endpoint." + ) + + def test_compatibility_endpoint_has_rate_limit_header(self): + """ + POST /v1/compatibility must include at least one rate-limit header when + the global limiter is active. + + CURRENT BEHAVIOUR: no rate-limit headers are returned because + check_compatibility() has no @rate_limit_endpoint decorator and no + `request: Request` parameter — test will FAIL (RED phase). + + Architecture reference: Section 9.2 — check_compatibility() must receive + `request: Request` parameter and @rate_limit_endpoint("general") decorator. + """ + client = self._make_client() + response = client.post( + "/v1/compatibility", + json={ + "model": "claude-3-sonnet-20240229", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + + assert ( + response.status_code == 200 + ), f"POST /v1/compatibility returned {response.status_code}, expected 200." + assert self._has_any_rate_limit_header(dict(response.headers)), ( + f"POST /v1/compatibility response has no rate-limit headers. " + f"Headers present: {list(response.headers.keys())}. " + "FR-3.2 requires @rate_limit_endpoint('general') on this endpoint." + ) + + +# --------------------------------------------------------------------------- +# Module reload teardown +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_main_module(): + """Reload src.main after each test to avoid cross-test module pollution.""" + yield + import src.main + import src.auth + + importlib.reload(src.auth) + importlib.reload(src.main) diff --git a/tests/test_model_warning_unit.py b/tests/test_model_warning_unit.py new file mode 100644 index 0000000..230643b --- /dev/null +++ b/tests/test_model_warning_unit.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +""" +Unit and integration tests for the model warning header (FR-9.1). + +Tests that: +- ParameterValidator.is_model_recognized() correctly identifies known vs unknown models +- The /v1/chat/completions endpoint adds X-Claude-Model-Warning: unrecognized + when the requested model is not in the known Claude model list +- The header is NOT added when the model is recognized + +These tests are in RED phase. The integration tests for the warning header will +FAIL against current code because main.py does not yet set the header. +""" + +import os +import json +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from src.parameter_validator import ParameterValidator + + +# --------------------------------------------------------------------------- +# Unit tests — ParameterValidator.is_model_recognized() +# These verify the helper logic that the endpoint should use. +# --------------------------------------------------------------------------- + + +class TestIsModelRecognized: + """Test ParameterValidator.is_model_recognized() returns correct bool.""" + + def test_is_model_recognized_known_sonnet_46_returns_true(self): + """claude-sonnet-4-6 is in SUPPORTED_MODELS and must return True.""" + assert ParameterValidator.is_model_recognized("claude-sonnet-4-6") is True + + def test_is_model_recognized_known_opus_46_returns_true(self): + """claude-opus-4-6 is in SUPPORTED_MODELS and must return True.""" + assert ParameterValidator.is_model_recognized("claude-opus-4-6") is True + + def test_is_model_recognized_known_sonnet_45_dated_returns_true(self): + """claude-sonnet-4-5-20250929 is in SUPPORTED_MODELS and must return True.""" + assert ParameterValidator.is_model_recognized("claude-sonnet-4-5-20250929") is True + + def test_is_model_recognized_openai_model_returns_false(self): + """gpt-4-turbo is not a Claude model and must return False.""" + assert ParameterValidator.is_model_recognized("gpt-4-turbo") is False + + def test_is_model_recognized_arbitrary_unknown_model_returns_false(self): + """A made-up model name is not in SUPPORTED_MODELS and must return False.""" + assert ParameterValidator.is_model_recognized("nonexistent-model-xyz") is False + + def test_is_model_recognized_empty_string_returns_false(self): + """Empty string is not in SUPPORTED_MODELS and must return False.""" + assert ParameterValidator.is_model_recognized("") is False + + def test_is_model_recognized_all_claude_models_return_true(self): + """Every model in SUPPORTED_MODELS must be recognized.""" + for model in ParameterValidator.SUPPORTED_MODELS: + assert ( + ParameterValidator.is_model_recognized(model) is True + ), f"Expected {model!r} to be recognized but it was not" + + +# --------------------------------------------------------------------------- +# Integration tests — /v1/chat/completions endpoint header behavior +# +# These tests use FastAPI's TestClient (httpx-based) and mock out the Claude +# Agent SDK so no real API calls are made. +# +# RED: The tests that check for X-Claude-Model-Warning will FAIL until +# main.py is updated to set the header for unrecognized models. +# --------------------------------------------------------------------------- + + +def _make_async_generator(chunks): + """Helper: create an async generator that yields the given chunks.""" + + async def _gen(*args, **kwargs): + for chunk in chunks: + yield chunk + + return _gen + + +def _mock_run_completion_chunks(): + """Return a list of Claude SDK message dicts that represent a valid response.""" + return [ + { + "type": "assistant", + "subtype": "success", + "result": "Hello from mocked Claude", + "total_cost_usd": 0.0, + "duration_ms": 100, + "num_turns": 1, + "session_id": "mock-session-id", + } + ] + + +@pytest.fixture(scope="module") +def test_client(): + """ + Create a TestClient for the FastAPI app with all external calls mocked. + + Patches applied at module import time: + - ClaudeCodeCLI.__init__ — prevents real subprocess/auth setup + - validate_claude_code_auth — returns (True, {}) so auth passes + - claude_cli.run_completion — returns a mocked async generator + - session_manager.start_cleanup_task — prevents background task noise + """ + with ( + patch("src.claude_cli.ClaudeCodeCLI.__init__", return_value=None), + patch("src.auth.validate_claude_code_auth", return_value=(True, {"method": "mock"})), + ): + # Import app AFTER patching ClaudeCodeCLI so the module-level + # `claude_cli = ClaudeCodeCLI(...)` call in main.py succeeds. + from src.main import app + + # Patch the module-level claude_cli instance used by the endpoint. + mock_cli = MagicMock() + mock_cli.run_completion = _make_async_generator(_mock_run_completion_chunks()) + mock_cli.parse_claude_message = MagicMock(return_value="Hello from mocked Claude") + mock_cli.extract_metadata = MagicMock(return_value={}) + mock_cli.estimate_token_usage = MagicMock( + return_value={"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10} + ) + + import src.main as main_module + import src.session_manager as sm_module + + original_cli = main_module.claude_cli + original_start = sm_module.session_manager.start_cleanup_task + + main_module.claude_cli = mock_cli + sm_module.session_manager.start_cleanup_task = MagicMock() + + from fastapi.testclient import TestClient + + with TestClient(app, raise_server_exceptions=True) as client: + yield client + + # Restore originals (best-effort; module is cached anyway) + main_module.claude_cli = original_cli + sm_module.session_manager.start_cleanup_task = original_start + + +def _chat_request_body(model: str) -> dict: + """Return a minimal /v1/chat/completions request body for the given model.""" + return { + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + "stream": False, + } + + +def _auth_headers(api_key: str) -> dict: + return {"Authorization": f"Bearer {api_key}"} + + +class TestModelWarningHeaderEndpoint: + """ + Integration tests for the X-Claude-Model-Warning response header. + + FR-9.1: When a chat completion request uses an unrecognized model the + response MUST include the header X-Claude-Model-Warning: unrecognized. + When a known model is used the header MUST NOT be present. + """ + + def test_unknown_model_response_has_warning_header(self, test_client): + """ + RED: Request with an unrecognized model returns X-Claude-Model-Warning: unrecognized. + + This test FAILS against current code because main.py does not set the header. + It will pass once main.py is updated (GREEN phase). + """ + api_key = "test-key-for-warning-header" + + with ( + patch.dict(os.environ, {"API_KEY": api_key}), + patch("src.main.validate_claude_code_auth", return_value=(True, {"method": "mock"})), + patch("src.main.claude_cli") as mock_cli_patch, + ): + mock_cli_patch.run_completion = _make_async_generator(_mock_run_completion_chunks()) + mock_cli_patch.parse_claude_message = MagicMock(return_value="Hello from mocked Claude") + + response = test_client.post( + "/v1/chat/completions", + json=_chat_request_body("gpt-4-turbo"), + headers=_auth_headers(api_key), + ) + + assert response.status_code == 200 + assert ( + "x-claude-model-warning" in response.headers + ), "Expected X-Claude-Model-Warning header in response for unrecognized model 'gpt-4-turbo'" + assert ( + response.headers["x-claude-model-warning"] == "unrecognized" + ), "Expected X-Claude-Model-Warning header value to be 'unrecognized'" + + def test_known_model_response_has_no_warning_header(self, test_client): + """ + When a known Claude model is used, no X-Claude-Model-Warning header is present. + + This test documents the expected ABSENCE of the header for recognized models. + It may pass or fail depending on implementation details; we include it to + ensure the implementation does not spam the header on every response. + """ + api_key = "test-key-for-warning-header" + + with ( + patch.dict(os.environ, {"API_KEY": api_key}), + patch("src.main.validate_claude_code_auth", return_value=(True, {"method": "mock"})), + patch("src.main.claude_cli") as mock_cli_patch, + ): + mock_cli_patch.run_completion = _make_async_generator(_mock_run_completion_chunks()) + mock_cli_patch.parse_claude_message = MagicMock(return_value="Hello from mocked Claude") + + response = test_client.post( + "/v1/chat/completions", + json=_chat_request_body("claude-sonnet-4-6"), + headers=_auth_headers(api_key), + ) + + assert response.status_code == 200 + assert ( + "x-claude-model-warning" not in response.headers + ), "Expected no X-Claude-Model-Warning header for recognized model 'claude-sonnet-4-6'" + + def test_nonexistent_model_string_triggers_warning_header(self, test_client): + """ + RED: A completely made-up model name also triggers the warning header. + + This test FAILS against current code. + """ + api_key = "test-key-for-warning-header" + + with ( + patch.dict(os.environ, {"API_KEY": api_key}), + patch("src.main.validate_claude_code_auth", return_value=(True, {"method": "mock"})), + patch("src.main.claude_cli") as mock_cli_patch, + ): + mock_cli_patch.run_completion = _make_async_generator(_mock_run_completion_chunks()) + mock_cli_patch.parse_claude_message = MagicMock(return_value="Hello from mocked Claude") + + response = test_client.post( + "/v1/chat/completions", + json=_chat_request_body("nonexistent-model-99999"), + headers=_auth_headers(api_key), + ) + + assert response.status_code == 200 + assert ( + "x-claude-model-warning" in response.headers + ), "Expected X-Claude-Model-Warning header for completely unknown model" + assert response.headers["x-claude-model-warning"] == "unrecognized" diff --git a/tests/test_models_unit.py b/tests/test_models_unit.py index 5e6387d..5959fbe 100644 --- a/tests/test_models_unit.py +++ b/tests/test_models_unit.py @@ -10,6 +10,7 @@ from datetime import datetime from unittest.mock import patch +from src.constants import CLAUDE_MODELS, DEFAULT_MODEL, FAST_MODEL from src.models import ( ContentPart, Message, @@ -570,3 +571,106 @@ def test_anthropic_messages_response(self): assert response.role == "assistant" assert response.stop_reason == "end_turn" assert response.id.startswith("msg_") + + +class TestClaude46ModelSupport: + """Tests for FR-11: Claude 4.6 model support in constants.py. + + These tests verify that Claude 4.6 model identifiers are present in + CLAUDE_MODELS, that DEFAULT_MODEL has been updated to a 4.6 model, + and that FAST_MODEL remains unchanged. + + Requirement: FR-11.1, FR-11.2, FR-11.3 + """ + + def test_claude_models_contains_opus_4_6(self): + """FR-11.1: CLAUDE_MODELS includes a claude-opus-4-6 variant. + + Accepts both the alias form 'claude-opus-4-6' and a dated variant + such as 'claude-opus-4-6-20260301'. + """ + has_opus_46 = any("opus-4-6" in model for model in CLAUDE_MODELS) + assert has_opus_46, ( + "CLAUDE_MODELS must contain a claude-opus-4-6 entry " + "(exact alias or dated variant, e.g. 'claude-opus-4-6-20260301'). " + f"Current models: {CLAUDE_MODELS}" + ) + + def test_claude_models_contains_sonnet_4_6(self): + """FR-11.1: CLAUDE_MODELS includes a claude-sonnet-4-6 variant. + + Accepts both the alias form 'claude-sonnet-4-6' and a dated variant + such as 'claude-sonnet-4-6-20260301'. + """ + has_sonnet_46 = any("sonnet-4-6" in model for model in CLAUDE_MODELS) + assert has_sonnet_46, ( + "CLAUDE_MODELS must contain a claude-sonnet-4-6 entry " + "(exact alias or dated variant, e.g. 'claude-sonnet-4-6-20260301'). " + f"Current models: {CLAUDE_MODELS}" + ) + + def test_default_model_is_4_6_family(self): + """FR-11.2: DEFAULT_MODEL resolves to a 4.6 model when env var is not set. + + The default (env-unset) value must contain '4-6' so that clients + automatically use the new model generation without additional config. + """ + import os + + # Only meaningful when DEFAULT_MODEL env var is not overridden by caller. + # We test the hardcoded fallback, not the os.getenv result, by importing + # the raw default from the module source via the already-imported constant. + # If the operator has set DEFAULT_MODEL in their environment this test + # is skipped so it doesn't false-positive against a deliberate override. + env_override = os.environ.get("DEFAULT_MODEL") + if env_override is not None: + pytest.skip("DEFAULT_MODEL env var is set; skipping hardcoded-default check") + + assert "4-6" in DEFAULT_MODEL, ( + f"DEFAULT_MODEL should contain '4-6' when no env var is set. " f"Got: '{DEFAULT_MODEL}'" + ) + + def test_fast_model_is_haiku_4_5(self): + """FR-11.2 (stability): FAST_MODEL remains claude-haiku-4-5 after the upgrade. + + The fast/cheap model alias must not silently change — consumers who + opt into FAST_MODEL for speed/cost reasons depend on it staying haiku-4-5. + """ + assert ( + "haiku-4-5" in FAST_MODEL + ), f"FAST_MODEL must still be the haiku-4-5 variant. Got: '{FAST_MODEL}'" + + def test_constants_module_docstring_references_4_6_family(self): + """FR-11.3: constants.py module-level comments/docstring reference 4.6 models. + + Stale 'latest' comments pointing only at 4.5 must be updated so that + developers reading the source are not misled about the current model family. + """ + import inspect + import src.constants as constants_module + + source = inspect.getsource(constants_module) + assert "4.6" in source or "4-6" in source, ( + "constants.py source must reference the 4.6 model family in comments " + "or docstrings (e.g. '4.6' or '4-6'). Update the CLAUDE_MODELS block comment." + ) + + def test_claude_models_comment_does_not_label_4_5_as_latest(self): + """FR-11.3: Source comments must not label 4.5 as 'Latest' after 4.6 is added. + + This guards against stale 'Latest' markers in the CLAUDE_MODELS block + that would mislead developers into thinking 4.5 is still the newest family. + """ + import inspect + import src.constants as constants_module + + source = inspect.getsource(constants_module) + # A comment like "4.5 Family (Latest" is only acceptable if 4.6 is NOT present. + # Once 4.6 is added the 4.5 block must no longer be labelled Latest. + has_46 = any("4-6" in m for m in CLAUDE_MODELS) + if has_46: + # 4.6 is present — 4.5 must not be called "Latest" + assert "4.5 Family (Latest" not in source and "4-5 Family (Latest" not in source, ( + "constants.py still labels the 4.5 family as 'Latest' even though 4.6 " + "models are now present. Update the comment to reflect 4.6 as the current latest." + ) diff --git a/tests/test_parameter_validator_unit.py b/tests/test_parameter_validator_unit.py index 3c31945..9082eac 100644 --- a/tests/test_parameter_validator_unit.py +++ b/tests/test_parameter_validator_unit.py @@ -362,3 +362,37 @@ def test_minimal_request_has_no_unsupported(self, minimal_request): """Minimal request with defaults has no unsupported parameters.""" report = CompatibilityReporter.generate_compatibility_report(minimal_request) assert len(report["unsupported_parameters"]) == 0 + + +class TestParameterValidatorIsModelRecognized: + """Test ParameterValidator.is_model_recognized() + + This method returns True only when the model string is present in + SUPPORTED_MODELS (the recognized Claude model list). It is used by the + chat endpoint to decide whether to attach an X-Claude-Model-Warning header. + """ + + def test_is_model_recognized_known_sonnet_46_returns_true(self): + """claude-sonnet-4-6 is a known supported model and must return True.""" + result = ParameterValidator.is_model_recognized("claude-sonnet-4-6") + assert result is True + + def test_is_model_recognized_known_sonnet_45_dated_returns_true(self): + """claude-sonnet-4-5-20250929 is a known supported model and must return True.""" + result = ParameterValidator.is_model_recognized("claude-sonnet-4-5-20250929") + assert result is True + + def test_is_model_recognized_unknown_openai_model_returns_false(self): + """gpt-4-turbo is not a Claude model and must return False.""" + result = ParameterValidator.is_model_recognized("gpt-4-turbo") + assert result is False + + def test_is_model_recognized_empty_string_returns_false(self): + """Empty string is not in SUPPORTED_MODELS and must return False.""" + result = ParameterValidator.is_model_recognized("") + assert result is False + + def test_is_model_recognized_typo_model_returns_false(self): + """Model string with a suffix typo is not in SUPPORTED_MODELS and must return False.""" + result = ParameterValidator.is_model_recognized("claude-sonnet-4-6-WRONG") + assert result is False diff --git a/tests/test_rate_limiter_unit.py b/tests/test_rate_limiter_unit.py index 2b14157..064fe9e 100644 --- a/tests/test_rate_limiter_unit.py +++ b/tests/test_rate_limiter_unit.py @@ -19,16 +19,22 @@ class TestGetRateLimitKey: """Test get_rate_limit_key()""" def test_returns_remote_address(self): - """Should return the remote address from the request.""" - with patch("src.rate_limiter.get_remote_address") as mock_get_addr: - mock_get_addr.return_value = "192.168.1.100" - mock_request = MagicMock(spec=Request) + """Should return the direct peer IP from the request when no trusted proxies are set.""" + import importlib + import src.constants + import src.rate_limiter - from src.rate_limiter import get_rate_limit_key + mock_request = MagicMock(spec=Request) + mock_request.client = MagicMock() + mock_request.client.host = "192.168.1.100" + mock_request.headers = {} - result = get_rate_limit_key(mock_request) - assert result == "192.168.1.100" - mock_get_addr.assert_called_once_with(mock_request) + with patch.dict(os.environ, {"TRUSTED_PROXIES": ""}): + importlib.reload(src.constants) + importlib.reload(src.rate_limiter) + result = src.rate_limiter.get_rate_limit_key(mock_request) + + assert result == "192.168.1.100" class TestCreateRateLimiter: @@ -269,6 +275,213 @@ def my_endpoint(): assert my_endpoint() == "hello" +class TestGetRateLimitKeyTrustedProxy: + """Test get_rate_limit_key() with trusted proxy support (FR-3.1). + + The function must only trust X-Forwarded-For when the direct peer IP is + in TRUSTED_PROXIES. When TRUSTED_PROXIES is empty or the peer is not + trusted, the function must ignore X-Forwarded-For entirely to prevent + IP spoofing attacks that would allow bypassing rate limits. + + Architecture reference: Section 5.3 and 7.2. + + Patching strategy: the new implementation will read TRUSTED_PROXIES from + src.constants at module import time (from src.constants import + TRUSTED_PROXIES). We therefore reload src.rate_limiter after patching the + TRUSTED_PROXIES environment variable so that each test gets a fresh module + with the desired proxy list. Using importlib.reload mirrors the pattern + already established in this file for RATE_LIMIT_ENABLED. + + All five tests MUST FAIL against the current implementation because + get_rate_limit_key() does not yet consult TRUSTED_PROXIES at all: it + delegates unconditionally to get_remote_address() which ignores XFF and + always returns request.client.host. The failures prove the feature is + missing and define the contract the implementer must satisfy. + """ + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _make_request(client_ip: str, x_forwarded_for: str = None): + """Return a mock Request with the given peer IP and optional XFF header. + + Headers are stored as a plain dict. The new implementation must call + request.headers.get("x-forwarded-for") which MagicMock satisfies when + headers is a dict-like object (MagicMock.__getitem__ is available, but + .get() on a plain dict works too). + """ + mock_request = MagicMock(spec=Request) + mock_request.client = MagicMock() + mock_request.client.host = client_ip + + # Use a real dict so .get() behaves correctly + headers = {} + if x_forwarded_for is not None: + headers["x-forwarded-for"] = x_forwarded_for + mock_request.headers = headers + + return mock_request + + @staticmethod + def _load_get_rate_limit_key(trusted_proxies_value: str): + """Reload src.rate_limiter with TRUSTED_PROXIES set to the given + comma-separated string and return the freshly bound get_rate_limit_key. + + This forces the module to re-evaluate TRUSTED_PROXIES from constants + (which reads os.environ) so each test exercises an isolated state. + """ + import importlib + import src.constants + import src.rate_limiter + + with patch.dict(os.environ, {"TRUSTED_PROXIES": trusted_proxies_value}): + importlib.reload(src.constants) + importlib.reload(src.rate_limiter) + # Return a reference captured while the patches are still active + return src.rate_limiter.get_rate_limit_key + + # ------------------------------------------------------------------ + # TC-1: No TRUSTED_PROXIES configured, no X-Forwarded-For header + # → must return the direct peer IP + # ------------------------------------------------------------------ + + def test_get_rate_limit_key_no_trusted_proxies_no_xff_returns_client_ip(self): + """When TRUSTED_PROXIES is empty and no XFF header is present, + get_rate_limit_key must return the direct peer IP.""" + import importlib + import src.constants + import src.rate_limiter + + mock_request = self._make_request("203.0.113.10") + + with patch.dict(os.environ, {"TRUSTED_PROXIES": ""}): + importlib.reload(src.constants) + importlib.reload(src.rate_limiter) + result = src.rate_limiter.get_rate_limit_key(mock_request) + + assert result == "203.0.113.10" + + # ------------------------------------------------------------------ + # TC-2: No TRUSTED_PROXIES configured, X-Forwarded-For is present + # → must return the direct peer IP (header MUST be ignored) + # + # This is the primary security requirement: a client that forges + # X-Forwarded-For must NOT be able to impersonate a different IP. + # ------------------------------------------------------------------ + + def test_get_rate_limit_key_no_trusted_proxies_xff_present_ignores_xff(self): + """When TRUSTED_PROXIES is empty, X-Forwarded-For must be ignored + regardless of its value. Trusting it without proxy validation + lets any client bypass rate limits by setting a forged header.""" + import importlib + import src.constants + import src.rate_limiter + + mock_request = self._make_request( + "203.0.113.10", + x_forwarded_for="1.2.3.4", + ) + + with patch.dict(os.environ, {"TRUSTED_PROXIES": ""}): + importlib.reload(src.constants) + importlib.reload(src.rate_limiter) + result = src.rate_limiter.get_rate_limit_key(mock_request) + + # Must be the real direct-connection peer, NOT the attacker-supplied IP + assert result == "203.0.113.10", ( + "get_rate_limit_key must not trust X-Forwarded-For when " "TRUSTED_PROXIES is empty" + ) + assert result != "1.2.3.4" + + # ------------------------------------------------------------------ + # TC-3: Peer IP is NOT in TRUSTED_PROXIES, XFF is present (spoofed) + # → must return the direct peer IP (header MUST be ignored) + # ------------------------------------------------------------------ + + def test_get_rate_limit_key_untrusted_peer_xff_spoofed_returns_client_ip(self): + """When the peer IP is not in TRUSTED_PROXIES, X-Forwarded-For is + attacker-controlled data and must be ignored entirely.""" + import importlib + import src.constants + import src.rate_limiter + + mock_request = self._make_request( + "198.51.100.99", # not a trusted proxy + x_forwarded_for="1.2.3.4, 10.0.0.1", + ) + + # 10.0.0.1 is a trusted proxy, but 198.51.100.99 (the peer) is not + with patch.dict(os.environ, {"TRUSTED_PROXIES": "10.0.0.1"}): + importlib.reload(src.constants) + importlib.reload(src.rate_limiter) + result = src.rate_limiter.get_rate_limit_key(mock_request) + + # Must return the real untrusted peer, not any value from XFF + assert result == "198.51.100.99" + assert result not in ("1.2.3.4", "10.0.0.1") + + # ------------------------------------------------------------------ + # TC-4: Peer IP IS in TRUSTED_PROXIES, XFF chain contains both + # trusted and non-trusted IPs + # → must return the rightmost non-trusted IP + # + # This is the core "happy path" for reverse-proxy deployments. + # ------------------------------------------------------------------ + + def test_get_rate_limit_key_trusted_peer_valid_xff_returns_rightmost_non_trusted_ip(self): + """When the peer is a trusted proxy and the X-Forwarded-For chain + contains non-trusted IPs, the rightmost non-trusted IP is the actual + client and must be used for rate limiting. + + XFF chain: "1.2.3.4, 10.0.0.2" + Direct peer: 10.0.0.1 (trusted) + Trusted proxies: 10.0.0.1, 10.0.0.2 + + Walking from right: 10.0.0.2 is trusted (skip); 1.2.3.4 is not trusted + → return 1.2.3.4 + """ + import importlib + import src.constants + import src.rate_limiter + + mock_request = self._make_request( + "10.0.0.1", # trusted proxy (direct peer) + x_forwarded_for="1.2.3.4, 10.0.0.2", + ) + + with patch.dict(os.environ, {"TRUSTED_PROXIES": "10.0.0.1,10.0.0.2"}): + importlib.reload(src.constants) + importlib.reload(src.rate_limiter) + result = src.rate_limiter.get_rate_limit_key(mock_request) + + assert result == "1.2.3.4" + + # ------------------------------------------------------------------ + # TC-5: Peer IP IS in TRUSTED_PROXIES but no XFF header is present + # → must fall back to the peer IP (no error, no None) + # ------------------------------------------------------------------ + + def test_get_rate_limit_key_trusted_peer_missing_xff_falls_back_to_peer_ip(self): + """When the peer is a trusted proxy but no X-Forwarded-For header + exists, there is no upstream IP to extract. The function must fall + back to the peer IP rather than raising an exception or returning + None.""" + import importlib + import src.constants + import src.rate_limiter + + mock_request = self._make_request("10.0.0.1") # no XFF header + + with patch.dict(os.environ, {"TRUSTED_PROXIES": "10.0.0.1"}): + importlib.reload(src.constants) + importlib.reload(src.rate_limiter) + result = src.rate_limiter.get_rate_limit_key(mock_request) + + assert result == "10.0.0.1" + + # Reset module state after tests @pytest.fixture(autouse=True) def reset_rate_limiter_module(): diff --git a/tests/test_sdk_migration.py b/tests/test_sdk_migration.py index 6ad2d95..cec5140 100644 --- a/tests/test_sdk_migration.py +++ b/tests/test_sdk_migration.py @@ -74,7 +74,7 @@ def test_default_model_defined(self): from src.constants import DEFAULT_MODEL, CLAUDE_MODELS assert DEFAULT_MODEL in CLAUDE_MODELS - assert DEFAULT_MODEL == "claude-sonnet-4-5-20250929" + assert DEFAULT_MODEL == "claude-sonnet-4-6" def test_fast_model_defined(self): """Test that FAST_MODEL is set to fastest model.""" diff --git a/tests/test_session_manager_unit.py b/tests/test_session_manager_unit.py index 961a385..9520530 100644 --- a/tests/test_session_manager_unit.py +++ b/tests/test_session_manager_unit.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch import asyncio -from src.session_manager import Session, SessionManager +from src.session_manager import Session, SessionManager, SessionLimitExceeded from src.models import Message @@ -373,3 +373,220 @@ def create_session(session_id): assert len(errors) == 0 assert len(results) == 10 assert len(manager.sessions) == 10 + + +class TestSessionManagerSessionLimit: + """Tests for FR-6.1 — configurable max session count enforcement.""" + + @pytest.fixture + def manager_limit_3(self): + """SessionManager with max_sessions=3 for easy boundary testing.""" + return SessionManager(default_ttl_hours=1, cleanup_interval_minutes=5, max_sessions=3) + + def test_manager_accepts_max_sessions_parameter(self): + """SessionManager can be constructed with a custom max_sessions value.""" + manager = SessionManager(max_sessions=500) + assert manager.max_sessions == 500 + + def test_manager_max_sessions_defaults_to_constant(self): + """SessionManager uses MAX_SESSIONS constant as default for max_sessions.""" + from src.constants import MAX_SESSIONS + + manager = SessionManager() + assert manager.max_sessions == MAX_SESSIONS + + def test_creating_sessions_up_to_limit_succeeds(self, manager_limit_3): + """Creating exactly max_sessions sessions does not raise an exception.""" + manager_limit_3.get_or_create_session("session-1") + manager_limit_3.get_or_create_session("session-2") + manager_limit_3.get_or_create_session("session-3") + # All three created without exception; dict has exactly 3 entries + assert len(manager_limit_3.sessions) == 3 + + def test_creating_session_at_limit_raises_session_limit_exceeded(self, manager_limit_3): + """Creating a session when the limit is already reached raises SessionLimitExceeded.""" + manager_limit_3.get_or_create_session("session-1") + manager_limit_3.get_or_create_session("session-2") + manager_limit_3.get_or_create_session("session-3") + + with pytest.raises(SessionLimitExceeded): + manager_limit_3.get_or_create_session("session-4") + + def test_session_limit_exceeded_is_raised_on_new_id_not_existing(self, manager_limit_3): + """SessionLimitExceeded is raised for a brand-new session ID, not when re-accessing existing.""" + manager_limit_3.get_or_create_session("session-1") + manager_limit_3.get_or_create_session("session-2") + manager_limit_3.get_or_create_session("session-3") + + # Re-accessing an existing session must NOT raise, even when at limit + existing = manager_limit_3.get_or_create_session("session-1") + assert existing.session_id == "session-1" + + def test_after_expired_session_cleaned_up_new_session_can_be_created(self, manager_limit_3): + """Once an expired session is cleaned up, the freed slot allows a new session.""" + manager_limit_3.get_or_create_session("session-1") + session2 = manager_limit_3.get_or_create_session("session-2") + manager_limit_3.get_or_create_session("session-3") + + # Expire session-2 so that get_or_create_session will replace it + session2.expires_at = datetime.utcnow() - timedelta(hours=1) + + # Trigger cleanup so the slot is released + manager_limit_3._cleanup_expired_sessions() + + # Now there are only 2 active sessions; creating a new one must succeed + new_session = manager_limit_3.get_or_create_session("session-new") + assert new_session.session_id == "session-new" + + def test_session_limit_exceeded_exception_is_value_error_subclass(self): + """SessionLimitExceeded is a subclass of ValueError per architecture spec (section 5.4).""" + assert issubclass(SessionLimitExceeded, ValueError) + + def test_session_count_does_not_increase_past_limit(self, manager_limit_3): + """Session count stays at max_sessions after a failed creation attempt.""" + manager_limit_3.get_or_create_session("session-1") + manager_limit_3.get_or_create_session("session-2") + manager_limit_3.get_or_create_session("session-3") + + try: + manager_limit_3.get_or_create_session("session-4") + except SessionLimitExceeded: + pass + + assert len(manager_limit_3.sessions) == 3 + + +class TestCheckSessionLimit: + """Tests for check_session_limit() — read-only pre-check for streaming path.""" + + @pytest.fixture + def manager_limit_3(self): + return SessionManager( + default_ttl_hours=1, + cleanup_interval_minutes=5, + max_sessions=3, + ) + + def test_existing_active_session_does_not_raise(self, manager_limit_3): + """check_session_limit passes for an existing active session even at limit.""" + manager_limit_3.get_or_create_session("s1") + manager_limit_3.get_or_create_session("s2") + manager_limit_3.get_or_create_session("s3") + # At limit, but s1 exists and is active — should not raise + manager_limit_3.check_session_limit("s1") + + def test_new_session_at_limit_raises(self, manager_limit_3): + """check_session_limit raises SessionLimitExceeded for a new session at limit.""" + manager_limit_3.get_or_create_session("s1") + manager_limit_3.get_or_create_session("s2") + manager_limit_3.get_or_create_session("s3") + with pytest.raises(SessionLimitExceeded): + manager_limit_3.check_session_limit("s4") + + def test_expired_session_at_limit_raises(self, manager_limit_3): + """check_session_limit raises for an expired session when all slots are full.""" + manager_limit_3.get_or_create_session("s1") + manager_limit_3.get_or_create_session("s2") + s3 = manager_limit_3.get_or_create_session("s3") + s3.expires_at = datetime.utcnow() - timedelta(hours=1) + # s3 is expired but still in dict — slot not freed yet + with pytest.raises(SessionLimitExceeded): + manager_limit_3.check_session_limit("s3") + + def test_under_limit_does_not_raise(self, manager_limit_3): + """check_session_limit passes when under the session limit.""" + manager_limit_3.get_or_create_session("s1") + manager_limit_3.check_session_limit("s2") # Should not raise + + def test_does_not_create_session(self, manager_limit_3): + """check_session_limit does not actually create a session.""" + manager_limit_3.check_session_limit("s1") + assert "s1" not in manager_limit_3.sessions + + +class TestSessionManagerMessageLimit: + """Tests for FR-6.2 — configurable max message history per session.""" + + @pytest.fixture + def manager_msg_limit_5(self): + """SessionManager with max_session_messages=5 for easy boundary testing.""" + return SessionManager( + default_ttl_hours=1, + cleanup_interval_minutes=5, + max_session_messages=5, + ) + + def _make_messages(self, count: int, prefix: str = "msg") -> list: + """Helper: build a list of user Messages with predictable content.""" + return [Message(role="user", content=f"{prefix}-{i}") for i in range(count)] + + def test_manager_accepts_max_session_messages_parameter(self): + """SessionManager can be constructed with a custom max_session_messages value.""" + manager = SessionManager(max_session_messages=50) + assert manager.max_session_messages == 50 + + def test_manager_max_session_messages_defaults_to_constant(self): + """SessionManager uses MAX_SESSION_MESSAGES constant as default.""" + from src.constants import MAX_SESSION_MESSAGES + + manager = SessionManager() + assert manager.max_session_messages == MAX_SESSION_MESSAGES + + def test_adding_messages_up_to_limit_keeps_all(self, manager_msg_limit_5): + """Adding exactly max_session_messages messages keeps all of them.""" + session = manager_msg_limit_5.get_or_create_session("test-session") + messages = self._make_messages(5) + session.add_messages(messages) + + assert len(session.messages) == 5 + + def test_adding_messages_beyond_limit_trims_to_limit(self, manager_msg_limit_5): + """Adding more than max_session_messages messages trims the list to the limit.""" + session = manager_msg_limit_5.get_or_create_session("test-session") + messages = self._make_messages(7) + session.add_messages(messages) + + assert len(session.messages) == 5 + + def test_trimming_drops_oldest_messages(self, manager_msg_limit_5): + """When trimming occurs the oldest messages (first added) are removed.""" + session = manager_msg_limit_5.get_or_create_session("test-session") + messages = self._make_messages(7) # msg-0 … msg-6 + session.add_messages(messages) + + # After trimming, only the 5 newest (msg-2 … msg-6) should remain + remaining_contents = [m.content for m in session.messages] + assert "msg-0" not in remaining_contents + assert "msg-1" not in remaining_contents + + def test_trimming_retains_newest_messages(self, manager_msg_limit_5): + """When trimming occurs the newest messages are retained.""" + session = manager_msg_limit_5.get_or_create_session("test-session") + messages = self._make_messages(7) # msg-0 … msg-6 + session.add_messages(messages) + + remaining_contents = [m.content for m in session.messages] + for i in range(2, 7): # msg-2 through msg-6 must be present + assert f"msg-{i}" in remaining_contents + + def test_trimming_across_multiple_add_calls(self, manager_msg_limit_5): + """Message limit is enforced across multiple separate add_messages calls.""" + session = manager_msg_limit_5.get_or_create_session("test-session") + + # Add 3 messages in first call, then 4 more in second call (total 7 > limit of 5) + first_batch = self._make_messages(3, prefix="first") + second_batch = self._make_messages(4, prefix="second") + + session.add_messages(first_batch) + session.add_messages(second_batch) + + assert len(session.messages) == 5 + + def test_message_count_after_trimming_equals_limit(self, manager_msg_limit_5): + """After any trim, get_all_messages returns exactly max_session_messages messages.""" + session = manager_msg_limit_5.get_or_create_session("test-session") + # Add well beyond the limit to exercise the trim path + session.add_messages(self._make_messages(20)) + + all_messages = session.get_all_messages() + assert len(all_messages) == 5 diff --git a/tests/test_tool_execution.py b/tests/test_tool_execution.py index 3c8fe34..9774414 100644 --- a/tests/test_tool_execution.py +++ b/tests/test_tool_execution.py @@ -8,6 +8,9 @@ - DEFAULT_ALLOWED_TOOLS configuration """ +import tempfile +from unittest.mock import patch + import pytest from claude_agent_sdk import ClaudeAgentOptions @@ -67,11 +70,16 @@ def test_default_allowed_tools_excludes_dangerous(self): class TestParseClaudeMessage: """Test parse_claude_message correctly handles multi-turn conversations.""" - def test_result_message_priority(self): - """Test that ResultMessage.result is prioritized over AssistantMessage.""" + def _make_cli(self): + """Create a ClaudeCodeCLI with sandbox check bypassed for /tmp.""" from src.claude_cli import ClaudeCodeCLI - cli = ClaudeCodeCLI(cwd="/tmp") + with patch("src.claude_cli.CLAUDE_CWD_ALLOWED_BASE", "/"): + return ClaudeCodeCLI(cwd="/tmp") + + def test_result_message_priority(self): + """Test that ResultMessage.result is prioritized over AssistantMessage.""" + cli = self._make_cli() # Simulate multi-turn conversation messages messages = [ @@ -97,9 +105,7 @@ def test_result_message_priority(self): def test_fallback_to_last_assistant_message(self): """Test fallback to last AssistantMessage when no ResultMessage.""" - from src.claude_cli import ClaudeCodeCLI - - cli = ClaudeCodeCLI(cwd="/tmp") + cli = self._make_cli() # Simulate messages without ResultMessage messages = [ @@ -118,18 +124,14 @@ def test_fallback_to_last_assistant_message(self): def test_handles_empty_messages(self): """Test handling of empty message list.""" - from src.claude_cli import ClaudeCodeCLI - - cli = ClaudeCodeCLI(cwd="/tmp") + cli = self._make_cli() result = cli.parse_claude_message([]) assert result is None def test_handles_dict_content_blocks(self): """Test handling of dict-based content blocks (old format).""" - from src.claude_cli import ClaudeCodeCLI - - cli = ClaudeCodeCLI(cwd="/tmp") + cli = self._make_cli() messages = [{"content": [{"type": "text", "text": "Hello world"}]}]