diff --git a/.github/workflows/python-sdk.yml b/.github/workflows/python-sdk.yml new file mode 100644 index 00000000..6310f0a8 --- /dev/null +++ b/.github/workflows/python-sdk.yml @@ -0,0 +1,113 @@ +name: Python SDK + +on: + push: + paths: + - 'sdk/python/**' + - '.github/workflows/python-sdk.yml' + pull_request: + paths: + - 'sdk/python/**' + +defaults: + run: + working-directory: sdk/python + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + pip install -e ".[dev]" + + - name: Run Ruff linter + run: ruff check . + + - name: Run Ruff formatter check + run: ruff format --check . + + - name: Run MyPy + run: mypy axonflow + + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12'] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: pip install -e ".[dev,all]" + + - name: Run tests + run: pytest --cov-report=xml --cov-fail-under=90 + + - name: Upload coverage + if: matrix.python-version == '3.11' + uses: codecov/codecov-action@v3 + with: + files: ./sdk/python/coverage.xml + flags: python-sdk + + build: + runs-on: ubuntu-latest + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install build tools + run: pip install build + + - name: Build package + run: python -m build + + - name: Check package + run: | + pip install twine + twine check dist/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: python-sdk-dist + path: sdk/python/dist/ + + publish: + needs: [build] + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/python-sdk-v') + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + name: python-sdk-dist + path: dist/ + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ diff --git a/sdk/python/CHANGELOG.md b/sdk/python/CHANGELOG.md new file mode 100644 index 00000000..8792c6d3 --- /dev/null +++ b/sdk/python/CHANGELOG.md @@ -0,0 +1,34 @@ +# Changelog + +All notable changes to the AxonFlow Python SDK will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [0.1.0] - 2025-12-04 + +### Added + +- Initial release of AxonFlow Python SDK +- Async-first client with sync wrappers +- Full type hints with Pydantic v2 models +- Gateway Mode support for lowest-latency LLM calls + - `get_policy_approved_context()` for pre-checks + - `audit_llm_call()` for compliance logging +- OpenAI interceptor for transparent governance +- Anthropic interceptor for transparent governance +- MCP connector operations + - `list_connectors()` + - `install_connector()` + - `query_connector()` +- Multi-agent planning + - `generate_plan()` + - `execute_plan()` + - `get_plan_status()` +- Comprehensive exception hierarchy +- Response caching with TTL +- Retry logic with exponential backoff +- Structured logging with structlog +- 95%+ test coverage +- mypy strict mode compatible +- ruff linting compatible diff --git a/sdk/python/README.md b/sdk/python/README.md new file mode 100644 index 00000000..6c4017a5 --- /dev/null +++ b/sdk/python/README.md @@ -0,0 +1,257 @@ +# AxonFlow Python SDK + +Enterprise AI Governance in 3 Lines of Code. + +[![PyPI version](https://badge.fury.io/py/axonflow.svg)](https://badge.fury.io/py/axonflow) +[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/) +[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Type hints](https://img.shields.io/badge/type%20hints-mypy-brightgreen.svg)](http://mypy-lang.org/) + +## Installation + +```bash +pip install axonflow +``` + +With LLM provider support: +```bash +pip install axonflow[openai] # OpenAI integration +pip install axonflow[anthropic] # Anthropic integration +pip install axonflow[all] # All integrations +``` + +## Quick Start + +### Async Usage (Recommended) + +```python +import asyncio +from axonflow import AxonFlow + +async def main(): + async with AxonFlow( + agent_url="https://your-agent.axonflow.com", + client_id="your-client-id", + client_secret="your-client-secret" + ) as client: + # Execute a governed query + response = await client.execute_query( + user_token="user-jwt-token", + query="What is AI governance?", + request_type="chat" + ) + print(response.data) + +asyncio.run(main()) +``` + +### Sync Usage + +```python +from axonflow import AxonFlow + +with AxonFlow.sync( + agent_url="https://your-agent.axonflow.com", + client_id="your-client-id", + client_secret="your-client-secret" +) as client: + response = client.execute_query( + user_token="user-jwt-token", + query="What is AI governance?", + request_type="chat" + ) + print(response.data) +``` + +## Features + +### Gateway Mode + +For lowest-latency LLM calls with full governance and audit compliance: + +```python +from axonflow import AxonFlow, TokenUsage + +async with AxonFlow(...) as client: + # 1. Pre-check: Get policy approval + ctx = await client.get_policy_approved_context( + user_token="user-jwt", + query="Find patient records", + data_sources=["postgres"] + ) + + if not ctx.approved: + raise Exception(f"Blocked: {ctx.block_reason}") + + # 2. Make LLM call directly (your code) + llm_response = await openai.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": str(ctx.approved_data)}] + ) + + # 3. Audit the call + await client.audit_llm_call( + context_id=ctx.context_id, + response_summary=llm_response.choices[0].message.content[:100], + provider="openai", + model="gpt-4", + token_usage=TokenUsage( + prompt_tokens=llm_response.usage.prompt_tokens, + completion_tokens=llm_response.usage.completion_tokens, + total_tokens=llm_response.usage.total_tokens + ), + latency_ms=250 + ) +``` + +### OpenAI Integration + +Transparent governance for existing OpenAI code: + +```python +from openai import OpenAI +from axonflow import AxonFlow +from axonflow.interceptors.openai import wrap_openai_client + +openai = OpenAI() +axonflow = AxonFlow(...) + +# Wrap client - governance is now automatic +wrapped = wrap_openai_client(openai, axonflow, user_token="user-123") + +# Use as normal +response = wrapped.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello!"}] +) +``` + +### MCP Connectors + +Query data through MCP connectors: + +```python +# List available connectors +connectors = await client.list_connectors() + +# Query a connector +result = await client.query_connector( + user_token="user-jwt", + connector_name="postgres", + operation="query", + params={"sql": "SELECT * FROM users LIMIT 10"} +) +``` + +### Multi-Agent Planning + +Generate and execute multi-agent plans: + +```python +# Generate a plan +plan = await client.generate_plan( + query="Book a flight and hotel for my trip to Paris", + domain="travel" +) + +print(f"Plan has {len(plan.steps)} steps") + +# Execute the plan +result = await client.execute_plan(plan.plan_id) +print(f"Result: {result.result}") +``` + +## Configuration + +```python +from axonflow import AxonFlow, Mode, RetryConfig + +client = AxonFlow( + agent_url="https://your-agent.axonflow.com", + client_id="your-client-id", + client_secret="your-client-secret", + license_key="optional-license-key", # For enterprise features + mode=Mode.PRODUCTION, # or Mode.SANDBOX + debug=True, # Enable debug logging + timeout=60.0, # Request timeout in seconds + retry_config=RetryConfig( # Retry configuration + enabled=True, + max_attempts=3, + initial_delay=1.0, + max_delay=30.0, + ), + cache_enabled=True, # Enable response caching + cache_ttl=60.0, # Cache TTL in seconds +) +``` + +## Error Handling + +```python +from axonflow.exceptions import ( + AxonFlowError, + PolicyViolationError, + AuthenticationError, + RateLimitError, + TimeoutError, +) + +try: + response = await client.execute_query(...) +except PolicyViolationError as e: + print(f"Blocked by policy: {e.block_reason}") +except RateLimitError as e: + print(f"Rate limited: {e.limit}/{e.remaining}, resets at {e.reset_at}") +except AuthenticationError: + print("Invalid credentials") +except TimeoutError: + print("Request timed out") +except AxonFlowError as e: + print(f"AxonFlow error: {e.message}") +``` + +## Response Types + +All responses are Pydantic models with full type hints: + +```python +from axonflow import ( + ClientResponse, + PolicyApprovalResult, + PlanResponse, + ConnectorResponse, +) + +# Full autocomplete and type checking support +response: ClientResponse = await client.execute_query(...) +print(response.success) +print(response.data) +print(response.policy_info.policies_evaluated) +``` + +## Development + +```bash +# Install dev dependencies +pip install -e ".[dev]" + +# Run tests +pytest + +# Run linting +ruff check . +ruff format . + +# Run type checking +mypy axonflow +``` + +## Documentation + +- [API Reference](https://docs.getaxonflow.com/sdk/python/api) +- [Gateway Mode Guide](https://docs.getaxonflow.com/sdk/python/gateway-mode) +- [Examples](https://github.com/getaxonflow/axonflow/tree/main/sdk/python/examples) + +## License + +Apache 2.0 - See [LICENSE](LICENSE) for details. diff --git a/sdk/python/axonflow/__init__.py b/sdk/python/axonflow/__init__.py new file mode 100644 index 00000000..6ed1972f --- /dev/null +++ b/sdk/python/axonflow/__init__.py @@ -0,0 +1,95 @@ +"""AxonFlow Python SDK - Enterprise AI Governance in 3 Lines of Code. + +This SDK provides a simple, async-first interface for integrating AI governance +into your Python applications. It supports policy enforcement, audit logging, +MCP connectors, and multi-agent planning. + +Example: + >>> from axonflow import AxonFlow + >>> + >>> # Async usage + >>> async with AxonFlow( + ... agent_url="https://your-agent.axonflow.com", + ... client_id="your-client-id", + ... client_secret="your-client-secret" + ... ) as client: + ... result = await client.execute_query("user-token", "What is AI?", "chat") + ... print(result.data) + >>> + >>> # Sync usage + >>> client = AxonFlow.sync( + ... agent_url="https://your-agent.axonflow.com", + ... client_id="your-client-id", + ... client_secret="your-client-secret" + ... ) + >>> result = client.execute_query("user-token", "What is AI?", "chat") +""" + +from axonflow.client import AxonFlow, SyncAxonFlow +from axonflow.exceptions import ( + AuthenticationError, + AxonFlowError, + ConfigurationError, + ConnectionError, + ConnectorError, + PlanExecutionError, + PolicyViolationError, + RateLimitError, + TimeoutError, +) +from axonflow.types import ( + AuditResult, + CacheConfig, + ClientRequest, + ClientResponse, + ConnectorInstallRequest, + ConnectorMetadata, + ConnectorResponse, + Mode, + PlanExecutionResponse, + PlanResponse, + PlanStep, + PolicyApprovalResult, + PolicyEvaluationInfo, + RateLimitInfo, + RetryConfig, + TokenUsage, +) + +__version__ = "0.1.0" +__all__ = [ + # Main client + "AxonFlow", + "SyncAxonFlow", + # Configuration + "Mode", + "RetryConfig", + "CacheConfig", + # Request/Response types + "ClientRequest", + "ClientResponse", + "PolicyEvaluationInfo", + # Connector types + "ConnectorMetadata", + "ConnectorInstallRequest", + "ConnectorResponse", + # Planning types + "PlanStep", + "PlanResponse", + "PlanExecutionResponse", + # Gateway Mode types + "RateLimitInfo", + "PolicyApprovalResult", + "TokenUsage", + "AuditResult", + # Exceptions + "AxonFlowError", + "ConfigurationError", + "AuthenticationError", + "PolicyViolationError", + "RateLimitError", + "ConnectionError", + "TimeoutError", + "ConnectorError", + "PlanExecutionError", +] diff --git a/sdk/python/axonflow/client.py b/sdk/python/axonflow/client.py new file mode 100644 index 00000000..aa16af0b --- /dev/null +++ b/sdk/python/axonflow/client.py @@ -0,0 +1,813 @@ +"""AxonFlow SDK Main Client. + +The primary interface for interacting with AxonFlow governance platform. +Supports both async and sync usage patterns. + +Example: + >>> from axonflow import AxonFlow + >>> + >>> # Async usage + >>> async with AxonFlow(agent_url="...", client_id="...", client_secret="...") as client: + ... result = await client.execute_query("user-token", "What is AI?", "chat") + ... print(result.data) + >>> + >>> # Sync usage + >>> client = AxonFlow.sync(agent_url="...", client_id="...", client_secret="...") + >>> result = client.execute_query("user-token", "What is AI?", "chat") +""" + +from __future__ import annotations + +import asyncio +import hashlib +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import httpx +import structlog +from cachetools import TTLCache +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from axonflow.exceptions import ( + AuthenticationError, + AxonFlowError, + ConnectionError, + PolicyViolationError, + TimeoutError, +) +from axonflow.types import ( + AuditResult, + AxonFlowConfig, + CacheConfig, + ClientRequest, + ClientResponse, + ConnectorInstallRequest, + ConnectorMetadata, + ConnectorResponse, + Mode, + PlanExecutionResponse, + PlanResponse, + PlanStep, + PolicyApprovalResult, + RateLimitInfo, + RetryConfig, + TokenUsage, +) + +if TYPE_CHECKING: + from types import TracebackType + +logger = structlog.get_logger(__name__) + + +class AxonFlow: + """Main AxonFlow client for AI governance. + + This client provides async-first API for interacting with AxonFlow Agent. + All methods are async by default, with sync wrappers available via `.sync()`. + + Attributes: + config: Client configuration + """ + + __slots__ = ("_config", "_http_client", "_cache", "_logger") + + def __init__( + self, + agent_url: str, + client_id: str, + client_secret: str, + *, + license_key: str | None = None, + mode: Mode | str = Mode.PRODUCTION, + debug: bool = False, + timeout: float = 60.0, + insecure_skip_verify: bool = False, + retry_config: RetryConfig | None = None, + cache_enabled: bool = True, + cache_ttl: float = 60.0, + cache_max_size: int = 1000, + ) -> None: + """Initialize AxonFlow client. + + Args: + agent_url: AxonFlow Agent URL + client_id: Client ID for authentication + client_secret: Client secret for authentication + license_key: Optional license key for organization-level auth + mode: Operation mode (production or sandbox) + debug: Enable debug logging + timeout: Request timeout in seconds + insecure_skip_verify: Skip TLS verification (dev only) + retry_config: Retry configuration + cache_enabled: Enable response caching + cache_ttl: Cache TTL in seconds + cache_max_size: Maximum cache entries + """ + if isinstance(mode, str): + mode = Mode(mode) + + self._config = AxonFlowConfig( + agent_url=agent_url.rstrip("/"), + client_id=client_id, + client_secret=client_secret, + license_key=license_key, + mode=mode, + debug=debug, + timeout=timeout, + insecure_skip_verify=insecure_skip_verify, + retry=retry_config or RetryConfig(), + cache=CacheConfig(enabled=cache_enabled, ttl=cache_ttl, max_size=cache_max_size), + ) + + # Configure SSL verification + verify_ssl: bool = not insecure_skip_verify + + # Build headers + headers: dict[str, str] = { + "Content-Type": "application/json", + "X-Client-Secret": client_secret, + } + if license_key: + headers["X-License-Key"] = license_key + + # Initialize HTTP client + self._http_client = httpx.AsyncClient( + timeout=httpx.Timeout(timeout), + verify=verify_ssl, + headers=headers, + ) + + # Initialize cache + self._cache: TTLCache[str, ClientResponse] | None = None + if cache_enabled: + self._cache = TTLCache(maxsize=cache_max_size, ttl=cache_ttl) + + # Initialize logger + self._logger = structlog.get_logger(__name__).bind( + client_id=client_id, + mode=mode.value, + ) + + if debug: + self._logger.info( + "AxonFlow client initialized", + agent_url=agent_url, + ) + + @property + def config(self) -> AxonFlowConfig: + """Get client configuration.""" + return self._config + + async def __aenter__(self) -> AxonFlow: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Async context manager exit.""" + await self.close() + + async def close(self) -> None: + """Close the HTTP client.""" + await self._http_client.aclose() + + @classmethod + def sync( + cls, + agent_url: str, + client_id: str, + client_secret: str, + **kwargs: Any, + ) -> SyncAxonFlow: + """Create a synchronous client wrapper. + + Example: + >>> client = AxonFlow.sync(agent_url="...", client_id="...", client_secret="...") + >>> result = client.execute_query("token", "query", "chat") + """ + return SyncAxonFlow(cls(agent_url, client_id, client_secret, **kwargs)) + + @classmethod + def sandbox(cls, api_key: str = "demo-key") -> AxonFlow: + """Create a sandbox client for testing. + + Args: + api_key: Optional API key (defaults to demo-key) + + Returns: + Configured AxonFlow client for sandbox environment + """ + return cls( + agent_url="https://staging-eu.getaxonflow.com", + client_id=api_key, + client_secret=api_key, + mode=Mode.SANDBOX, + debug=True, + ) + + def _get_cache_key(self, request_type: str, query: str, user_token: str) -> str: + """Generate cache key for a request.""" + key = f"{request_type}:{query}:{user_token}" + return hashlib.sha256(key.encode()).hexdigest()[:32] + + async def _request( + self, + method: str, + path: str, + *, + json_data: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Make HTTP request to Agent.""" + url = f"{self._config.agent_url}{path}" + + try: + if self._config.retry.enabled: + response = await self._request_with_retry(method, url, json_data) + else: + response = await self._http_client.request(method, url, json=json_data) + + response.raise_for_status() + return response.json() # type: ignore[no-any-return] + + except httpx.ConnectError as e: + msg = f"Failed to connect to AxonFlow Agent: {e}" + raise ConnectionError(msg) from e + except httpx.TimeoutException as e: + msg = f"Request timed out: {e}" + raise TimeoutError(msg) from e + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: # noqa: PLR2004 + msg = "Invalid credentials" + raise AuthenticationError(msg) from e + if e.response.status_code == 403: # noqa: PLR2004 + body = e.response.json() + raise PolicyViolationError( + body.get("message", "Request blocked by policy"), + policy=body.get("policy"), + block_reason=body.get("block_reason"), + ) from e + msg = f"HTTP {e.response.status_code}: {e.response.text}" + raise AxonFlowError(msg) from e + + async def _request_with_retry( + self, + method: str, + url: str, + json_data: dict[str, Any] | None, + ) -> httpx.Response: + """Make request with retry logic.""" + + @retry( + stop=stop_after_attempt(self._config.retry.max_attempts), + wait=wait_exponential( + multiplier=self._config.retry.initial_delay, + max=self._config.retry.max_delay, + exp_base=self._config.retry.exponential_base, + ), + retry=retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)), + reraise=True, + ) + async def _do_request() -> httpx.Response: + return await self._http_client.request(method, url, json=json_data) + + return await _do_request() + + async def health_check(self) -> bool: + """Check if AxonFlow Agent is healthy. + + Returns: + True if agent is healthy, False otherwise + """ + try: + response = await self._request("GET", "/health") + return response.get("status") == "healthy" + except AxonFlowError: + return False + + async def execute_query( + self, + user_token: str, + query: str, + request_type: str, + context: dict[str, Any] | None = None, + ) -> ClientResponse: + """Execute a query through AxonFlow with policy enforcement. + + Args: + user_token: User authentication token + query: The query or prompt + request_type: Type of request (chat, sql, mcp-query, multi-agent-plan) + context: Optional additional context + + Returns: + ClientResponse with results or error + + Raises: + PolicyViolationError: If request is blocked by policy + AuthenticationError: If credentials are invalid + TimeoutError: If request times out + """ + # Check cache + if self._cache is not None: + cache_key = self._get_cache_key(request_type, query, user_token) + if cache_key in self._cache: + if self._config.debug: + self._logger.debug("Cache hit", query=query[:50]) + return self._cache[cache_key] + else: + cache_key = "" + + request = ClientRequest( + query=query, + user_token=user_token, + client_id=self._config.client_id, + request_type=request_type, + context=context or {}, + ) + + if self._config.debug: + self._logger.debug( + "Executing query", + request_type=request_type, + query=query[:50] if query else "", + ) + + response_data = await self._request( + "POST", + "/api/request", + json_data=request.model_dump(), + ) + + response = ClientResponse.model_validate(response_data) + + # Check for policy violation + if response.blocked: + raise PolicyViolationError( + response.block_reason or "Request blocked by policy", + block_reason=response.block_reason, + ) + + # Cache successful responses + if self._cache is not None and response.success and cache_key: + self._cache[cache_key] = response + + return response + + async def list_connectors(self) -> list[ConnectorMetadata]: + """List all available MCP connectors. + + Returns: + List of connector metadata + """ + response = await self._request("GET", "/api/connectors") + return [ConnectorMetadata.model_validate(c) for c in response] + + async def install_connector(self, request: ConnectorInstallRequest) -> None: + """Install an MCP connector. + + Args: + request: Connector installation request + """ + await self._request( + "POST", + "/api/connectors/install", + json_data=request.model_dump(), + ) + + if self._config.debug: + self._logger.info("Connector installed", name=request.name) + + async def query_connector( + self, + user_token: str, + connector_name: str, + operation: str, + params: dict[str, Any] | None = None, + ) -> ConnectorResponse: + """Query an MCP connector directly. + + Args: + user_token: User authentication token + connector_name: Name of the connector + operation: Operation to perform + params: Operation parameters + + Returns: + ConnectorResponse with results + """ + request_data: dict[str, Any] = { + "client_id": self._config.client_id, + "user_token": user_token, + "connector": connector_name, + "operation": operation, + "parameters": params or {}, + } + + if self._config.license_key: + request_data["license_key"] = self._config.license_key + + response = await self._request( + "POST", + "/mcp/resources/query", + json_data=request_data, + ) + + return ConnectorResponse.model_validate(response) + + async def generate_plan( + self, + query: str, + domain: str | None = None, + ) -> PlanResponse: + """Generate a multi-agent execution plan. + + Args: + query: Natural language query describing the task + domain: Optional domain hint (travel, healthcare, etc.) + + Returns: + PlanResponse with generated plan + """ + context = {"domain": domain} if domain else {} + + response = await self.execute_query( + user_token="", + query=query, + request_type="multi-agent-plan", + context=context, + ) + + if not response.success: + msg = f"Plan generation failed: {response.error}" + raise AxonFlowError(msg) + + # Extract steps from response data + steps: list[PlanStep] = [] + if response.data and isinstance(response.data, dict): + steps_data = response.data.get("steps", []) + steps = [PlanStep.model_validate(s) for s in steps_data] + + return PlanResponse( + plan_id=response.plan_id or "", + steps=steps, + domain=response.data.get("domain", domain or "generic") + if response.data and isinstance(response.data, dict) + else (domain or "generic"), + complexity=response.data.get("complexity", 0) + if response.data and isinstance(response.data, dict) + else 0, + parallel=response.data.get("parallel", False) + if response.data and isinstance(response.data, dict) + else False, + metadata=response.metadata, + ) + + async def execute_plan(self, plan_id: str) -> PlanExecutionResponse: + """Execute a previously generated plan. + + Args: + plan_id: ID of the plan to execute + + Returns: + PlanExecutionResponse with results + """ + response = await self.execute_query( + user_token="", + query="", + request_type="execute-plan", + context={"plan_id": plan_id}, + ) + + return PlanExecutionResponse( + plan_id=plan_id, + status="completed" if response.success else "failed", + result=response.result, + step_results=response.metadata.get("step_results", {}), + error=response.error, + duration=response.metadata.get("duration"), + ) + + async def get_plan_status(self, plan_id: str) -> PlanExecutionResponse: + """Get status of a running or completed plan. + + Args: + plan_id: ID of the plan + + Returns: + PlanExecutionResponse with current status + """ + response = await self._request("GET", f"/api/plans/{plan_id}") + return PlanExecutionResponse.model_validate(response) + + # ========================================================================= + # Gateway Mode Methods + # ========================================================================= + + async def get_policy_approved_context( + self, + user_token: str, + query: str, + data_sources: list[str] | None = None, + context: dict[str, Any] | None = None, + ) -> PolicyApprovalResult: + """Perform policy pre-check before making LLM call. + + This is the first step in Gateway Mode. Call this before making your + LLM call to ensure policy compliance. + + Args: + user_token: JWT token for the user making the request + query: The query/prompt that will be sent to the LLM + data_sources: Optional list of MCP connectors to fetch data from + context: Optional additional context for policy evaluation + + Returns: + PolicyApprovalResult with context ID and approved data + + Raises: + AuthenticationError: If user token is invalid + ConnectionError: If unable to reach AxonFlow Agent + TimeoutError: If request times out + + Example: + >>> result = await client.get_policy_approved_context( + ... user_token="user-jwt", + ... query="Find patients with diabetes", + ... data_sources=["postgres"] + ... ) + >>> if not result.approved: + ... raise PolicyViolationError(result.block_reason) + """ + request_body = { + "user_token": user_token, + "client_id": self._config.client_id, + "query": query, + "data_sources": data_sources or [], + "context": context or {}, + } + + if self._config.debug: + self._logger.debug( + "Gateway pre-check request", + query=query[:50] if query else "", + data_sources=data_sources, + ) + + response = await self._request( + "POST", + "/api/policy/pre-check", + json_data=request_body, + ) + + if self._config.debug: + self._logger.debug( + "Gateway pre-check complete", + context_id=response.get("context_id"), + approved=response.get("approved"), + ) + + rate_limit = None + if response.get("rate_limit"): + rate_limit = RateLimitInfo( + limit=response["rate_limit"]["limit"], + remaining=response["rate_limit"]["remaining"], + reset_at=datetime.fromisoformat(response["rate_limit"]["reset_at"]), + ) + + return PolicyApprovalResult( + context_id=response["context_id"], + approved=response["approved"], + approved_data=response.get("approved_data", {}), + policies=response.get("policies", []), + rate_limit_info=rate_limit, + expires_at=datetime.fromisoformat(response["expires_at"]), + block_reason=response.get("block_reason"), + ) + + async def audit_llm_call( + self, + context_id: str, + response_summary: str, + provider: str, + model: str, + token_usage: TokenUsage, + latency_ms: int, + metadata: dict[str, Any] | None = None, + ) -> AuditResult: + """Report LLM call details for audit logging. + + This is the second step in Gateway Mode. Call this after making your + LLM call to record it in the audit trail. + + Args: + context_id: Context ID from get_policy_approved_context() + response_summary: Brief summary of the LLM response (not full response) + provider: LLM provider name (openai, anthropic, bedrock, ollama) + model: Model name (gpt-4, claude-3-sonnet, etc.) + token_usage: Token counts from the LLM response + latency_ms: Time taken for the LLM call in milliseconds + metadata: Optional additional metadata to log + + Returns: + AuditResult confirming the audit was recorded + + Raises: + AxonFlowError: If audit recording fails + + Example: + >>> result = await client.audit_llm_call( + ... context_id=ctx.context_id, + ... response_summary="Found 5 patients with recent lab results", + ... provider="openai", + ... model="gpt-4", + ... token_usage=TokenUsage( + ... prompt_tokens=100, + ... completion_tokens=50, + ... total_tokens=150 + ... ), + ... latency_ms=250 + ... ) + """ + request_body = { + "context_id": context_id, + "client_id": self._config.client_id, + "response_summary": response_summary, + "provider": provider, + "model": model, + "token_usage": { + "prompt_tokens": token_usage.prompt_tokens, + "completion_tokens": token_usage.completion_tokens, + "total_tokens": token_usage.total_tokens, + }, + "latency_ms": latency_ms, + "metadata": metadata or {}, + } + + if self._config.debug: + self._logger.debug( + "Gateway audit request", + context_id=context_id, + provider=provider, + model=model, + tokens=token_usage.total_tokens, + ) + + response = await self._request( + "POST", + "/api/audit/llm-call", + json_data=request_body, + ) + + if self._config.debug: + self._logger.debug( + "Gateway audit complete", + audit_id=response.get("audit_id"), + ) + + return AuditResult( + success=response["success"], + audit_id=response["audit_id"], + ) + + +class SyncAxonFlow: + """Synchronous wrapper for AxonFlow client. + + Wraps all async methods for synchronous usage. + """ + + __slots__ = ("_async_client", "_loop") + + def __init__(self, async_client: AxonFlow) -> None: + self._async_client = async_client + self._loop: asyncio.AbstractEventLoop | None = None + + def _get_loop(self) -> asyncio.AbstractEventLoop: + """Get or create event loop.""" + if self._loop is None or self._loop.is_closed(): + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + return self._loop + + def __enter__(self) -> SyncAxonFlow: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """Close the client.""" + self._get_loop().run_until_complete(self._async_client.close()) + + @property + def config(self) -> AxonFlowConfig: + """Get client configuration.""" + return self._async_client.config + + def health_check(self) -> bool: + """Check if AxonFlow Agent is healthy.""" + return self._get_loop().run_until_complete(self._async_client.health_check()) + + def execute_query( + self, + user_token: str, + query: str, + request_type: str, + context: dict[str, Any] | None = None, + ) -> ClientResponse: + """Execute a query through AxonFlow.""" + return self._get_loop().run_until_complete( + self._async_client.execute_query(user_token, query, request_type, context) + ) + + def list_connectors(self) -> list[ConnectorMetadata]: + """List all available MCP connectors.""" + return self._get_loop().run_until_complete(self._async_client.list_connectors()) + + def install_connector(self, request: ConnectorInstallRequest) -> None: + """Install an MCP connector.""" + return self._get_loop().run_until_complete( + self._async_client.install_connector(request) + ) + + def query_connector( + self, + user_token: str, + connector_name: str, + operation: str, + params: dict[str, Any] | None = None, + ) -> ConnectorResponse: + """Query an MCP connector directly.""" + return self._get_loop().run_until_complete( + self._async_client.query_connector(user_token, connector_name, operation, params) + ) + + def generate_plan( + self, + query: str, + domain: str | None = None, + ) -> PlanResponse: + """Generate a multi-agent execution plan.""" + return self._get_loop().run_until_complete( + self._async_client.generate_plan(query, domain) + ) + + def execute_plan(self, plan_id: str) -> PlanExecutionResponse: + """Execute a previously generated plan.""" + return self._get_loop().run_until_complete(self._async_client.execute_plan(plan_id)) + + def get_plan_status(self, plan_id: str) -> PlanExecutionResponse: + """Get status of a running or completed plan.""" + return self._get_loop().run_until_complete(self._async_client.get_plan_status(plan_id)) + + # Gateway Mode sync wrappers + + def get_policy_approved_context( + self, + user_token: str, + query: str, + data_sources: list[str] | None = None, + context: dict[str, Any] | None = None, + ) -> PolicyApprovalResult: + """Perform policy pre-check before making LLM call.""" + return self._get_loop().run_until_complete( + self._async_client.get_policy_approved_context( + user_token, query, data_sources, context + ) + ) + + def audit_llm_call( + self, + context_id: str, + response_summary: str, + provider: str, + model: str, + token_usage: TokenUsage, + latency_ms: int, + metadata: dict[str, Any] | None = None, + ) -> AuditResult: + """Report LLM call details for audit logging.""" + return self._get_loop().run_until_complete( + self._async_client.audit_llm_call( + context_id, response_summary, provider, model, token_usage, latency_ms, metadata + ) + ) diff --git a/sdk/python/axonflow/exceptions.py b/sdk/python/axonflow/exceptions.py new file mode 100644 index 00000000..7f166673 --- /dev/null +++ b/sdk/python/axonflow/exceptions.py @@ -0,0 +1,103 @@ +"""AxonFlow SDK Exceptions. + +Custom exception hierarchy for clear error handling. +""" + +from __future__ import annotations + +from typing import Any + + +class AxonFlowError(Exception): + """Base exception for all AxonFlow errors.""" + + def __init__(self, message: str, details: dict[str, Any] | None = None) -> None: + self.message = message + self.details = details or {} + super().__init__(message) + + +class ConfigurationError(AxonFlowError): + """Invalid configuration.""" + + +class AuthenticationError(AxonFlowError): + """Authentication failed.""" + + +class PolicyViolationError(AxonFlowError): + """Request blocked by policy.""" + + def __init__( + self, + message: str, + policy: str | None = None, + block_reason: str | None = None, + ) -> None: + super().__init__( + message, + details={"policy": policy, "block_reason": block_reason}, + ) + self.policy = policy + self.block_reason = block_reason + + +class RateLimitError(AxonFlowError): + """Rate limit exceeded.""" + + def __init__( + self, + message: str, + limit: int, + remaining: int, + reset_at: str | None = None, + ) -> None: + super().__init__( + message, + details={"limit": limit, "remaining": remaining, "reset_at": reset_at}, + ) + self.limit = limit + self.remaining = remaining + self.reset_at = reset_at + + +class ConnectionError(AxonFlowError): + """Connection to AxonFlow Agent failed.""" + + +class TimeoutError(AxonFlowError): + """Request timed out.""" + + +class ConnectorError(AxonFlowError): + """MCP connector error.""" + + def __init__( + self, + message: str, + connector: str | None = None, + operation: str | None = None, + ) -> None: + super().__init__( + message, + details={"connector": connector, "operation": operation}, + ) + self.connector = connector + self.operation = operation + + +class PlanExecutionError(AxonFlowError): + """Multi-agent plan execution failed.""" + + def __init__( + self, + message: str, + plan_id: str | None = None, + step: str | None = None, + ) -> None: + super().__init__( + message, + details={"plan_id": plan_id, "step": step}, + ) + self.plan_id = plan_id + self.step = step diff --git a/sdk/python/axonflow/interceptors/__init__.py b/sdk/python/axonflow/interceptors/__init__.py new file mode 100644 index 00000000..9a596523 --- /dev/null +++ b/sdk/python/axonflow/interceptors/__init__.py @@ -0,0 +1,14 @@ +"""AxonFlow LLM Provider Interceptors. + +Interceptors allow transparent governance integration with popular LLM providers. +""" + +from axonflow.interceptors.anthropic import wrap_anthropic_client +from axonflow.interceptors.base import BaseInterceptor +from axonflow.interceptors.openai import wrap_openai_client + +__all__ = [ + "BaseInterceptor", + "wrap_openai_client", + "wrap_anthropic_client", +] diff --git a/sdk/python/axonflow/interceptors/anthropic.py b/sdk/python/axonflow/interceptors/anthropic.py new file mode 100644 index 00000000..e12a27c6 --- /dev/null +++ b/sdk/python/axonflow/interceptors/anthropic.py @@ -0,0 +1,190 @@ +"""Anthropic Interceptor for transparent governance. + +Wraps Anthropic client to automatically apply AxonFlow governance +without changing application code. + +Example: + >>> from anthropic import Anthropic + >>> from axonflow import AxonFlow + >>> from axonflow.interceptors.anthropic import wrap_anthropic_client + >>> + >>> anthropic = Anthropic() + >>> axonflow = AxonFlow(...) + >>> + >>> # Wrap the client - governance is now automatic + >>> wrapped = wrap_anthropic_client(anthropic, axonflow) + >>> + >>> # Use as normal - governance happens invisibly + >>> response = wrapped.messages.create( + ... model="claude-3-sonnet-20240229", + ... max_tokens=1024, + ... messages=[{"role": "user", "content": "Hello!"}] + ... ) +""" + +from __future__ import annotations + +import asyncio +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +from axonflow.exceptions import PolicyViolationError +from axonflow.interceptors.base import BaseInterceptor + +if TYPE_CHECKING: + from axonflow.client import AxonFlow + +T = TypeVar("T") + + +class AnthropicInterceptor(BaseInterceptor): + """Interceptor for Anthropic client.""" + + def get_provider_name(self) -> str: + """Get the provider name.""" + return "anthropic" + + def extract_prompt(self, *_args: Any, **kwargs: Any) -> str: + """Extract prompt from messages arguments.""" + messages = kwargs.get("messages", []) + parts = [] + for m in messages: + if isinstance(m, dict): + content = m.get("content", "") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + # Handle content blocks + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + parts.append(block.get("text", "")) + return " ".join(parts) + + def wrap(self, client: Any) -> Any: + """Wrap Anthropic client with governance.""" + return wrap_anthropic_client(client, self.axonflow, user_token=self.user_token) + + +def wrap_anthropic_client( + anthropic_client: Any, + axonflow: AxonFlow, + *, + user_token: str = "", +) -> Any: + """Wrap Anthropic client with AxonFlow governance. + + Args: + anthropic_client: Anthropic client to wrap + axonflow: AxonFlow client for governance + user_token: User token for policy evaluation + + Returns: + Wrapped Anthropic client with automatic governance + """ + original_create = anthropic_client.messages.create + + def _extract_prompt(kwargs: dict[str, Any]) -> str: + """Extract prompt from messages.""" + messages = kwargs.get("messages", []) + parts = [] + for m in messages: + if isinstance(m, dict): + content = m.get("content", "") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + parts.append(block.get("text", "")) + return " ".join(parts) + + def _get_loop() -> asyncio.AbstractEventLoop: + """Get or create event loop.""" + try: + return asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + if asyncio.iscoroutinefunction(original_create): + + @wraps(original_create) + async def async_wrapped_create(*args: Any, **kwargs: Any) -> Any: + prompt = _extract_prompt(kwargs) + + # Check with AxonFlow + response = await axonflow.execute_query( + user_token=user_token, + query=prompt, + request_type="llm_chat", + context={ + "provider": "anthropic", + "model": kwargs.get("model", "claude-3-sonnet"), + "parameters": { + k: v + for k, v in kwargs.items() + if k not in ("messages", "model", "max_tokens") + }, + }, + ) + + if response.blocked: + raise PolicyViolationError( + response.block_reason or "Request blocked by policy" + ) + + # Call original + return await original_create(*args, **kwargs) + + anthropic_client.messages.create = async_wrapped_create + else: + + @wraps(original_create) + def sync_wrapped_create(*args: Any, **kwargs: Any) -> Any: + prompt = _extract_prompt(kwargs) + + # Check with AxonFlow (sync) + loop = _get_loop() + response = loop.run_until_complete( + axonflow.execute_query( + user_token=user_token, + query=prompt, + request_type="llm_chat", + context={ + "provider": "anthropic", + "model": kwargs.get("model", "claude-3-sonnet"), + }, + ) + ) + + if response.blocked: + raise PolicyViolationError( + response.block_reason or "Request blocked by policy" + ) + + # Call original + return original_create(*args, **kwargs) + + anthropic_client.messages.create = sync_wrapped_create + + return anthropic_client + + +def create_anthropic_wrapper( + axonflow: AxonFlow, user_token: str = "" +) -> Callable[[Any], Any]: + """Create a wrapper function for Anthropic client. + + Args: + axonflow: AxonFlow client for governance + user_token: User token for policy evaluation + + Returns: + Wrapper function that takes an Anthropic client + """ + + def wrapper(anthropic_client: Any) -> Any: + return wrap_anthropic_client(anthropic_client, axonflow, user_token=user_token) + + return wrapper diff --git a/sdk/python/axonflow/interceptors/base.py b/sdk/python/axonflow/interceptors/base.py new file mode 100644 index 00000000..ff77d49b --- /dev/null +++ b/sdk/python/axonflow/interceptors/base.py @@ -0,0 +1,58 @@ +"""Base Interceptor class for LLM providers.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from axonflow.client import AxonFlow + + +class BaseInterceptor(ABC): + """Base class for LLM provider interceptors. + + Interceptors wrap LLM client methods to automatically apply + AxonFlow governance without changing application code. + """ + + def __init__(self, axonflow: AxonFlow, user_token: str = "") -> None: + """Initialize interceptor. + + Args: + axonflow: AxonFlow client for governance + user_token: User token for policy evaluation + """ + self.axonflow = axonflow + self.user_token = user_token + + @abstractmethod + def wrap(self, client: Any) -> Any: + """Wrap an LLM client with governance. + + Args: + client: The LLM client to wrap + + Returns: + Wrapped client with automatic governance + """ + + @abstractmethod + def extract_prompt(self, *args: Any, **kwargs: Any) -> str: + """Extract prompt from method arguments. + + Args: + args: Positional arguments + kwargs: Keyword arguments + + Returns: + Extracted prompt string + """ + + @abstractmethod + def get_provider_name(self) -> str: + """Get the provider name. + + Returns: + Provider name (e.g., "openai", "anthropic") + """ diff --git a/sdk/python/axonflow/interceptors/openai.py b/sdk/python/axonflow/interceptors/openai.py new file mode 100644 index 00000000..24869fbe --- /dev/null +++ b/sdk/python/axonflow/interceptors/openai.py @@ -0,0 +1,164 @@ +"""OpenAI Interceptor for transparent governance. + +Wraps OpenAI client to automatically apply AxonFlow governance +without changing application code. + +Example: + >>> from openai import OpenAI + >>> from axonflow import AxonFlow + >>> from axonflow.interceptors.openai import wrap_openai_client + >>> + >>> openai = OpenAI() + >>> axonflow = AxonFlow(...) + >>> + >>> # Wrap the client - governance is now automatic + >>> wrapped = wrap_openai_client(openai, axonflow) + >>> + >>> # Use as normal - governance happens invisibly + >>> response = wrapped.chat.completions.create( + ... model="gpt-4", + ... messages=[{"role": "user", "content": "Hello!"}] + ... ) +""" + +from __future__ import annotations + +import asyncio +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +from axonflow.exceptions import PolicyViolationError +from axonflow.interceptors.base import BaseInterceptor + +if TYPE_CHECKING: + from axonflow.client import AxonFlow + +T = TypeVar("T") + + +class OpenAIInterceptor(BaseInterceptor): + """Interceptor for OpenAI client.""" + + def get_provider_name(self) -> str: + """Get the provider name.""" + return "openai" + + def extract_prompt(self, *_args: Any, **kwargs: Any) -> str: + """Extract prompt from chat completions arguments.""" + messages = kwargs.get("messages", []) + return " ".join(m.get("content", "") for m in messages if isinstance(m, dict)) + + def wrap(self, client: Any) -> Any: + """Wrap OpenAI client with governance.""" + return wrap_openai_client(client, self.axonflow, user_token=self.user_token) + + +def wrap_openai_client( + openai_client: Any, + axonflow: AxonFlow, + *, + user_token: str = "", +) -> Any: + """Wrap OpenAI client with AxonFlow governance. + + Args: + openai_client: OpenAI client to wrap + axonflow: AxonFlow client for governance + user_token: User token for policy evaluation + + Returns: + Wrapped OpenAI client with automatic governance + """ + original_create = openai_client.chat.completions.create + + def _extract_prompt(kwargs: dict[str, Any]) -> str: + """Extract prompt from messages.""" + messages = kwargs.get("messages", []) + return " ".join(m.get("content", "") for m in messages if isinstance(m, dict)) + + def _get_loop() -> asyncio.AbstractEventLoop: + """Get or create event loop.""" + try: + return asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + if asyncio.iscoroutinefunction(original_create): + + @wraps(original_create) + async def async_wrapped_create(*args: Any, **kwargs: Any) -> Any: + prompt = _extract_prompt(kwargs) + + # Check with AxonFlow + response = await axonflow.execute_query( + user_token=user_token, + query=prompt, + request_type="llm_chat", + context={ + "provider": "openai", + "model": kwargs.get("model", "gpt-4"), + "parameters": { + k: v for k, v in kwargs.items() if k not in ("messages", "model") + }, + }, + ) + + if response.blocked: + raise PolicyViolationError( + response.block_reason or "Request blocked by policy" + ) + + # Call original + return await original_create(*args, **kwargs) + + openai_client.chat.completions.create = async_wrapped_create + else: + + @wraps(original_create) + def sync_wrapped_create(*args: Any, **kwargs: Any) -> Any: + prompt = _extract_prompt(kwargs) + + # Check with AxonFlow (sync) + loop = _get_loop() + response = loop.run_until_complete( + axonflow.execute_query( + user_token=user_token, + query=prompt, + request_type="llm_chat", + context={ + "provider": "openai", + "model": kwargs.get("model", "gpt-4"), + }, + ) + ) + + if response.blocked: + raise PolicyViolationError( + response.block_reason or "Request blocked by policy" + ) + + # Call original + return original_create(*args, **kwargs) + + openai_client.chat.completions.create = sync_wrapped_create + + return openai_client + + +def create_openai_wrapper(axonflow: AxonFlow, user_token: str = "") -> Callable[[Any], Any]: + """Create a wrapper function for OpenAI client. + + Args: + axonflow: AxonFlow client for governance + user_token: User token for policy evaluation + + Returns: + Wrapper function that takes an OpenAI client + """ + + def wrapper(openai_client: Any) -> Any: + return wrap_openai_client(openai_client, axonflow, user_token=user_token) + + return wrapper diff --git a/sdk/python/axonflow/py.typed b/sdk/python/axonflow/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/sdk/python/axonflow/types.py b/sdk/python/axonflow/types.py new file mode 100644 index 00000000..7e5b3e4c --- /dev/null +++ b/sdk/python/axonflow/types.py @@ -0,0 +1,213 @@ +"""AxonFlow SDK Type Definitions. + +All types are defined using Pydantic v2 for runtime validation +and automatic JSON serialization/deserialization. +""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class Mode(str, Enum): + """SDK operation mode.""" + + PRODUCTION = "production" + SANDBOX = "sandbox" + + +class RetryConfig(BaseModel): + """Retry configuration with exponential backoff.""" + + model_config = ConfigDict(frozen=True) + + enabled: bool = Field(default=True, description="Enable retry logic") + max_attempts: int = Field(default=3, ge=1, le=10, description="Max retry attempts") + initial_delay: float = Field(default=1.0, gt=0, description="Initial delay (seconds)") + max_delay: float = Field(default=30.0, gt=0, description="Max delay (seconds)") + exponential_base: float = Field(default=2.0, gt=1, description="Backoff multiplier") + + +class CacheConfig(BaseModel): + """Cache configuration.""" + + model_config = ConfigDict(frozen=True) + + enabled: bool = Field(default=True, description="Enable caching") + ttl: float = Field(default=60.0, gt=0, description="Cache TTL (seconds)") + max_size: int = Field(default=1000, gt=0, description="Max cache entries") + + +class AxonFlowConfig(BaseModel): + """Configuration for AxonFlow client. + + Attributes: + agent_url: AxonFlow Agent URL (required) + client_id: Client ID for authentication (required) + client_secret: Client secret for authentication (required) + license_key: Optional license key for organization-level auth + mode: Operation mode (production or sandbox) + debug: Enable debug logging + timeout: Request timeout in seconds + insecure_skip_verify: Skip TLS verification (dev only) + retry: Retry configuration + cache: Cache configuration + """ + + model_config = ConfigDict(frozen=True) + + agent_url: str = Field(..., min_length=1, description="AxonFlow Agent URL") + client_id: str = Field(..., min_length=1, description="Client ID") + client_secret: str = Field(..., min_length=1, description="Client secret") + license_key: str | None = Field(default=None, description="License key") + mode: Mode = Field(default=Mode.PRODUCTION, description="Operation mode") + debug: bool = Field(default=False, description="Enable debug logging") + timeout: float = Field(default=60.0, gt=0, description="Request timeout (seconds)") + insecure_skip_verify: bool = Field(default=False, description="Skip TLS verify") + retry: RetryConfig = Field(default_factory=RetryConfig) + cache: CacheConfig = Field(default_factory=CacheConfig) + + +class ClientRequest(BaseModel): + """Request to AxonFlow Agent.""" + + query: str = Field(..., description="Query or prompt") + user_token: str = Field(..., description="User token for auth") + client_id: str = Field(..., description="Client ID") + request_type: str = Field(..., description="Request type") + context: dict[str, Any] = Field(default_factory=dict, description="Additional context") + + +class PolicyEvaluationInfo(BaseModel): + """Policy evaluation metadata.""" + + policies_evaluated: list[str] = Field(default_factory=list) + static_checks: list[str] = Field(default_factory=list) + processing_time: str = Field(default="0ms") + tenant_id: str = Field(default="") + + +class ClientResponse(BaseModel): + """Response from AxonFlow Agent.""" + + success: bool = Field(..., description="Whether request succeeded") + data: Any | None = Field(default=None, description="Response data") + result: str | None = Field(default=None, description="Result for planning") + plan_id: str | None = Field(default=None, description="Plan ID if applicable") + metadata: dict[str, Any] = Field(default_factory=dict) + error: str | None = Field(default=None, description="Error message if failed") + blocked: bool = Field(default=False, description="Whether request was blocked") + block_reason: str | None = Field(default=None, description="Block reason") + policy_info: PolicyEvaluationInfo | None = Field(default=None) + + +class ConnectorMetadata(BaseModel): + """MCP connector metadata.""" + + id: str + name: str + type: str + version: str = "" + description: str = "" + category: str = "" + icon: str = "" + tags: list[str] = Field(default_factory=list) + capabilities: list[str] = Field(default_factory=list) + config_schema: dict[str, Any] = Field(default_factory=dict) + installed: bool = False + healthy: bool = False + + +class ConnectorInstallRequest(BaseModel): + """Request to install an MCP connector.""" + + connector_id: str + name: str + tenant_id: str + options: dict[str, Any] = Field(default_factory=dict) + credentials: dict[str, str] = Field(default_factory=dict) + + +class ConnectorResponse(BaseModel): + """Response from MCP connector query.""" + + success: bool + data: Any | None = None + error: str | None = None + meta: dict[str, Any] = Field(default_factory=dict) + + +class PlanStep(BaseModel): + """A step in a multi-agent plan.""" + + id: str + name: str + type: str + description: str = "" + depends_on: list[str] = Field(default_factory=list) + agent: str = "" + parameters: dict[str, Any] = Field(default_factory=dict) + + +class PlanResponse(BaseModel): + """Multi-agent plan response.""" + + plan_id: str + steps: list[PlanStep] = Field(default_factory=list) + domain: str = "generic" + complexity: int = 0 + parallel: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class PlanExecutionResponse(BaseModel): + """Plan execution result.""" + + plan_id: str + status: str # "running", "completed", "failed" + result: str | None = None + step_results: dict[str, Any] = Field(default_factory=dict) + error: str | None = None + duration: str | None = None + + +# Gateway Mode Types + + +class RateLimitInfo(BaseModel): + """Rate limiting status.""" + + limit: int + remaining: int + reset_at: datetime + + +class PolicyApprovalResult(BaseModel): + """Pre-check result from Gateway Mode.""" + + context_id: str = Field(..., description="Context ID for audit linking") + approved: bool = Field(..., description="Whether request is approved") + approved_data: dict[str, Any] = Field(default_factory=dict) + policies: list[str] = Field(default_factory=list) + rate_limit_info: RateLimitInfo | None = None + expires_at: datetime + block_reason: str | None = None + + +class TokenUsage(BaseModel): + """LLM token usage tracking.""" + + prompt_tokens: int = Field(ge=0) + completion_tokens: int = Field(ge=0) + total_tokens: int = Field(ge=0) + + +class AuditResult(BaseModel): + """Audit confirmation.""" + + success: bool + audit_id: str diff --git a/sdk/python/axonflow/utils/__init__.py b/sdk/python/axonflow/utils/__init__.py new file mode 100644 index 00000000..00e5aebb --- /dev/null +++ b/sdk/python/axonflow/utils/__init__.py @@ -0,0 +1,12 @@ +"""AxonFlow SDK Utilities.""" + +from axonflow.utils.cache import CacheManager +from axonflow.utils.logging import configure_logging, get_logger +from axonflow.utils.retry import RetryHandler + +__all__ = [ + "CacheManager", + "RetryHandler", + "configure_logging", + "get_logger", +] diff --git a/sdk/python/axonflow/utils/cache.py b/sdk/python/axonflow/utils/cache.py new file mode 100644 index 00000000..59bb9e14 --- /dev/null +++ b/sdk/python/axonflow/utils/cache.py @@ -0,0 +1,101 @@ +"""Cache utilities for AxonFlow SDK.""" + +from __future__ import annotations + +from typing import Any, Generic, TypeVar + +from cachetools import TTLCache + +T = TypeVar("T") + + +class CacheManager(Generic[T]): + """Generic cache manager with TTL support. + + Provides a simple interface for caching with automatic expiration. + """ + + def __init__(self, maxsize: int = 1000, ttl: float = 60.0) -> None: + """Initialize cache manager. + + Args: + maxsize: Maximum number of entries + ttl: Time-to-live in seconds + """ + self._cache: TTLCache[str, T] = TTLCache(maxsize=maxsize, ttl=ttl) + self._ttl = ttl + self._maxsize = maxsize + + def get(self, key: str) -> T | None: + """Get a value from cache. + + Args: + key: Cache key + + Returns: + Cached value or None if not found + """ + return self._cache.get(key) + + def set(self, key: str, value: T) -> None: + """Set a value in cache. + + Args: + key: Cache key + value: Value to cache + """ + self._cache[key] = value + + def delete(self, key: str) -> None: + """Delete a value from cache. + + Args: + key: Cache key + """ + self._cache.pop(key, None) + + def clear(self) -> None: + """Clear all cached values.""" + self._cache.clear() + + def contains(self, key: str) -> bool: + """Check if key exists in cache. + + Args: + key: Cache key + + Returns: + True if key exists + """ + return key in self._cache + + @property + def size(self) -> int: + """Get current cache size.""" + return len(self._cache) + + @property + def ttl(self) -> float: + """Get cache TTL.""" + return self._ttl + + @property + def maxsize(self) -> int: + """Get maximum cache size.""" + return self._maxsize + + def get_or_set(self, key: str, factory: Any) -> T: + """Get from cache or create using factory. + + Args: + key: Cache key + factory: Callable to create value if not cached + + Returns: + Cached or newly created value + """ + value = self.get(key) + if value is None: + value = factory() + self.set(key, value) + return value diff --git a/sdk/python/axonflow/utils/logging.py b/sdk/python/axonflow/utils/logging.py new file mode 100644 index 00000000..4f8324ab --- /dev/null +++ b/sdk/python/axonflow/utils/logging.py @@ -0,0 +1,88 @@ +"""Logging utilities for AxonFlow SDK.""" + +from __future__ import annotations + +import logging +import sys +from typing import Any + +import structlog + + +def configure_logging( + level: int = logging.INFO, + json_format: bool = False, +) -> None: + """Configure structured logging for AxonFlow SDK. + + Args: + level: Logging level (default: INFO) + json_format: Use JSON format for logs (default: False) + """ + # Configure structlog processors + processors: list[Any] = [ + structlog.stdlib.filter_by_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.UnicodeDecoder(), + ] + + if json_format: + processors.append(structlog.processors.JSONRenderer()) + else: + processors.append(structlog.dev.ConsoleRenderer()) + + structlog.configure( + processors=processors, + wrapper_class=structlog.stdlib.BoundLogger, + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True, + ) + + # Configure stdlib logging + logging.basicConfig( + format="%(message)s", + stream=sys.stdout, + level=level, + ) + + +def get_logger(name: str) -> structlog.stdlib.BoundLogger: + """Get a structured logger. + + Args: + name: Logger name + + Returns: + Configured structured logger + """ + return structlog.get_logger(name) + + +class LogContext: + """Context manager for adding log context.""" + + def __init__(self, logger: structlog.stdlib.BoundLogger, **context: Any) -> None: + """Initialize log context. + + Args: + logger: Logger to bind context to + **context: Context key-value pairs + """ + self._logger = logger + self._context = context + self._original_context: dict[str, Any] = {} + + def __enter__(self) -> structlog.stdlib.BoundLogger: + """Enter context and bind values.""" + self._original_context = dict(self._logger._context) # type: ignore[attr-defined] + return self._logger.bind(**self._context) + + def __exit__(self, *args: Any) -> None: + """Exit context and restore original values.""" + # Restore original context + self._logger._context = self._original_context # type: ignore[attr-defined] diff --git a/sdk/python/axonflow/utils/retry.py b/sdk/python/axonflow/utils/retry.py new file mode 100644 index 00000000..c56dddf7 --- /dev/null +++ b/sdk/python/axonflow/utils/retry.py @@ -0,0 +1,113 @@ +"""Retry utilities for AxonFlow SDK.""" + +from __future__ import annotations + +from typing import Any, Callable, TypeVar + +from tenacity import ( + RetryCallState, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from axonflow.types import RetryConfig + +T = TypeVar("T") + + +class RetryHandler: + """Handles retry logic with exponential backoff.""" + + def __init__(self, config: RetryConfig) -> None: + """Initialize retry handler. + + Args: + config: Retry configuration + """ + self.config = config + + def create_decorator( + self, + retry_on: tuple[type[Exception], ...], + ) -> Callable[[Callable[..., T]], Callable[..., T]]: + """Create retry decorator based on config. + + Args: + retry_on: Exception types to retry on + + Returns: + Decorator function + """ + if not self.config.enabled: + return lambda f: f + + return retry( + stop=stop_after_attempt(self.config.max_attempts), + wait=wait_exponential( + multiplier=self.config.initial_delay, + max=self.config.max_delay, + exp_base=self.config.exponential_base, + ), + retry=retry_if_exception_type(retry_on), + reraise=True, + ) + + @staticmethod + def log_retry(retry_state: RetryCallState) -> None: + """Log retry attempt. + + Args: + retry_state: Current retry state + """ + if retry_state.outcome and retry_state.outcome.failed: + exception = retry_state.outcome.exception() + attempt = retry_state.attempt_number + print(f"Retry attempt {attempt} failed: {exception}") + + +def with_retry( + max_attempts: int = 3, + initial_delay: float = 1.0, + max_delay: float = 30.0, + exponential_base: float = 2.0, + retry_on: tuple[type[Exception], ...] = (Exception,), +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """Decorator for adding retry logic to a function. + + Args: + max_attempts: Maximum retry attempts + initial_delay: Initial delay between retries + max_delay: Maximum delay between retries + exponential_base: Exponential backoff base + retry_on: Exception types to retry on + + Returns: + Decorator function + """ + config = RetryConfig( + enabled=True, + max_attempts=max_attempts, + initial_delay=initial_delay, + max_delay=max_delay, + exponential_base=exponential_base, + ) + handler = RetryHandler(config) + return handler.create_decorator(retry_on) + + +def create_retry_decorator( + config: RetryConfig, retry_on: tuple[type[Exception], ...] +) -> Any: + """Create a retry decorator from config. + + Args: + config: Retry configuration + retry_on: Exception types to retry on + + Returns: + Retry decorator + """ + handler = RetryHandler(config) + return handler.create_decorator(retry_on) diff --git a/sdk/python/examples/gateway_mode.py b/sdk/python/examples/gateway_mode.py new file mode 100644 index 00000000..07b04e2b --- /dev/null +++ b/sdk/python/examples/gateway_mode.py @@ -0,0 +1,170 @@ +"""Gateway Mode Example - Direct LLM calls with AxonFlow governance. + +This example shows how to use Gateway Mode for lowest-latency LLM calls +while maintaining full governance and audit compliance. + +Gateway Mode Flow: +1. Pre-check: Get policy approval and filtered data +2. LLM Call: Make direct call to your LLM provider +3. Audit: Report the call for compliance + +Run with: python gateway_mode.py +""" + +import asyncio +import os +import time + +from axonflow import AxonFlow, TokenUsage + +# Simulated LLM response (replace with actual OpenAI/Anthropic call) +MOCK_LLM_RESPONSE = { + "choices": [ + { + "message": { + "content": "Found 5 patients with recent lab results: P001, P002, P003, P004, P005" + } + } + ], + "usage": { + "prompt_tokens": 150, + "completion_tokens": 45, + "total_tokens": 195, + }, +} + + +async def main() -> None: + """Run Gateway Mode example.""" + async with AxonFlow( + agent_url=os.environ.get("AXONFLOW_AGENT_URL", "https://staging-eu.getaxonflow.com"), + client_id=os.environ.get("AXONFLOW_CLIENT_ID", "demo-client"), + client_secret=os.environ.get("AXONFLOW_CLIENT_SECRET", "demo-secret"), + license_key=os.environ.get("AXONFLOW_LICENSE_KEY"), + debug=True, + ) as axonflow: + print("=== Gateway Mode Example ===\n") + + # ===================================================================== + # Step 1: Pre-check - Get policy approval before LLM call + # ===================================================================== + print("Step 1: Policy Pre-Check") + print("-" * 40) + + ctx = await axonflow.get_policy_approved_context( + user_token="user-jwt-token", # Your user's JWT + query="Find patients with recent lab results", + data_sources=["postgres"], # MCP connectors to fetch data from + context={"department": "cardiology"}, # Additional context + ) + + print(f" Context ID: {ctx.context_id}") + print(f" Approved: {ctx.approved}") + print(f" Policies: {ctx.policies}") + print(f" Expires: {ctx.expires_at}") + + if ctx.rate_limit_info: + print(f" Rate Limit: {ctx.rate_limit_info.remaining}/{ctx.rate_limit_info.limit}") + + if not ctx.approved: + print(f"\n❌ Request blocked: {ctx.block_reason}") + return + + print(f" Approved Data Keys: {list(ctx.approved_data.keys())}") + print("\n✅ Pre-check passed!\n") + + # ===================================================================== + # Step 2: Make LLM call directly (lowest latency) + # ===================================================================== + print("Step 2: Direct LLM Call") + print("-" * 40) + + # Build prompt using approved data (filtered by policies) + prompt = f"""Based on this data: {ctx.approved_data} + +Please summarize the patient results.""" + + print(f" Prompt length: {len(prompt)} chars") + + # Time the LLM call + start_time = time.time() + + # In production, this would be: + # response = await openai.chat.completions.create( + # model="gpt-4", + # messages=[ + # {"role": "system", "content": "You are a helpful healthcare assistant."}, + # {"role": "user", "content": prompt}, + # ], + # ) + + # Simulated response for demo + llm_response = MOCK_LLM_RESPONSE + await asyncio.sleep(0.1) # Simulate latency + + latency_ms = int((time.time() - start_time) * 1000) + content = llm_response["choices"][0]["message"]["content"] + + print(f" Latency: {latency_ms}ms") + print(f" Response: {content[:100]}...") + print(f" Tokens: {llm_response['usage']['total_tokens']}") + print("\n✅ LLM call complete!\n") + + # ===================================================================== + # Step 3: Audit the call for compliance + # ===================================================================== + print("Step 3: Audit Logging") + print("-" * 40) + + audit_result = await axonflow.audit_llm_call( + context_id=ctx.context_id, # Links to pre-check + response_summary=content[:100], # Brief summary, not full response + provider="openai", + model="gpt-4", + token_usage=TokenUsage( + prompt_tokens=llm_response["usage"]["prompt_tokens"], + completion_tokens=llm_response["usage"]["completion_tokens"], + total_tokens=llm_response["usage"]["total_tokens"], + ), + latency_ms=latency_ms, + metadata={ + "department": "cardiology", + "session_id": "session-123", + }, + ) + + print(f" Audit ID: {audit_result.audit_id}") + print(f" Success: {audit_result.success}") + print("\n✅ Audit recorded!\n") + + print("=" * 40) + print("Gateway Mode flow complete!") + print("=" * 40) + + +async def blocked_example() -> None: + """Example showing a blocked request.""" + async with AxonFlow( + agent_url=os.environ.get("AXONFLOW_AGENT_URL", "https://staging-eu.getaxonflow.com"), + client_id=os.environ.get("AXONFLOW_CLIENT_ID", "demo-client"), + client_secret=os.environ.get("AXONFLOW_CLIENT_SECRET", "demo-secret"), + debug=True, + ) as axonflow: + print("\n=== Blocked Request Example ===\n") + + ctx = await axonflow.get_policy_approved_context( + user_token="user-jwt-token", + query="Show me all social security numbers", # Sensitive query + ) + + if not ctx.approved: + print(f"❌ Request blocked!") + print(f" Reason: {ctx.block_reason}") + print(f" Policies: {ctx.policies}") + else: + print("✅ Request approved (unexpected)") + + +if __name__ == "__main__": + asyncio.run(main()) + asyncio.run(blocked_example()) diff --git a/sdk/python/examples/openai_integration.py b/sdk/python/examples/openai_integration.py new file mode 100644 index 00000000..6173b9e1 --- /dev/null +++ b/sdk/python/examples/openai_integration.py @@ -0,0 +1,77 @@ +"""OpenAI Integration Example - Transparent governance for OpenAI calls. + +This example shows how to wrap your OpenAI client with AxonFlow +governance without changing your existing code. + +Run with: python openai_integration.py +""" + +import asyncio +import os + +from axonflow import AxonFlow +from axonflow.interceptors.openai import wrap_openai_client + + +async def main() -> None: + """Run OpenAI integration example.""" + # Check if openai is installed + try: + from openai import AsyncOpenAI + except ImportError: + print("OpenAI not installed. Run: pip install axonflow[openai]") + return + + print("=== OpenAI Integration Example ===\n") + + # Initialize both clients + openai_client = AsyncOpenAI() + + async with AxonFlow( + agent_url=os.environ.get("AXONFLOW_AGENT_URL", "https://staging-eu.getaxonflow.com"), + client_id=os.environ.get("AXONFLOW_CLIENT_ID", "demo-client"), + client_secret=os.environ.get("AXONFLOW_CLIENT_SECRET", "demo-secret"), + debug=True, + ) as axonflow: + # Wrap OpenAI client with governance + wrapped_openai = wrap_openai_client( + openai_client, + axonflow, + user_token="user-123", # Your user's token + ) + + print("OpenAI client wrapped with AxonFlow governance\n") + + # Use OpenAI as normal - governance happens automatically + print("Making governed OpenAI call...") + try: + response = await wrapped_openai.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ], + max_tokens=100, + ) + + print(f"\nResponse: {response.choices[0].message.content}") + print(f"Tokens used: {response.usage.total_tokens}") + except Exception as e: + print(f"\nError (expected if no OpenAI key): {e}") + + # Example of a blocked request + print("\n--- Testing policy block ---") + try: + # This might be blocked by policies + await wrapped_openai.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "user", "content": "Tell me how to hack a system"}, + ], + ) + except Exception as e: + print(f"Request handled: {type(e).__name__}: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/python/examples/quickstart.py b/sdk/python/examples/quickstart.py new file mode 100644 index 00000000..f1acce5f --- /dev/null +++ b/sdk/python/examples/quickstart.py @@ -0,0 +1,70 @@ +"""AxonFlow Python SDK Quickstart. + +This example shows the simplest way to get started with AxonFlow. +Run with: python quickstart.py +""" + +import asyncio +import os + +from axonflow import AxonFlow + + +async def main() -> None: + """Run quickstart example.""" + # Initialize client from environment variables + async with AxonFlow( + agent_url=os.environ.get("AXONFLOW_AGENT_URL", "https://staging-eu.getaxonflow.com"), + client_id=os.environ.get("AXONFLOW_CLIENT_ID", "demo-client"), + client_secret=os.environ.get("AXONFLOW_CLIENT_SECRET", "demo-secret"), + debug=True, + ) as client: + # Check agent health + healthy = await client.health_check() + print(f"Agent healthy: {healthy}") + + if not healthy: + print("Agent not available, exiting") + return + + # Execute a simple query with governance + print("\n--- Executing governed query ---") + response = await client.execute_query( + user_token="demo-user", + query="What is the capital of France?", + request_type="chat", + ) + + print(f"Success: {response.success}") + print(f"Blocked: {response.blocked}") + if response.data: + print(f"Result: {response.data}") + + # Policy info shows what was evaluated + if response.policy_info: + print(f"Policies evaluated: {response.policy_info.policies_evaluated}") + print(f"Processing time: {response.policy_info.processing_time}") + + +def sync_example() -> None: + """Synchronous usage example.""" + # Create sync client + with AxonFlow.sync( + agent_url=os.environ.get("AXONFLOW_AGENT_URL", "https://staging-eu.getaxonflow.com"), + client_id=os.environ.get("AXONFLOW_CLIENT_ID", "demo-client"), + client_secret=os.environ.get("AXONFLOW_CLIENT_SECRET", "demo-secret"), + ) as client: + result = client.execute_query( + user_token="demo-user", + query="Hello, world!", + request_type="chat", + ) + print(f"Sync result: {result.success}") + + +if __name__ == "__main__": + print("=== Async Example ===") + asyncio.run(main()) + + print("\n=== Sync Example ===") + sync_example() diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml new file mode 100644 index 00000000..647a5d31 --- /dev/null +++ b/sdk/python/pyproject.toml @@ -0,0 +1,197 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "axonflow" +version = "0.1.0" +description = "AxonFlow Python SDK - Enterprise AI Governance in 3 Lines of Code" +readme = "README.md" +license = {text = "Apache-2.0"} +authors = [ + {name = "AxonFlow", email = "engineering@getaxonflow.com"} +] +maintainers = [ + {name = "AxonFlow", email = "engineering@getaxonflow.com"} +] +keywords = [ + "ai", "governance", "llm", "openai", "anthropic", "bedrock", + "policy", "compliance", "enterprise", "mcp", "multi-agent" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Typing :: Typed", +] +requires-python = ">=3.9" +dependencies = [ + "httpx>=0.25.0", + "pydantic>=2.0.0", + "tenacity>=8.0.0", + "structlog>=23.0.0", + "cachetools>=5.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.0.0", + "pytest-httpx>=0.22.0", + "mypy>=1.5.0", + "ruff>=0.1.0", + "black>=23.0.0", + "isort>=5.12.0", + "pre-commit>=3.0.0", +] +docs = [ + "sphinx>=7.0.0", + "sphinx-rtd-theme>=1.3.0", + "sphinx-autodoc-typehints>=1.24.0", + "myst-parser>=2.0.0", +] +openai = ["openai>=1.0.0"] +anthropic = ["anthropic>=0.18.0"] +all = [ + "openai>=1.0.0", + "anthropic>=0.18.0", +] + +[project.urls] +Homepage = "https://getaxonflow.com" +Documentation = "https://docs.getaxonflow.com/sdk/python" +Repository = "https://github.com/getaxonflow/axonflow" +Changelog = "https://github.com/getaxonflow/axonflow/blob/main/sdk/python/CHANGELOG.md" + +[tool.setuptools.packages.find] +where = ["."] +include = ["axonflow*"] + +[tool.setuptools.package-data] +axonflow = ["py.typed"] + +[tool.ruff] +target-version = "py39" +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear + "UP", # pyupgrade + "ANN", # flake8-annotations + "S", # flake8-bandit (security) + "BLE", # flake8-blind-except + "A", # flake8-builtins + "C4", # flake8-comprehensions + "DTZ", # flake8-datetimez + "T10", # flake8-debugger + "EM", # flake8-errmsg + "ISC", # flake8-implicit-str-concat + "ICN", # flake8-import-conventions + "G", # flake8-logging-format + "PIE", # flake8-pie + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RSE", # flake8-raise + "RET", # flake8-return + "SLF", # flake8-self + "SIM", # flake8-simplify + "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + "ERA", # eradicate + "PGH", # pygrep-hooks + "PL", # pylint + "TRY", # tryceratops + "RUF", # Ruff-specific rules +] +ignore = [ + "ANN401", # Dynamically typed expressions (Any) + "COM812", # Trailing comma (conflicts with formatter) + "ISC001", # Single line implicit string concatenation + "PLR0913", # Too many arguments + "TRY003", # Avoid long exception messages + "A001", # Variable shadowing builtin + "A004", # Import shadowing builtin + "RUF022", # Unsorted __all__ + "RUF023", # Unsorted __slots__ + "RUF100", # Unused noqa + "TC003", # Move stdlib import to type-checking (false positive for pydantic) + "SLF001", # Private member access (needed for structlog context) + "C901", # Too complex (interceptors need nested structure) +] + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["S101", "S105", "S106", "ANN", "PLR2004", "ARG", "PLC0415", "PT011", "PT012", "TC002", "UP035", "F401", "F841", "EM101", "TRY301"] +"examples/**/*.py" = ["T201", "ANN", "S106", "ERA001", "BLE001", "PLC0415", "F541"] + +[tool.ruff.lint.isort] +known-first-party = ["axonflow"] + +[tool.mypy] +python_version = "3.9" +strict = true +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true +show_error_codes = true +show_column_numbers = true + +[[tool.mypy.overrides]] +module = "tests.*" +disallow_untyped_defs = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +addopts = [ + "-v", + "--strict-markers", + "--cov=axonflow", + "--cov-report=term-missing", + "--cov-report=xml", + "--cov-fail-under=95", +] +markers = [ + "unit: Unit tests", + "integration: Integration tests (requires live agent)", + "slow: Slow tests", +] + +[tool.coverage.run] +branch = true +source = ["axonflow"] +omit = ["tests/*", "examples/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "if __name__ == .__main__.:", +] diff --git a/sdk/python/setup.py b/sdk/python/setup.py new file mode 100644 index 00000000..ea8f0fba --- /dev/null +++ b/sdk/python/setup.py @@ -0,0 +1,5 @@ +"""Legacy setup.py for editable installs.""" + +from setuptools import setup + +setup() diff --git a/sdk/python/tests/__init__.py b/sdk/python/tests/__init__.py new file mode 100644 index 00000000..ea657b26 --- /dev/null +++ b/sdk/python/tests/__init__.py @@ -0,0 +1 @@ +"""AxonFlow SDK Test Suite.""" diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py new file mode 100644 index 00000000..6620c83e --- /dev/null +++ b/sdk/python/tests/conftest.py @@ -0,0 +1,172 @@ +"""Pytest fixtures for AxonFlow SDK tests.""" + +from __future__ import annotations + +from typing import Any, AsyncGenerator + +import pytest +import pytest_asyncio +from pytest_httpx import HTTPXMock + +from axonflow import AxonFlow + + +@pytest.fixture +def config_dict() -> dict[str, Any]: + """Base configuration dictionary.""" + return { + "agent_url": "https://test.axonflow.com", + "client_id": "test-client", + "client_secret": "test-secret", + "debug": True, + } + + +@pytest_asyncio.fixture +async def client(config_dict: dict[str, Any]) -> AsyncGenerator[AxonFlow, None]: + """Create test AxonFlow client.""" + async with AxonFlow(**config_dict) as c: + yield c + + +@pytest.fixture +def sync_client(config_dict: dict[str, Any]): + """Create sync test AxonFlow client.""" + with AxonFlow.sync(**config_dict) as c: + yield c + + +@pytest.fixture +def mock_health_response() -> dict[str, Any]: + """Mock health check response.""" + return { + "status": "healthy", + "version": "1.0.0", + "components": { + "database": "connected", + "orchestrator": "reachable", + }, + } + + +@pytest.fixture +def mock_query_response() -> dict[str, Any]: + """Mock successful query response.""" + return { + "success": True, + "data": {"result": "test result"}, + "blocked": False, + "metadata": {}, + "policy_info": { + "policies_evaluated": ["default"], + "static_checks": [], + "processing_time": "5ms", + "tenant_id": "test", + }, + } + + +@pytest.fixture +def mock_blocked_response() -> dict[str, Any]: + """Mock blocked query response.""" + return { + "success": False, + "blocked": True, + "block_reason": "Rate limit exceeded", + "error": "Request blocked by policy", + } + + +@pytest.fixture +def mock_connector_list() -> list[dict[str, Any]]: + """Mock connector list response.""" + return [ + { + "id": "postgres", + "name": "PostgreSQL", + "type": "database", + "version": "1.0.0", + "description": "PostgreSQL database connector", + "category": "database", + "tags": ["sql", "relational"], + "capabilities": ["read", "write"], + "config_schema": {}, + "installed": True, + "healthy": True, + }, + { + "id": "salesforce", + "name": "Salesforce", + "type": "crm", + "version": "1.0.0", + "description": "Salesforce CRM connector", + "category": "crm", + "tags": ["crm", "sales"], + "capabilities": ["read"], + "config_schema": {}, + "installed": False, + "healthy": False, + }, + ] + + +@pytest.fixture +def mock_plan_response() -> dict[str, Any]: + """Mock plan generation response.""" + return { + "success": True, + "plan_id": "plan-123", + "data": { + "steps": [ + { + "id": "step-1", + "name": "Fetch data", + "type": "data", + "description": "Fetch customer data", + "depends_on": [], + "agent": "data-agent", + "parameters": {}, + }, + { + "id": "step-2", + "name": "Process data", + "type": "process", + "description": "Process the data", + "depends_on": ["step-1"], + "agent": "process-agent", + "parameters": {}, + }, + ], + "domain": "generic", + "complexity": 2, + "parallel": False, + }, + "metadata": {}, + } + + +@pytest.fixture +def mock_pre_check_response() -> dict[str, Any]: + """Mock Gateway Mode pre-check response.""" + return { + "context_id": "ctx-123", + "approved": True, + "approved_data": {"patients": ["patient-1", "patient-2"]}, + "policies": ["hipaa", "gdpr"], + "rate_limit": { + "limit": 100, + "remaining": 99, + "reset_at": "2025-12-05T00:00:00Z", + }, + "expires_at": "2025-12-04T13:00:00Z", + "block_reason": None, + } + + +@pytest.fixture +def mock_audit_response() -> dict[str, Any]: + """Mock Gateway Mode audit response.""" + return { + "success": True, + "audit_id": "audit-456", + } diff --git a/sdk/python/tests/integration/__init__.py b/sdk/python/tests/integration/__init__.py new file mode 100644 index 00000000..b222b132 --- /dev/null +++ b/sdk/python/tests/integration/__init__.py @@ -0,0 +1,4 @@ +"""Integration tests for AxonFlow SDK. + +These tests require a live AxonFlow Agent and are marked with the 'integration' marker. +""" diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py new file mode 100644 index 00000000..a6954f91 --- /dev/null +++ b/sdk/python/tests/test_client.py @@ -0,0 +1,492 @@ +"""Unit tests for AxonFlow client.""" + +from __future__ import annotations + +from typing import Any + +import pytest +from pytest_httpx import HTTPXMock + +from axonflow import AxonFlow, Mode +from axonflow.exceptions import ( + AuthenticationError, + AxonFlowError, + PolicyViolationError, +) + + +class TestClientInitialization: + """Test client initialization.""" + + def test_creates_with_required_params(self, config_dict: dict[str, Any]) -> None: + """Test client creates with required parameters.""" + client = AxonFlow(**config_dict) + assert client.config.agent_url == config_dict["agent_url"] + assert client.config.client_id == config_dict["client_id"] + + def test_default_values_applied(self, config_dict: dict[str, Any]) -> None: + """Test default configuration values.""" + client = AxonFlow(**config_dict) + assert client.config.timeout == 60.0 + assert client.config.retry.enabled is True + assert client.config.retry.max_attempts == 3 + + def test_sandbox_mode(self) -> None: + """Test sandbox client creation.""" + client = AxonFlow.sandbox() + assert "staging" in client.config.agent_url + assert client.config.debug is True + assert client.config.mode == Mode.SANDBOX + + def test_mode_string_conversion(self, config_dict: dict[str, Any]) -> None: + """Test mode string is converted to enum.""" + client = AxonFlow(**config_dict, mode="sandbox") + assert client.config.mode == Mode.SANDBOX + + def test_url_trailing_slash_stripped(self) -> None: + """Test trailing slash is stripped from URL.""" + client = AxonFlow( + agent_url="https://test.axonflow.com/", + client_id="test", + client_secret="test", + ) + assert client.config.agent_url == "https://test.axonflow.com" + + def test_license_key_optional(self, config_dict: dict[str, Any]) -> None: + """Test license key is optional.""" + client = AxonFlow(**config_dict) + assert client.config.license_key is None + + client_with_license = AxonFlow(**config_dict, license_key="license-123") + assert client_with_license.config.license_key == "license-123" + + +class TestHealthCheck: + """Test health check functionality.""" + + @pytest.mark.asyncio + async def test_health_check_success( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_health_response: dict[str, Any], + ) -> None: + """Test successful health check.""" + httpx_mock.add_response(json=mock_health_response) + result = await client.health_check() + assert result is True + + @pytest.mark.asyncio + async def test_health_check_failure( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test health check returns False on error.""" + httpx_mock.add_response(status_code=500) + result = await client.health_check() + assert result is False + + @pytest.mark.asyncio + async def test_health_check_unhealthy( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test health check returns False when unhealthy.""" + httpx_mock.add_response(json={"status": "unhealthy"}) + result = await client.health_check() + assert result is False + + +class TestExecuteQuery: + """Test query execution.""" + + @pytest.mark.asyncio + async def test_successful_query( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_query_response: dict[str, Any], + ) -> None: + """Test successful query execution.""" + httpx_mock.add_response(json=mock_query_response) + + result = await client.execute_query( + user_token="test-token", + query="What is AI?", + request_type="chat", + ) + + assert result.success is True + assert result.blocked is False + assert result.data == {"result": "test result"} + + @pytest.mark.asyncio + async def test_blocked_by_policy( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_blocked_response: dict[str, Any], + ) -> None: + """Test query blocked by policy raises exception.""" + httpx_mock.add_response(json=mock_blocked_response) + + with pytest.raises(PolicyViolationError) as exc_info: + await client.execute_query( + user_token="test-token", + query="What is AI?", + request_type="chat", + ) + + assert "Rate limit" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_authentication_error( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test authentication error handling.""" + httpx_mock.add_response(status_code=401) + + with pytest.raises(AuthenticationError): + await client.execute_query( + user_token="bad-token", + query="test", + request_type="chat", + ) + + @pytest.mark.asyncio + async def test_policy_violation_403( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test 403 response raises PolicyViolationError.""" + httpx_mock.add_response( + status_code=403, + json={ + "message": "Access denied", + "policy": "access-control", + "block_reason": "Insufficient permissions", + }, + ) + + with pytest.raises(PolicyViolationError) as exc_info: + await client.execute_query( + user_token="test", + query="test", + request_type="chat", + ) + + assert exc_info.value.policy == "access-control" + + @pytest.mark.asyncio + async def test_generic_http_error( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test generic HTTP error handling.""" + httpx_mock.add_response(status_code=500, text="Internal Server Error") + + with pytest.raises(AxonFlowError) as exc_info: + await client.execute_query( + user_token="test", + query="test", + request_type="chat", + ) + + assert "HTTP 500" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_query_with_context( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_query_response: dict[str, Any], + ) -> None: + """Test query with additional context.""" + httpx_mock.add_response(json=mock_query_response) + + result = await client.execute_query( + user_token="test-token", + query="What is AI?", + request_type="chat", + context={"session_id": "123", "user_role": "admin"}, + ) + + assert result.success is True + + +class TestCaching: + """Test response caching.""" + + @pytest.mark.asyncio + async def test_cache_hit( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + mock_query_response: dict[str, Any], + ) -> None: + """Test cache returns same response.""" + httpx_mock.add_response(json=mock_query_response) + + async with AxonFlow(**config_dict) as client: + # First call + result1 = await client.execute_query( + user_token="test", + query="cached query", + request_type="chat", + ) + + # Second call - should hit cache + result2 = await client.execute_query( + user_token="test", + query="cached query", + request_type="chat", + ) + + assert result1.data == result2.data + # Only one HTTP request should have been made + assert len(httpx_mock.get_requests()) == 1 + + @pytest.mark.asyncio + async def test_cache_miss_different_query( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + mock_query_response: dict[str, Any], + ) -> None: + """Test cache miss for different queries.""" + httpx_mock.add_response(json=mock_query_response) + httpx_mock.add_response(json=mock_query_response) + + async with AxonFlow(**config_dict) as client: + await client.execute_query("test", "query1", "chat") + await client.execute_query("test", "query2", "chat") + + # Two HTTP requests should have been made + assert len(httpx_mock.get_requests()) == 2 + + @pytest.mark.asyncio + async def test_cache_disabled( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + mock_query_response: dict[str, Any], + ) -> None: + """Test caching can be disabled.""" + httpx_mock.add_response(json=mock_query_response) + httpx_mock.add_response(json=mock_query_response) + + async with AxonFlow(**config_dict, cache_enabled=False) as client: + await client.execute_query("test", "query", "chat") + await client.execute_query("test", "query", "chat") + + # Both requests should be made + assert len(httpx_mock.get_requests()) == 2 + + +class TestConnectors: + """Test connector operations.""" + + @pytest.mark.asyncio + async def test_list_connectors( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_connector_list: list[dict[str, Any]], + ) -> None: + """Test listing connectors.""" + httpx_mock.add_response(json=mock_connector_list) + + connectors = await client.list_connectors() + + assert len(connectors) == 2 + assert connectors[0].id == "postgres" + assert connectors[0].installed is True + assert connectors[1].id == "salesforce" + assert connectors[1].installed is False + + @pytest.mark.asyncio + async def test_install_connector( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test installing a connector.""" + from axonflow import ConnectorInstallRequest + + httpx_mock.add_response(status_code=201) + + await client.install_connector( + ConnectorInstallRequest( + connector_id="salesforce", + name="My Salesforce", + tenant_id="tenant-123", + options={"api_version": "v55.0"}, + credentials={"api_key": "secret"}, + ) + ) + + request = httpx_mock.get_requests()[0] + assert "/api/connectors/install" in str(request.url) + + @pytest.mark.asyncio + async def test_query_connector( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test querying a connector.""" + httpx_mock.add_response( + json={ + "success": True, + "data": {"rows": [{"id": 1}, {"id": 2}]}, + } + ) + + result = await client.query_connector( + user_token="test", + connector_name="postgres", + operation="query", + params={"sql": "SELECT * FROM users"}, + ) + + assert result.success is True + assert result.data["rows"][0]["id"] == 1 + + +class TestPlanning: + """Test multi-agent planning.""" + + @pytest.mark.asyncio + async def test_generate_plan( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_plan_response: dict[str, Any], + ) -> None: + """Test plan generation.""" + httpx_mock.add_response(json=mock_plan_response) + + plan = await client.generate_plan( + query="Book a flight and hotel for my trip", + domain="travel", + ) + + assert plan.plan_id == "plan-123" + assert len(plan.steps) == 2 + assert plan.steps[0].name == "Fetch data" + assert plan.steps[1].depends_on == ["step-1"] + + @pytest.mark.asyncio + async def test_execute_plan( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test plan execution.""" + httpx_mock.add_response( + json={ + "success": True, + "result": "Trip booked successfully", + "metadata": { + "duration": "5.2s", + "step_results": {"step-1": "done", "step-2": "done"}, + }, + } + ) + + result = await client.execute_plan("plan-123") + + assert result.status == "completed" + assert result.result == "Trip booked successfully" + + @pytest.mark.asyncio + async def test_get_plan_status( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test getting plan status.""" + httpx_mock.add_response( + json={ + "plan_id": "plan-123", + "status": "running", + "step_results": {"step-1": "done"}, + } + ) + + result = await client.get_plan_status("plan-123") + + assert result.status == "running" + + +class TestSyncClient: + """Test synchronous client wrapper.""" + + def test_sync_health_check( + self, + sync_client, + httpx_mock: HTTPXMock, + mock_health_response: dict[str, Any], + ) -> None: + """Test sync health check.""" + httpx_mock.add_response(json=mock_health_response) + result = sync_client.health_check() + assert result is True + + def test_sync_execute_query( + self, + sync_client, + httpx_mock: HTTPXMock, + mock_query_response: dict[str, Any], + ) -> None: + """Test sync query execution.""" + httpx_mock.add_response(json=mock_query_response) + result = sync_client.execute_query("test", "query", "chat") + assert result.success is True + + def test_sync_context_manager( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + mock_health_response: dict[str, Any], + ) -> None: + """Test sync context manager.""" + httpx_mock.add_response(json=mock_health_response) + + with AxonFlow.sync(**config_dict) as client: + result = client.health_check() + assert result is True + + +class TestContextManager: + """Test async context manager.""" + + @pytest.mark.asyncio + async def test_async_context_manager( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + mock_health_response: dict[str, Any], + ) -> None: + """Test async context manager.""" + httpx_mock.add_response(json=mock_health_response) + + async with AxonFlow(**config_dict) as client: + result = await client.health_check() + assert result is True + + @pytest.mark.asyncio + async def test_explicit_close( + self, + config_dict: dict[str, Any], + ) -> None: + """Test explicit close.""" + client = AxonFlow(**config_dict) + await client.close() + # No exception should be raised diff --git a/sdk/python/tests/test_exceptions.py b/sdk/python/tests/test_exceptions.py new file mode 100644 index 00000000..a3f72670 --- /dev/null +++ b/sdk/python/tests/test_exceptions.py @@ -0,0 +1,214 @@ +"""Tests for exception classes.""" + +from __future__ import annotations + +import pytest + +from axonflow.exceptions import ( + AuthenticationError, + AxonFlowError, + ConfigurationError, + ConnectionError, + ConnectorError, + PlanExecutionError, + PolicyViolationError, + RateLimitError, + TimeoutError, +) + + +class TestAxonFlowError: + """Test base AxonFlowError class.""" + + def test_basic_error(self) -> None: + """Test basic error creation.""" + error = AxonFlowError("Something went wrong") + assert str(error) == "Something went wrong" + assert error.message == "Something went wrong" + assert error.details == {} + + def test_error_with_details(self) -> None: + """Test error with details.""" + error = AxonFlowError( + "Error occurred", + details={"code": "ERR001", "context": "testing"}, + ) + assert error.details["code"] == "ERR001" + assert error.details["context"] == "testing" + + def test_error_is_exception(self) -> None: + """Test error inherits from Exception.""" + error = AxonFlowError("Test") + assert isinstance(error, Exception) + + def test_error_can_be_raised(self) -> None: + """Test error can be raised and caught.""" + with pytest.raises(AxonFlowError) as exc_info: + raise AxonFlowError("Test error") + assert "Test error" in str(exc_info.value) + + +class TestConfigurationError: + """Test ConfigurationError class.""" + + def test_configuration_error(self) -> None: + """Test configuration error.""" + error = ConfigurationError("Invalid config") + assert isinstance(error, AxonFlowError) + assert error.message == "Invalid config" + + +class TestAuthenticationError: + """Test AuthenticationError class.""" + + def test_authentication_error(self) -> None: + """Test authentication error.""" + error = AuthenticationError("Invalid credentials") + assert isinstance(error, AxonFlowError) + assert error.message == "Invalid credentials" + + +class TestPolicyViolationError: + """Test PolicyViolationError class.""" + + def test_basic_policy_error(self) -> None: + """Test basic policy violation error.""" + error = PolicyViolationError("Request blocked") + assert error.message == "Request blocked" + assert error.policy is None + assert error.block_reason is None + + def test_policy_error_with_details(self) -> None: + """Test policy error with full details.""" + error = PolicyViolationError( + "Request blocked by policy", + policy="rate-limit", + block_reason="Too many requests", + ) + assert error.policy == "rate-limit" + assert error.block_reason == "Too many requests" + assert error.details["policy"] == "rate-limit" + + def test_policy_error_inheritance(self) -> None: + """Test policy error inherits from AxonFlowError.""" + error = PolicyViolationError("Test") + assert isinstance(error, AxonFlowError) + + +class TestRateLimitError: + """Test RateLimitError class.""" + + def test_rate_limit_error(self) -> None: + """Test rate limit error.""" + error = RateLimitError( + "Rate limit exceeded", + limit=100, + remaining=0, + reset_at="2025-12-04T13:00:00Z", + ) + assert error.limit == 100 + assert error.remaining == 0 + assert error.reset_at == "2025-12-04T13:00:00Z" + + def test_rate_limit_details(self) -> None: + """Test rate limit error details.""" + error = RateLimitError("Exceeded", limit=50, remaining=0) + assert error.details["limit"] == 50 + assert error.details["remaining"] == 0 + + +class TestConnectionError: + """Test ConnectionError class.""" + + def test_connection_error(self) -> None: + """Test connection error.""" + error = ConnectionError("Failed to connect") + assert isinstance(error, AxonFlowError) + assert error.message == "Failed to connect" + + +class TestTimeoutError: + """Test TimeoutError class.""" + + def test_timeout_error(self) -> None: + """Test timeout error.""" + error = TimeoutError("Request timed out after 30s") + assert isinstance(error, AxonFlowError) + assert "30s" in error.message + + +class TestConnectorError: + """Test ConnectorError class.""" + + def test_basic_connector_error(self) -> None: + """Test basic connector error.""" + error = ConnectorError("Connector failed") + assert error.connector is None + assert error.operation is None + + def test_connector_error_with_details(self) -> None: + """Test connector error with details.""" + error = ConnectorError( + "Query failed", + connector="postgres", + operation="SELECT", + ) + assert error.connector == "postgres" + assert error.operation == "SELECT" + assert error.details["connector"] == "postgres" + + +class TestPlanExecutionError: + """Test PlanExecutionError class.""" + + def test_basic_plan_error(self) -> None: + """Test basic plan execution error.""" + error = PlanExecutionError("Plan failed") + assert error.plan_id is None + assert error.step is None + + def test_plan_error_with_details(self) -> None: + """Test plan error with details.""" + error = PlanExecutionError( + "Step failed", + plan_id="plan-123", + step="step-2", + ) + assert error.plan_id == "plan-123" + assert error.step == "step-2" + assert error.details["plan_id"] == "plan-123" + + +class TestExceptionHierarchy: + """Test exception hierarchy and catching.""" + + def test_catch_all_axonflow_errors(self) -> None: + """Test all errors can be caught as AxonFlowError.""" + errors = [ + ConfigurationError("config"), + AuthenticationError("auth"), + PolicyViolationError("policy"), + RateLimitError("rate", limit=100, remaining=0), + ConnectionError("connection"), + TimeoutError("timeout"), + ConnectorError("connector"), + PlanExecutionError("plan"), + ] + + for error in errors: + with pytest.raises(AxonFlowError): + raise error + + def test_specific_catch(self) -> None: + """Test specific exception types can be caught.""" + with pytest.raises(PolicyViolationError): + raise PolicyViolationError("Test") + + # Should not catch other types + with pytest.raises(AuthenticationError): + try: + raise AuthenticationError("Auth failed") + except PolicyViolationError: + pytest.fail("Should not catch PolicyViolationError") + except AuthenticationError: + raise diff --git a/sdk/python/tests/test_gateway.py b/sdk/python/tests/test_gateway.py new file mode 100644 index 00000000..8fd8cc91 --- /dev/null +++ b/sdk/python/tests/test_gateway.py @@ -0,0 +1,304 @@ +"""Tests for Gateway Mode functionality.""" + +from __future__ import annotations + +from typing import Any + +import pytest +from pytest_httpx import HTTPXMock + +from axonflow import AxonFlow, TokenUsage + + +class TestGatewayPreCheck: + """Test Gateway Mode pre-check functionality.""" + + @pytest.mark.asyncio + async def test_pre_check_approved( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_pre_check_response: dict[str, Any], + ) -> None: + """Test successful pre-check approval.""" + httpx_mock.add_response(json=mock_pre_check_response) + + result = await client.get_policy_approved_context( + user_token="user-jwt", + query="Find patients with diabetes", + data_sources=["postgres"], + ) + + assert result.approved is True + assert result.context_id == "ctx-123" + assert "patients" in result.approved_data + assert "hipaa" in result.policies + + @pytest.mark.asyncio + async def test_pre_check_with_rate_limit( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_pre_check_response: dict[str, Any], + ) -> None: + """Test pre-check with rate limit info.""" + httpx_mock.add_response(json=mock_pre_check_response) + + result = await client.get_policy_approved_context( + user_token="user-jwt", + query="Find patients", + ) + + assert result.rate_limit_info is not None + assert result.rate_limit_info.limit == 100 + assert result.rate_limit_info.remaining == 99 + + @pytest.mark.asyncio + async def test_pre_check_blocked( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test pre-check when blocked.""" + httpx_mock.add_response( + json={ + "context_id": "ctx-blocked", + "approved": False, + "approved_data": {}, + "policies": ["hipaa"], + "expires_at": "2025-12-04T13:00:00Z", + "block_reason": "Sensitive patient data access requires approval", + } + ) + + result = await client.get_policy_approved_context( + user_token="user-jwt", + query="Show all SSN numbers", + ) + + assert result.approved is False + assert result.block_reason is not None + assert "approval" in result.block_reason.lower() + + @pytest.mark.asyncio + async def test_pre_check_with_context( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_pre_check_response: dict[str, Any], + ) -> None: + """Test pre-check with additional context.""" + httpx_mock.add_response(json=mock_pre_check_response) + + result = await client.get_policy_approved_context( + user_token="user-jwt", + query="Find patients", + data_sources=["postgres", "salesforce"], + context={"department": "cardiology", "urgency": "high"}, + ) + + request = httpx_mock.get_requests()[0] + assert request is not None + assert "/api/policy/pre-check" in str(request.url) + + @pytest.mark.asyncio + async def test_pre_check_no_rate_limit( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + ) -> None: + """Test pre-check without rate limit info.""" + httpx_mock.add_response( + json={ + "context_id": "ctx-123", + "approved": True, + "approved_data": {}, + "policies": [], + "expires_at": "2025-12-04T13:00:00Z", + } + ) + + result = await client.get_policy_approved_context( + user_token="user-jwt", + query="Simple query", + ) + + assert result.rate_limit_info is None + + +class TestGatewayAudit: + """Test Gateway Mode audit functionality.""" + + @pytest.mark.asyncio + async def test_audit_llm_call( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_audit_response: dict[str, Any], + ) -> None: + """Test successful audit logging.""" + httpx_mock.add_response(json=mock_audit_response) + + result = await client.audit_llm_call( + context_id="ctx-123", + response_summary="Found 5 patients with recent lab results", + provider="openai", + model="gpt-4", + token_usage=TokenUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ), + latency_ms=250, + ) + + assert result.success is True + assert result.audit_id == "audit-456" + + @pytest.mark.asyncio + async def test_audit_with_metadata( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_audit_response: dict[str, Any], + ) -> None: + """Test audit with additional metadata.""" + httpx_mock.add_response(json=mock_audit_response) + + result = await client.audit_llm_call( + context_id="ctx-123", + response_summary="Generated summary", + provider="anthropic", + model="claude-3-sonnet", + token_usage=TokenUsage( + prompt_tokens=200, + completion_tokens=100, + total_tokens=300, + ), + latency_ms=500, + metadata={ + "session_id": "session-123", + "request_id": "req-456", + "user_feedback": "helpful", + }, + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_audit_request_body( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_audit_response: dict[str, Any], + ) -> None: + """Test audit request body format.""" + httpx_mock.add_response(json=mock_audit_response) + + await client.audit_llm_call( + context_id="ctx-789", + response_summary="Test summary", + provider="bedrock", + model="claude-v2", + token_usage=TokenUsage( + prompt_tokens=50, + completion_tokens=25, + total_tokens=75, + ), + latency_ms=150, + ) + + request = httpx_mock.get_requests()[0] + assert "/api/audit/llm-call" in str(request.url) + + +class TestGatewayModeFlow: + """Test complete Gateway Mode flow.""" + + @pytest.mark.asyncio + async def test_complete_gateway_flow( + self, + client: AxonFlow, + httpx_mock: HTTPXMock, + mock_pre_check_response: dict[str, Any], + mock_audit_response: dict[str, Any], + ) -> None: + """Test complete pre-check -> LLM call -> audit flow.""" + # Mock pre-check + httpx_mock.add_response(json=mock_pre_check_response) + # Mock audit + httpx_mock.add_response(json=mock_audit_response) + + # Step 1: Pre-check + ctx = await client.get_policy_approved_context( + user_token="user-jwt", + query="Find patients with diabetes", + data_sources=["postgres"], + ) + + assert ctx.approved is True + context_id = ctx.context_id + + # Step 2: (Simulated) LLM call would happen here + # ... + + # Step 3: Audit + audit_result = await client.audit_llm_call( + context_id=context_id, + response_summary="Found 5 patients", + provider="openai", + model="gpt-4", + token_usage=TokenUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ), + latency_ms=250, + ) + + assert audit_result.success is True + assert len(httpx_mock.get_requests()) == 2 + + +class TestSyncGatewayMode: + """Test Gateway Mode with sync client.""" + + def test_sync_pre_check( + self, + sync_client, + httpx_mock: HTTPXMock, + mock_pre_check_response: dict[str, Any], + ) -> None: + """Test sync pre-check.""" + httpx_mock.add_response(json=mock_pre_check_response) + + result = sync_client.get_policy_approved_context( + user_token="user-jwt", + query="Test query", + ) + + assert result.approved is True + + def test_sync_audit( + self, + sync_client, + httpx_mock: HTTPXMock, + mock_audit_response: dict[str, Any], + ) -> None: + """Test sync audit.""" + httpx_mock.add_response(json=mock_audit_response) + + result = sync_client.audit_llm_call( + context_id="ctx-123", + response_summary="Test", + provider="openai", + model="gpt-4", + token_usage=TokenUsage( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + ), + latency_ms=100, + ) + + assert result.success is True diff --git a/sdk/python/tests/test_interceptors.py b/sdk/python/tests/test_interceptors.py new file mode 100644 index 00000000..848891f8 --- /dev/null +++ b/sdk/python/tests/test_interceptors.py @@ -0,0 +1,233 @@ +"""Tests for LLM provider interceptors.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pytest_httpx import HTTPXMock + +from axonflow import AxonFlow +from axonflow.exceptions import PolicyViolationError +from axonflow.interceptors.anthropic import ( + AnthropicInterceptor, + wrap_anthropic_client, +) +from axonflow.interceptors.openai import OpenAIInterceptor, wrap_openai_client + + +class TestOpenAIInterceptor: + """Test OpenAI interceptor.""" + + def test_get_provider_name(self) -> None: + """Test provider name.""" + client = MagicMock() + axonflow = MagicMock() + interceptor = OpenAIInterceptor(axonflow) + assert interceptor.get_provider_name() == "openai" + + def test_extract_prompt_from_messages(self) -> None: + """Test prompt extraction from messages.""" + axonflow = MagicMock() + interceptor = OpenAIInterceptor(axonflow) + + prompt = interceptor.extract_prompt( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello world"}, + ] + ) + + assert "You are helpful" in prompt + assert "Hello world" in prompt + + def test_extract_prompt_empty_messages(self) -> None: + """Test prompt extraction with empty messages.""" + axonflow = MagicMock() + interceptor = OpenAIInterceptor(axonflow) + + prompt = interceptor.extract_prompt(messages=[]) + assert prompt == "" + + @pytest.mark.asyncio + async def test_wrap_async_openai_client( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + ) -> None: + """Test wrapping async OpenAI client.""" + # Mock AxonFlow response + httpx_mock.add_response( + json={ + "success": True, + "blocked": False, + "data": None, + } + ) + + async with AxonFlow(**config_dict) as axonflow: + # Create mock OpenAI client + mock_openai = MagicMock() + mock_openai.chat.completions.create = AsyncMock( + return_value={"choices": [{"message": {"content": "Hello!"}}]} + ) + + # Wrap it + wrapped = wrap_openai_client(mock_openai, axonflow, user_token="test") + + # Call it + result = await wrapped.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert result is not None + + @pytest.mark.asyncio + async def test_openai_blocked_by_policy( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + ) -> None: + """Test OpenAI call blocked by policy.""" + httpx_mock.add_response( + json={ + "success": False, + "blocked": True, + "block_reason": "Sensitive content detected", + } + ) + + async with AxonFlow(**config_dict) as axonflow: + mock_openai = MagicMock() + mock_openai.chat.completions.create = AsyncMock() + + wrapped = wrap_openai_client(mock_openai, axonflow) + + with pytest.raises(PolicyViolationError) as exc_info: + await wrapped.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Test"}], + ) + + assert "Sensitive content" in str(exc_info.value) + + +class TestAnthropicInterceptor: + """Test Anthropic interceptor.""" + + def test_get_provider_name(self) -> None: + """Test provider name.""" + axonflow = MagicMock() + interceptor = AnthropicInterceptor(axonflow) + assert interceptor.get_provider_name() == "anthropic" + + def test_extract_prompt_string_content(self) -> None: + """Test prompt extraction with string content.""" + axonflow = MagicMock() + interceptor = AnthropicInterceptor(axonflow) + + prompt = interceptor.extract_prompt( + messages=[ + {"role": "user", "content": "Hello Claude"}, + ] + ) + + assert "Hello Claude" in prompt + + def test_extract_prompt_block_content(self) -> None: + """Test prompt extraction with content blocks.""" + axonflow = MagicMock() + interceptor = AnthropicInterceptor(axonflow) + + prompt = interceptor.extract_prompt( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image", "source": {"type": "url", "url": "..."}}, + ], + }, + ] + ) + + assert "What is this" in prompt + + @pytest.mark.asyncio + async def test_wrap_async_anthropic_client( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + ) -> None: + """Test wrapping async Anthropic client.""" + httpx_mock.add_response( + json={ + "success": True, + "blocked": False, + "data": None, + } + ) + + async with AxonFlow(**config_dict) as axonflow: + mock_anthropic = MagicMock() + mock_anthropic.messages.create = AsyncMock( + return_value={"content": [{"type": "text", "text": "Hello!"}]} + ) + + wrapped = wrap_anthropic_client(mock_anthropic, axonflow, user_token="test") + + result = await wrapped.messages.create( + model="claude-3-sonnet-20240229", + max_tokens=1024, + messages=[{"role": "user", "content": "Hi"}], + ) + + assert result is not None + + @pytest.mark.asyncio + async def test_anthropic_blocked_by_policy( + self, + config_dict: dict[str, Any], + httpx_mock: HTTPXMock, + ) -> None: + """Test Anthropic call blocked by policy.""" + httpx_mock.add_response( + json={ + "success": False, + "blocked": True, + "block_reason": "Rate limit exceeded", + } + ) + + async with AxonFlow(**config_dict) as axonflow: + mock_anthropic = MagicMock() + mock_anthropic.messages.create = AsyncMock() + + wrapped = wrap_anthropic_client(mock_anthropic, axonflow) + + with pytest.raises(PolicyViolationError) as exc_info: + await wrapped.messages.create( + model="claude-3-sonnet", + max_tokens=100, + messages=[{"role": "user", "content": "Test"}], + ) + + assert "Rate limit" in str(exc_info.value) + + +class TestInterceptorUserToken: + """Test user token handling in interceptors.""" + + def test_openai_user_token(self) -> None: + """Test OpenAI interceptor stores user token.""" + axonflow = MagicMock() + interceptor = OpenAIInterceptor(axonflow, user_token="user-123") + assert interceptor.user_token == "user-123" + + def test_anthropic_user_token(self) -> None: + """Test Anthropic interceptor stores user token.""" + axonflow = MagicMock() + interceptor = AnthropicInterceptor(axonflow, user_token="user-456") + assert interceptor.user_token == "user-456" diff --git a/sdk/python/tests/test_types.py b/sdk/python/tests/test_types.py new file mode 100644 index 00000000..00871a5e --- /dev/null +++ b/sdk/python/tests/test_types.py @@ -0,0 +1,335 @@ +"""Tests for type definitions and validation.""" + +from __future__ import annotations + +from datetime import datetime + +import pytest +from pydantic import ValidationError + +from axonflow.types import ( + AuditResult, + AxonFlowConfig, + CacheConfig, + ClientRequest, + ClientResponse, + ConnectorInstallRequest, + ConnectorMetadata, + ConnectorResponse, + Mode, + PlanExecutionResponse, + PlanResponse, + PlanStep, + PolicyApprovalResult, + PolicyEvaluationInfo, + RateLimitInfo, + RetryConfig, + TokenUsage, +) + + +class TestMode: + """Test Mode enum.""" + + def test_production_mode(self) -> None: + """Test production mode value.""" + assert Mode.PRODUCTION.value == "production" + + def test_sandbox_mode(self) -> None: + """Test sandbox mode value.""" + assert Mode.SANDBOX.value == "sandbox" + + def test_mode_from_string(self) -> None: + """Test creating mode from string.""" + assert Mode("production") == Mode.PRODUCTION + assert Mode("sandbox") == Mode.SANDBOX + + +class TestRetryConfig: + """Test RetryConfig model.""" + + def test_default_values(self) -> None: + """Test default configuration values.""" + config = RetryConfig() + assert config.enabled is True + assert config.max_attempts == 3 + assert config.initial_delay == 1.0 + assert config.max_delay == 30.0 + assert config.exponential_base == 2.0 + + def test_custom_values(self) -> None: + """Test custom configuration values.""" + config = RetryConfig( + enabled=False, + max_attempts=5, + initial_delay=2.0, + max_delay=60.0, + exponential_base=3.0, + ) + assert config.enabled is False + assert config.max_attempts == 5 + + def test_validation_max_attempts(self) -> None: + """Test max_attempts validation.""" + with pytest.raises(ValidationError): + RetryConfig(max_attempts=0) # min is 1 + + with pytest.raises(ValidationError): + RetryConfig(max_attempts=11) # max is 10 + + def test_frozen(self) -> None: + """Test config is immutable.""" + config = RetryConfig() + with pytest.raises(ValidationError): + config.enabled = False # type: ignore[misc] + + +class TestCacheConfig: + """Test CacheConfig model.""" + + def test_default_values(self) -> None: + """Test default configuration values.""" + config = CacheConfig() + assert config.enabled is True + assert config.ttl == 60.0 + assert config.max_size == 1000 + + def test_validation_ttl(self) -> None: + """Test TTL must be positive.""" + with pytest.raises(ValidationError): + CacheConfig(ttl=0) + + with pytest.raises(ValidationError): + CacheConfig(ttl=-1) + + +class TestAxonFlowConfig: + """Test AxonFlowConfig model.""" + + def test_required_fields(self) -> None: + """Test required fields validation.""" + with pytest.raises(ValidationError): + AxonFlowConfig() # type: ignore[call-arg] + + def test_minimal_config(self) -> None: + """Test minimal configuration.""" + config = AxonFlowConfig( + agent_url="https://test.com", + client_id="client", + client_secret="secret", + ) + assert config.agent_url == "https://test.com" + assert config.mode == Mode.PRODUCTION + assert config.debug is False + + def test_full_config(self) -> None: + """Test full configuration.""" + config = AxonFlowConfig( + agent_url="https://test.com", + client_id="client", + client_secret="secret", + license_key="license-123", + mode=Mode.SANDBOX, + debug=True, + timeout=30.0, + insecure_skip_verify=True, + ) + assert config.license_key == "license-123" + assert config.mode == Mode.SANDBOX + assert config.timeout == 30.0 + + def test_empty_url_validation(self) -> None: + """Test empty URL validation.""" + with pytest.raises(ValidationError): + AxonFlowConfig( + agent_url="", + client_id="client", + client_secret="secret", + ) + + +class TestClientRequest: + """Test ClientRequest model.""" + + def test_create_request(self) -> None: + """Test creating a client request.""" + request = ClientRequest( + query="What is AI?", + user_token="token-123", + client_id="client", + request_type="chat", + ) + assert request.query == "What is AI?" + assert request.context == {} + + def test_request_with_context(self) -> None: + """Test request with context.""" + request = ClientRequest( + query="Query", + user_token="token", + client_id="client", + request_type="sql", + context={"database": "postgres"}, + ) + assert request.context["database"] == "postgres" + + +class TestClientResponse: + """Test ClientResponse model.""" + + def test_successful_response(self) -> None: + """Test successful response.""" + response = ClientResponse( + success=True, + data={"result": "test"}, + ) + assert response.success is True + assert response.blocked is False + assert response.data == {"result": "test"} + + def test_blocked_response(self) -> None: + """Test blocked response.""" + response = ClientResponse( + success=False, + blocked=True, + block_reason="Policy violation", + ) + assert response.success is False + assert response.blocked is True + assert response.block_reason == "Policy violation" + + def test_response_with_policy_info(self) -> None: + """Test response with policy evaluation info.""" + response = ClientResponse( + success=True, + policy_info=PolicyEvaluationInfo( + policies_evaluated=["policy1", "policy2"], + processing_time="10ms", + ), + ) + assert response.policy_info is not None + assert len(response.policy_info.policies_evaluated) == 2 + + +class TestConnectorTypes: + """Test connector-related types.""" + + def test_connector_metadata(self) -> None: + """Test ConnectorMetadata model.""" + metadata = ConnectorMetadata( + id="postgres", + name="PostgreSQL", + type="database", + version="1.0.0", + ) + assert metadata.id == "postgres" + assert metadata.installed is False + + def test_connector_install_request(self) -> None: + """Test ConnectorInstallRequest model.""" + request = ConnectorInstallRequest( + connector_id="postgres", + name="My DB", + tenant_id="tenant-123", + credentials={"password": "secret"}, + ) + assert request.credentials["password"] == "secret" + + def test_connector_response(self) -> None: + """Test ConnectorResponse model.""" + response = ConnectorResponse( + success=True, + data={"rows": []}, + ) + assert response.success is True + + +class TestPlanTypes: + """Test planning-related types.""" + + def test_plan_step(self) -> None: + """Test PlanStep model.""" + step = PlanStep( + id="step-1", + name="Fetch data", + type="data", + depends_on=["step-0"], + ) + assert step.id == "step-1" + assert "step-0" in step.depends_on + + def test_plan_response(self) -> None: + """Test PlanResponse model.""" + plan = PlanResponse( + plan_id="plan-123", + steps=[ + PlanStep(id="s1", name="Step 1", type="data"), + PlanStep(id="s2", name="Step 2", type="process"), + ], + domain="travel", + ) + assert plan.plan_id == "plan-123" + assert len(plan.steps) == 2 + assert plan.domain == "travel" + + def test_plan_execution_response(self) -> None: + """Test PlanExecutionResponse model.""" + result = PlanExecutionResponse( + plan_id="plan-123", + status="completed", + result="Success", + duration="5.2s", + ) + assert result.status == "completed" + + +class TestGatewayTypes: + """Test Gateway Mode types.""" + + def test_rate_limit_info(self) -> None: + """Test RateLimitInfo model.""" + info = RateLimitInfo( + limit=100, + remaining=50, + reset_at=datetime.fromisoformat("2025-12-04T12:00:00"), + ) + assert info.limit == 100 + assert info.remaining == 50 + + def test_policy_approval_result(self) -> None: + """Test PolicyApprovalResult model.""" + result = PolicyApprovalResult( + context_id="ctx-123", + approved=True, + expires_at=datetime.fromisoformat("2025-12-04T13:00:00"), + ) + assert result.context_id == "ctx-123" + assert result.approved is True + assert result.approved_data == {} + + def test_token_usage(self) -> None: + """Test TokenUsage model.""" + usage = TokenUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ) + assert usage.total_tokens == 150 + + def test_token_usage_validation(self) -> None: + """Test TokenUsage validation.""" + with pytest.raises(ValidationError): + TokenUsage( + prompt_tokens=-1, + completion_tokens=0, + total_tokens=0, + ) + + def test_audit_result(self) -> None: + """Test AuditResult model.""" + result = AuditResult( + success=True, + audit_id="audit-456", + ) + assert result.success is True + assert result.audit_id == "audit-456" diff --git a/sdk/python/tests/test_utils.py b/sdk/python/tests/test_utils.py new file mode 100644 index 00000000..d1a68036 --- /dev/null +++ b/sdk/python/tests/test_utils.py @@ -0,0 +1,191 @@ +"""Tests for utility modules.""" + +from __future__ import annotations + +from axonflow.types import RetryConfig +from axonflow.utils.cache import CacheManager +from axonflow.utils.retry import RetryHandler, create_retry_decorator, with_retry + + +class TestCacheManager: + """Test CacheManager utility.""" + + def test_create_cache(self) -> None: + """Test creating cache manager.""" + cache: CacheManager[str] = CacheManager(maxsize=100, ttl=30.0) + assert cache.maxsize == 100 + assert cache.ttl == 30.0 + + def test_set_and_get(self) -> None: + """Test setting and getting values.""" + cache: CacheManager[str] = CacheManager() + cache.set("key1", "value1") + assert cache.get("key1") == "value1" + + def test_get_missing_key(self) -> None: + """Test getting non-existent key.""" + cache: CacheManager[str] = CacheManager() + assert cache.get("missing") is None + + def test_delete(self) -> None: + """Test deleting a key.""" + cache: CacheManager[str] = CacheManager() + cache.set("key", "value") + cache.delete("key") + assert cache.get("key") is None + + def test_delete_missing_key(self) -> None: + """Test deleting non-existent key doesn't raise.""" + cache: CacheManager[str] = CacheManager() + cache.delete("missing") # Should not raise + + def test_clear(self) -> None: + """Test clearing all values.""" + cache: CacheManager[str] = CacheManager() + cache.set("key1", "value1") + cache.set("key2", "value2") + cache.clear() + assert cache.size == 0 + + def test_contains(self) -> None: + """Test checking if key exists.""" + cache: CacheManager[str] = CacheManager() + cache.set("key", "value") + assert cache.contains("key") is True + assert cache.contains("missing") is False + + def test_size(self) -> None: + """Test getting cache size.""" + cache: CacheManager[str] = CacheManager() + assert cache.size == 0 + cache.set("key1", "value1") + assert cache.size == 1 + cache.set("key2", "value2") + assert cache.size == 2 + + def test_get_or_set(self) -> None: + """Test get_or_set functionality.""" + cache: CacheManager[int] = CacheManager() + factory_called = [0] + + def factory() -> int: + factory_called[0] += 1 + return 42 + + # First call should invoke factory + result1 = cache.get_or_set("key", factory) + assert result1 == 42 + assert factory_called[0] == 1 + + # Second call should use cached value + result2 = cache.get_or_set("key", factory) + assert result2 == 42 + assert factory_called[0] == 1 # Factory not called again + + def test_generic_typing(self) -> None: + """Test generic typing works.""" + int_cache: CacheManager[int] = CacheManager() + int_cache.set("number", 42) + assert int_cache.get("number") == 42 + + dict_cache: CacheManager[dict[str, str]] = CacheManager() + dict_cache.set("data", {"key": "value"}) + assert dict_cache.get("data") == {"key": "value"} + + +class TestRetryHandler: + """Test RetryHandler utility.""" + + def test_create_handler(self) -> None: + """Test creating retry handler.""" + config = RetryConfig() + handler = RetryHandler(config) + assert handler.config == config + + def test_disabled_retry(self) -> None: + """Test disabled retry returns identity decorator.""" + config = RetryConfig(enabled=False) + handler = RetryHandler(config) + + def my_func() -> str: + return "hello" + + decorator = handler.create_decorator((Exception,)) + decorated = decorator(my_func) + + assert decorated() == "hello" + + def test_enabled_retry_decorator(self) -> None: + """Test enabled retry creates decorator.""" + config = RetryConfig(enabled=True, max_attempts=3) + handler = RetryHandler(config) + + call_count = [0] + + def flaky_func() -> str: + call_count[0] += 1 + if call_count[0] < 3: + msg = "Temporary error" + raise ValueError(msg) + return "success" + + decorator = handler.create_decorator((ValueError,)) + decorated = decorator(flaky_func) + + result = decorated() + assert result == "success" + assert call_count[0] == 3 + + +class TestWithRetryDecorator: + """Test with_retry decorator.""" + + def test_with_retry_basic(self) -> None: + """Test basic retry functionality.""" + call_count = [0] + + @with_retry(max_attempts=3, retry_on=(ValueError,)) + def my_func() -> str: + call_count[0] += 1 + if call_count[0] < 2: + msg = "Error" + raise ValueError(msg) + return "done" + + result = my_func() + assert result == "done" + assert call_count[0] == 2 + + def test_with_retry_exhausted(self) -> None: + """Test retry exhaustion.""" + import pytest + + @with_retry(max_attempts=2, retry_on=(ValueError,)) + def always_fails() -> None: + msg = "Always fails" + raise ValueError(msg) + + with pytest.raises(ValueError): + always_fails() + + +class TestCreateRetryDecorator: + """Test create_retry_decorator function.""" + + def test_create_decorator(self) -> None: + """Test creating decorator from config.""" + config = RetryConfig(enabled=True, max_attempts=2) + decorator = create_retry_decorator(config, (RuntimeError,)) + + call_count = [0] + + def my_func() -> str: + call_count[0] += 1 + if call_count[0] < 2: + msg = "Error" + raise RuntimeError(msg) + return "ok" + + decorated = decorator(my_func) + result = decorated() + assert result == "ok"