diff --git a/src/ad_seller/interfaces/api/main.py b/src/ad_seller/interfaces/api/main.py index 7b58c55..cb1d7b6 100644 --- a/src/ad_seller/interfaces/api/main.py +++ b/src/ad_seller/interfaces/api/main.py @@ -17,7 +17,9 @@ from typing import Any, Optional from fastapi import Depends, FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel +from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware logger = logging.getLogger(__name__) @@ -50,6 +52,7 @@ version="1.0.0", contact={"name": "IAB Tech Lab", "url": "https://iabtechlab.com"}, license_info={"name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0"}, + root_path_in_servers=False, openapi_tags=[ {"name": "Core", "description": "Health check and API root"}, {"name": "Products", "description": "Product catalog browsing"}, @@ -84,6 +87,21 @@ # Lifecycle: start/stop background services # ============================================================================= +# Trust X-Forwarded-Proto / X-Forwarded-For from Cloud Run so that Starlette +# generates https:// redirects instead of http:// ones behind the TLS proxy. +app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") + +# Allow all browser-based clients — buyer UIs, claude.ai, SSP dashboards, etc. +# The MCP Streamable HTTP protocol requires CORS for browser-originated requests. +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], + allow_credentials=False, + expose_headers=["*"], +) + _mcp_server_ref = None diff --git a/src/ad_seller/interfaces/mcp_server.py b/src/ad_seller/interfaces/mcp_server.py index 6395d7f..30daa4e 100644 --- a/src/ad_seller/interfaces/mcp_server.py +++ b/src/ad_seller/interfaces/mcp_server.py @@ -39,6 +39,13 @@ "and interact with buyer agents. On first connection, check setup status " "and offer the guided setup wizard if configuration is incomplete." ), + # streamable_http_path="/" so that when mounted at /mcp in FastAPI the + # endpoint resolves to /mcp (not /mcp/mcp which is the default). + streamable_http_path="/", + # host="0.0.0.0" disables the auto DNS-rebinding protection that FastMCP + # applies when host is 127.0.0.1/localhost. That protection blocks requests + # from Cloud Run (Host header is the public *.run.app domain) with 421. + host="0.0.0.0", ) diff --git a/tests/integration/test_mcp_streamable.py b/tests/integration/test_mcp_streamable.py new file mode 100644 index 0000000..0217f9c --- /dev/null +++ b/tests/integration/test_mcp_streamable.py @@ -0,0 +1,243 @@ +"""MCP Streamable HTTP Smoke Tests — /mcp endpoint. + +Tests the seller agent's primary MCP transport (Streamable HTTP at /mcp) +against a live running server. Separate from test_mcp_integration.py which +uses mocked backends. + +Usage: + # Start the seller server first: + # uvicorn ad_seller.interfaces.api.main:app --port 8000 + # + # Then run: + # pytest tests/integration/test_mcp_streamable.py -v + +Requires a running seller server on port 8000 (or set SELLER_MCP_HTTP_URL). + +Note: no @pytest.mark.asyncio decorators needed — pyproject.toml sets +asyncio_mode = "auto" which handles all async test functions automatically. +Adding the decorator alongside AUTO mode causes double collection. +""" + +import asyncio +import json +import os +from contextlib import asynccontextmanager + +import pytest + +# --------------------------------------------------------------------------- +# Optional MCP SDK imports +# --------------------------------------------------------------------------- +try: + from mcp.client.streamable_http import streamable_http_client + from mcp import ClientSession + MCP_HTTP_AVAILABLE = True +except ImportError: + try: + from mcp.client.streamable_http import streamablehttp_client as streamable_http_client # type: ignore[no-redef] + from mcp import ClientSession + MCP_HTTP_AVAILABLE = True + except ImportError: + MCP_HTTP_AVAILABLE = False + +MCP_HTTP_URL = os.environ.get("SELLER_MCP_HTTP_URL", "http://127.0.0.1:3000/mcp") +TOOL_TIMEOUT = float(os.environ.get("MCP_TOOL_TIMEOUT", "15")) + +pytestmark = [ + pytest.mark.integration, + pytest.mark.skipif(not MCP_HTTP_AVAILABLE, reason="mcp streamable_http client not available"), +] + + +# --------------------------------------------------------------------------- +# Session helper +# --------------------------------------------------------------------------- + +@asynccontextmanager +async def _mcp_session(): + """Open a fresh Streamable HTTP MCP session for one test.""" + try: + async with streamable_http_client(MCP_HTTP_URL) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + yield session + except Exception as exc: + pytest.skip(f"Seller /mcp not reachable at {MCP_HTTP_URL}: {exc}") + + +async def _call(session: "ClientSession", name: str, args: dict | None = None): + """Call an MCP tool and return (is_error, data).""" + try: + result = await asyncio.wait_for( + session.call_tool(name, arguments=args or {}), + timeout=TOOL_TIMEOUT, + ) + except asyncio.TimeoutError: + pytest.fail(f"Tool '{name}' timed out after {TOOL_TIMEOUT}s on /mcp") + + content = result.content + if not content or not hasattr(content[0], "text"): + return False, {} + text = content[0].text + if text.startswith("Error executing tool"): + return True, {"raw_error": text} + try: + return False, json.loads(text) + except json.JSONDecodeError: + return False, {"raw_text": text} + + +# --------------------------------------------------------------------------- +# Connection +# --------------------------------------------------------------------------- + +async def test_streamable_http_connection(): + """/mcp must accept a session and initialize successfully.""" + async with _mcp_session() as session: + assert session is not None + + +async def test_streamable_http_tool_list(): + """/mcp must advertise all foundation tools.""" + async with _mcp_session() as session: + result = await asyncio.wait_for(session.list_tools(), timeout=TOOL_TIMEOUT) + tool_names = {t.name for t in result.tools} + for required in ("health_check", "get_setup_status", "get_config"): + assert required in tool_names, ( + f"Required tool '{required}' missing — got: {sorted(tool_names)}" + ) + + +# --------------------------------------------------------------------------- +# Foundation tools +# --------------------------------------------------------------------------- + +async def test_health_check(): + async with _mcp_session() as session: + err, data = await _call(session, "health_check") + assert not err, f"health_check error: {data}" + assert data.get("status") in ("healthy", "degraded") + assert "checks" in data + + +async def test_get_setup_status(): + async with _mcp_session() as session: + err, data = await _call(session, "get_setup_status") + assert not err, f"get_setup_status error: {data}" + assert "setup_complete" in data + assert "publisher_identity" in data + assert "ad_server" in data + + +async def test_get_config(): + async with _mcp_session() as session: + err, data = await _call(session, "get_config") + assert not err, f"get_config error: {data}" + assert "publisher" in data + assert "pricing" in data + assert "anthropic" not in str(data).lower(), "API key must not be exposed" + + +# --------------------------------------------------------------------------- +# Inventory & Products +# --------------------------------------------------------------------------- + +async def test_list_products(): + async with _mcp_session() as session: + err, data = await _call(session, "list_products") + assert not err, f"list_products error: {data}" + assert "products" in data + assert isinstance(data["products"], list) + + +async def test_list_packages(): + async with _mcp_session() as session: + err, data = await _call(session, "list_packages") + assert not err, f"list_packages error: {data}" + assert "packages" in data + assert isinstance(data["packages"], list) + + +async def test_get_rate_card(): + async with _mcp_session() as session: + err, data = await _call(session, "get_rate_card") + assert not err, f"get_rate_card error: {data}" + assert "entries" in data + assert isinstance(data["entries"], list) + + +async def test_get_sync_status(): + async with _mcp_session() as session: + err, data = await _call(session, "get_sync_status") + assert not err, f"get_sync_status error: {data}" + + +# --------------------------------------------------------------------------- +# Orders & Approvals +# --------------------------------------------------------------------------- + +async def test_list_orders(): + async with _mcp_session() as session: + err, data = await _call(session, "list_orders") + assert not err, f"list_orders error: {data}" + + +async def test_list_pending_approvals(): + async with _mcp_session() as session: + err, data = await _call(session, "list_pending_approvals") + assert not err, f"list_pending_approvals error: {data}" + + +async def test_get_inbound_queue(): + async with _mcp_session() as session: + err, data = await _call(session, "get_inbound_queue") + assert not err, f"get_inbound_queue error: {data}" + assert "items" in data + assert "count" in data + + +# --------------------------------------------------------------------------- +# Buyer agents & SSPs +# --------------------------------------------------------------------------- + +async def test_list_buyer_agents(): + async with _mcp_session() as session: + err, data = await _call(session, "list_buyer_agents") + assert not err, f"list_buyer_agents error: {data}" + + +async def test_list_ssps(): + async with _mcp_session() as session: + err, data = await _call(session, "list_ssps") + assert not err, f"list_ssps error: {data}" + assert "connectors" in data + + +async def test_list_agents(): + async with _mcp_session() as session: + err, data = await _call(session, "list_agents") + assert not err, f"list_agents error: {data}" + assert "hierarchy" in data + + +# --------------------------------------------------------------------------- +# API keys +# --------------------------------------------------------------------------- + +async def test_api_key_lifecycle(): + """Full create → list → revoke lifecycle over /mcp.""" + async with _mcp_session() as session: + err, created = await _call(session, "create_api_key", { + "name": "smoke-test-key", + "label": "mcp-streamable-smoke", + }) + assert not err, f"create_api_key failed: {created}" + key_id = created.get("key_id") + assert key_id, "Response must include key_id" + + err, listed = await _call(session, "list_api_keys") + assert not err + assert any(k.get("key_id") == key_id for k in listed.get("keys", [])) + + err, revoked = await _call(session, "revoke_api_key", {"key_id": key_id}) + assert not err, f"revoke_api_key failed: {revoked}"