diff --git a/AGENTS.md b/AGENTS.md index d3f9f83..5928104 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -67,6 +67,28 @@ Three repos work together: After pushing to the automation repo, update both files in the deploy repo. +## Configuration + +Configuration is centralized in `config.py` using a composed `AppConfig` with typed sections: + +```python +from automation.config import get_config + +config = get_config() +config.service.db_host # ServiceSettings (AUTOMATION_ prefix) +config.storage.file_store # StorageSettings (no prefix, SDK conventions) +config.http.auth_cache_ttl # HttpSettings (AUTOMATION_ prefix) +config.sandbox.max_run_duration # SandboxSettings (AUTOMATION_ prefix) +config.kv.kv_secret # KVSettings (AUTOMATION_ prefix) +config.log.log_level # LogSettings (no prefix) +``` + +**Key principles:** +- Use `get_config().
` instead of deprecated `get_settings()` +- All environment variables documented in config class docstrings +- Protocol constants (WORK_DIR, TARBALL_PATH) in `constants.py` - these cannot be changed without breaking compatibility +- Shared logging context via `log_extra()` from `automation.utils` + ## Build & Test Commands ```bash diff --git a/docs/kv-store-client-guide.md b/docs/kv-store-client-guide.md new file mode 100644 index 0000000..5c600ec --- /dev/null +++ b/docs/kv-store-client-guide.md @@ -0,0 +1,451 @@ +# KV Store Client Guide + +This guide covers how to use the automation KV store API for state persistence between runs. + +## Overview + +The KV store provides a Redis-like key-value interface for automations to persist state between runs. It's designed for small, frequently-accessed data like: + +- Counters and cursors +- Configuration flags +- Small caches (< 64KB recommended) +- Run metadata and logs + +### When to Use + +✅ **Good use cases:** +- Tracking pagination cursors across runs +- Counting events or iterations +- Storing configuration that changes over time +- Caching small computed values + +❌ **Not designed for:** +- Large file storage (use object storage) +- High-throughput queues (use proper message queues) +- Relational data (use a database) +- Storing sensitive credentials (use secrets management) + +### Limitations + +| Limit | Value | Notes | +|-------|-------|-------| +| Max state size | 64 KB | Total size of all keys combined | +| Max key length | 255 chars | Keys are case-sensitive | +| Reserved keys | `$` prefix | System use (e.g., `$version`) | +| Max nesting depth | 32 levels | For nested objects/arrays | + +## Authentication + +All KV endpoints require a JWT token passed via the `Authorization` header: + +``` +Authorization: Bearer +``` + +The token is automatically provided to your automation via the `AUTOMATION_KV_TOKEN` environment variable whenever the service has a KV secret configured. The KV store is always available — there is no per-automation toggle. + +## Basic Operations + +### Get a Value + +```bash +curl -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + "$AUTOMATION_API_URL/v1/kv/mykey" +``` + +**Response:** +```json +{"key": "mykey", "value": {"foo": "bar"}} +``` + +### Set a Value + +```bash +curl -X PUT \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"foo": "bar"}' \ + "$AUTOMATION_API_URL/v1/kv/mykey" +``` + +**Response (201 Created for new key, 200 OK for update):** +```json +{"key": "mykey", "value": {"foo": "bar"}, "created": true, "updated_at": "2024-01-15T10:00:00Z"} +``` + +### Delete a Value + +```bash +curl -X DELETE \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + "$AUTOMATION_API_URL/v1/kv/mykey" +``` + +**Response:** +```json +{"key": "mykey", "deleted": true} +``` + +### List All Keys + +```bash +curl -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + "$AUTOMATION_API_URL/v1/kv" +``` + +**Response:** +```json +{"keys": ["config", "counter", "last_run"], "count": 3} +``` + +## Advanced Operations + +### Nested Paths with PATCH + +Update a nested field without replacing the entire value: + +```bash +curl -X PATCH \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"path": "settings.theme", "value": "dark"}' \ + "$AUTOMATION_API_URL/v1/kv/config" +``` + +### Atomic Counters + +**Increment:** +```bash +curl -X POST \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"by": 1}' \ + "$AUTOMATION_API_URL/v1/kv/counter/incr" +``` + +**Response:** +```json +{"key": "counter", "value": 42} +``` + +**Decrement:** +```bash +curl -X POST \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"by": 5}' \ + "$AUTOMATION_API_URL/v1/kv/counter/decr" +``` + +### List Operations + +**Push to front (LPUSH):** +```bash +curl -X POST \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"value": "new_item"}' \ + "$AUTOMATION_API_URL/v1/kv/queue/lpush" +``` + +**Push to back (RPUSH):** +```bash +curl -X POST \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"value": "new_item"}' \ + "$AUTOMATION_API_URL/v1/kv/queue/rpush" +``` + +**Pop from front (LPOP):** +```bash +curl -X POST \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + "$AUTOMATION_API_URL/v1/kv/queue/lpop" +``` + +**Pop from back (RPOP):** +```bash +curl -X POST \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + "$AUTOMATION_API_URL/v1/kv/queue/rpop" +``` + +### Batch Operations + +Execute multiple operations atomically: + +```bash +curl -X POST \ + -H "Authorization: Bearer $AUTOMATION_KV_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "operations": [ + {"op": "incr", "key": "counter"}, + {"op": "set", "key": "last_run", "value": "2024-01-15"}, + {"op": "rpush", "key": "log", "value": {"event": "complete"}} + ] + }' \ + "$AUTOMATION_API_URL/v1/kv/batch" +``` + +**Response:** +```json +{ + "version": 5, + "results": [ + {"op": "incr", "key": "counter", "success": true, "value": 42}, + {"op": "set", "key": "last_run", "success": true, "created": true}, + {"op": "rpush", "key": "log", "success": true, "length": 3} + ] +} +``` + +## Concurrency Patterns + +### For Scheduled Automations (max_concurrent_runs=1) + +When only one instance runs at a time, there's no contention: + +```python +import os +import httpx + +async def main(): + token = os.environ["AUTOMATION_KV_TOKEN"] + api_url = os.environ["AUTOMATION_API_URL"] + + async with httpx.AsyncClient() as client: + # Simple read-modify-write, no retry needed + resp = await client.get( + f"{api_url}/v1/kv/counter", + headers={"Authorization": f"Bearer {token}"} + ) + if resp.status_code == 200: + counter = resp.json()["value"] + else: + counter = 0 + + # Or just use atomic incr + resp = await client.post( + f"{api_url}/v1/kv/counter/incr", + headers={"Authorization": f"Bearer {token}"}, + json={"by": 1} + ) +``` + +### For Event Handlers (max_concurrent_runs > 1) + +When multiple instances run concurrently, **409 Conflicts are expected**. Always implement retry with exponential backoff: + +```python +import asyncio +import os +import random +import httpx + +async def kv_set_with_retry( + client: httpx.AsyncClient, + key: str, + value: any, + max_retries: int = 5 +) -> dict: + """Set a KV value with automatic retry on conflict.""" + token = os.environ["AUTOMATION_KV_TOKEN"] + api_url = os.environ["AUTOMATION_API_URL"] + + for attempt in range(max_retries): + resp = await client.put( + f"{api_url}/v1/kv/{key}", + headers={"Authorization": f"Bearer {token}"}, + json=value + ) + + if resp.status_code in (200, 201): + return resp.json() + + if resp.status_code == 409: + # Get suggested retry delay from header + retry_after = int(resp.headers.get("Retry-After", 1)) + # Exponential backoff with jitter + delay = retry_after * (2 ** attempt) + random.uniform(0, 0.5) + await asyncio.sleep(delay) + continue + + resp.raise_for_status() + + raise Exception(f"Failed to set {key} after {max_retries} retries") +``` + +### Using Optimistic Concurrency + +For read-modify-write patterns, use `if_version` to detect concurrent modifications: + +```python +async def safe_update(client: httpx.AsyncClient, key: str, transform_fn): + """Safely update a value using optimistic concurrency.""" + token = os.environ["AUTOMATION_KV_TOKEN"] + api_url = os.environ["AUTOMATION_API_URL"] + headers = {"Authorization": f"Bearer {token}"} + + for attempt in range(5): + # Read with version metadata + resp = await client.get( + f"{api_url}/v1/kv/{key}", + headers=headers, + params={"meta": "true"} + ) + + if resp.status_code == 404: + # Key doesn't exist, create it + initial_value = transform_fn(None) + resp = await client.put( + f"{api_url}/v1/kv/{key}", + headers=headers, + json=initial_value, + params={"nx": "true"} # Only if not exists + ) + if resp.status_code in (200, 201): + return resp.json() + continue # Retry if conflict + + data = resp.json() + version = data["version"] + old_value = data["value"] + + # Apply transformation locally + new_value = transform_fn(old_value) + + # Write with version check + resp = await client.put( + f"{api_url}/v1/kv/{key}", + headers=headers, + json=new_value, + params={"if_version": version} + ) + + if resp.status_code in (200, 201): + return resp.json() + + if resp.status_code == 409: + # Version changed, retry with backoff + await asyncio.sleep(0.1 * (2 ** attempt)) + continue + + resp.raise_for_status() + + raise Exception("Max retries exceeded") + + +# Usage example +async def increment_counter(): + async with httpx.AsyncClient() as client: + result = await safe_update( + client, + "counter", + lambda v: (v or 0) + 1 + ) + print(f"Counter is now: {result['value']}") +``` + +## Best Practices + +### DO ✅ + +- **Use atomic operations** (`incr`, `push`, `pop`) when possible - they're conflict-free +- **Keep state small** (< 64KB total, ideally < 8KB for best performance) +- **Design for idempotency** - operations may be retried +- **Use batch endpoint** for multiple updates in one operation +- **Implement proper retry logic** for concurrent event handlers + +### DON'T ❌ + +- **Read state, sleep, then write** - maximizes contention +- **Store large blobs** - use object storage instead +- **Ignore 409 errors** - always handle with retry +- **Use KV as a queue** - use proper message queues for high-throughput +- **Rely on ordering** across concurrent writes + +### Lock Timeout + +The KV store uses a single service-wide row-lock timeout (default: 5000ms), +configured via `AUTOMATION_KV_LOCK_TIMEOUT_MS` on the service. Operations +that can't acquire the row lock within this window return HTTP 409 with a +`Retry-After` header. Clients should always implement retry-with-backoff +on 409 — see the example handler below. + +## Error Handling + +| Status | Meaning | Action | +|--------|---------|--------| +| 200 | Success (update) | - | +| 201 | Success (create) | - | +| 400 | Bad request | Fix request (invalid key, etc.) | +| 401 | Unauthorized | Check token | +| 404 | Key not found | Handle missing key | +| 409 | Conflict | Retry with backoff (see `Retry-After` header) | +| 413 | Payload too large | Reduce state size | +| 503 | Service unavailable | KV store not configured | + +### 409 Conflict Types + +The 409 response can indicate: + +1. **Lock timeout** (`kv_store_busy`): Another operation is holding the lock + ```json + {"detail": "kv_store_busy: another operation is in progress, please retry"} + ``` + +2. **Version mismatch** (`version_mismatch`): State changed since your read + ```json + { + "detail": { + "error": "version_mismatch", + "expected_version": 5, + "actual_version": 6 + } + } + ``` + +Both include a `Retry-After: 1` header suggesting initial backoff. + +## Debugging + +### Common Issues + +**Frequent 409s:** +- Too much concurrent access +- Solutions: + - Reduce `max_concurrent_runs` + - Use atomic operations instead of read-modify-write + - Ask an operator to lower `AUTOMATION_KV_LOCK_TIMEOUT_MS` so contended + operations fail fast and retry sooner + +**Slow operations:** +- State document too large +- Solutions: + - Split into multiple keys + - Store large data externally + - Clean up old/unused keys + +**Version mismatches:** +- Concurrent modifications +- Solutions: + - Use atomic operations + - Implement proper retry loop + - Reconsider if you need concurrent access + +### Metrics + +If Prometheus metrics are enabled, monitor: + +- `kv_operation_duration_seconds`: Operation latency +- `kv_lock_wait_duration_seconds`: Time waiting for row lock +- `kv_conflict_total{reason="lock_timeout|version_mismatch"}`: Conflict rate +- `kv_state_size_bytes`: State document size + +High lock wait times or conflict rates indicate contention that may need architectural changes. + +## API Reference + +See the [KV Store Design Document](kv-store-design.md) for full API specification and implementation details. diff --git a/docs/kv-store-design.md b/docs/kv-store-design.md new file mode 100644 index 0000000..26c7454 --- /dev/null +++ b/docs/kv-store-design.md @@ -0,0 +1,910 @@ +# Automation KV Store - Design Document + +## Problem Statement + +One of the use cases for the automations system is implementing integrations. Some kinds of integrations—like many webhook responders—will have a stateless implementation. They receive an event, do some work, and complete. No memory of previous runs is needed. + +But other kinds of jobs require small amounts of data storage to work effectively. + +For example, consider an automation that summarizes data from Slack or another source. A common pattern would be for each run of the integration to store the last timestamp of the retrieved dataset, and then on the next scheduled run, look for items since that date. This avoids reprocessing the entire history on every run and enables efficient incremental sync patterns. + +**But where should the automation store this data?** + +Surely a GitHub repo wouldn't be a great fit—commits for every timestamp update would pollute the history and is simply the wrong tool for the job. + +We could have integration authors use custom solutions for persistence—JSONBin.io, Redis Cloud, a personal database, or some other external service. These work, but they require users to provision, configure, and manage external infrastructure. + +If external systems are required for such a prevalent use case, that erodes the simplicity of having a **batteries-included** solution. The promise of the automation platform is that you can build and deploy integrations without managing infrastructure. Requiring external storage for basic state persistence breaks that promise. + +## Solution + +Provide a built-in **key-value store API** scoped to each automation. Every automation has access to persistent storage that: + +- **Is easy to use** — simple GET/SET operations, familiar Redis-like semantics +- **Is flexible** — supports JSON values, counters, lists/queues, nested paths +- **Is secure** — application-level encryption, isolated per-automation (one automation cannot access another's data) + +We don't need massive storage capacity or high-performance operations. An automation might run once per hour and make 5-10 KV operations. **Simplicity and security matter more than raw speed.** + +## Goals + +1. Provide a simple key-value store API scoped to each automation +2. Ensure strict isolation — automation A cannot access automation B's data +3. Support atomic operations for safe concurrent access (Redis-like guarantees) +4. **Application-level encryption** for all stored values — customers can trust storing sensitive data (API keys, tokens, cursors) +5. Follow OpenHands encryption conventions + +## Non-Goals + +**We are not building Redis.** We borrow Redis's well-designed API semantics because they're familiar and battle-tested, but we have different requirements: + +| Aspect | Redis | Automation KV Store | +|--------|-------|---------------------| +| **Use case** | High-throughput cache, real-time apps | Occasional state persistence for scheduled agents | +| **Operations/sec** | Millions | Tens (at most) | +| **Storage** | In-memory | PostgreSQL (durable) | +| **Latency target** | Sub-millisecond | Hundreds of milliseconds is fine | +| **Encryption** | Optional, at-rest only | **Required, application-level** | + +The overhead of JWT verification, JWE encryption, and PostgreSQL round-trips is completely acceptable for our use case. **Correctness, security, and durability matter more than raw speed.** + +--- + +## Security Design + +### The Problem with User-Level Auth + +Initial idea: Use the existing `OPENHANDS_API_KEY` (user's temp API key) to authenticate KV requests, with `automation_id` in the URL path. + +**Flaw**: Two automations owned by the same user could access each other's data, since both run with the same user's credentials. + +### Solution: Per-Run JWT Tokens + +Generate a short-lived, signed JWT token for each automation run that embeds the `automation_id` as a trusted claim. + +**Flow:** +``` +Dispatcher creates run + ↓ +Generate JWT: {automation_id, run_id, exp} + ↓ +Sign with service's secret key + ↓ +Pass as AUTOMATION_KV_TOKEN env var to sandbox + ↓ +Agent includes token in KV API requests + ↓ +API verifies signature, extracts automation_id from trusted claim + ↓ +All KV operations scoped to that automation_id +``` + +**Why JWT over per-automation secrets:** +- Tokens are time-limited (expire with the run) +- Single signing key to manage (vs N secrets for N automations) +- Stateless verification (no DB lookup to identify automation) +- Can include additional context (run_id for audit) + +### Encryption at Rest (Required) + +**All KV values are encrypted at the application level before storage.** This ensures: + +- Database administrators cannot read sensitive values +- Database backups contain only encrypted data +- Customers can confidently store API keys, tokens, and credentials +- Compliance with security best practices + +Following OpenHands conventions from the parent project: + +| Component | Approach | +|-----------|----------| +| **Auth tokens** | JWS (JSON Web Signature) with HS256 | +| **KV values** | Fernet (AES-128-CBC + HMAC-SHA256), via SDK `Cipher` helper | +| **Key management** | Single master key from `AUTOMATION_KV_SECRET` env var | +| **Libraries** | `pyjwt` for tokens; `openhands.sdk.utils.cipher.Cipher` for values | + +**Pattern (mirrors the rest of the platform):** +```python +# openhands/automation/utils/kv.py +from openhands.sdk.utils.cipher import Cipher +from pydantic import SecretStr + +def encrypt_value(secret: str, value) -> str: + plaintext = strict_json(value) # validates + serializes + return Cipher(secret).encrypt(SecretStr(plaintext)) + +def decrypt_value(secret: str, encrypted: str): + return json.loads(Cipher(secret).decrypt(encrypted).get_secret_value()) +``` + +Using the SDK's `Cipher` keeps this module thin and shares a vetted +implementation with the rest of the OpenHands platform — we don't need to +maintain our own AES code or worry about IV management, padding, or +authentication tag handling. + +**What's stored in the database:** +``` +state_encrypted: "gAAAAABm..." +``` + +**What the application sees after decryption:** +```json +{"api_key": "sk-secret-123", "last_cursor": "abc"} +``` + +--- + +## API Design + +### Authentication + +All KV endpoints require the `AUTOMATION_KV_TOKEN` in the Authorization header: +``` +Authorization: Bearer +``` + +The token contains: +```json +{ + "automation_id": "uuid", + "run_id": "uuid", + "exp": 1234567890 +} +``` + +### Base Path + +``` +/api/automation/v1/kv +``` + +Note: No `automation_id` in URL - it comes from the verified JWT claim. + +--- + +## API Endpoints + +### Overview + +| Endpoint | Method | Redis Equivalent | Description | +|----------|--------|------------------|-------------| +| `/kv` | GET | `KEYS *` | List all keys | +| `/kv/{key}` | GET | `GET` | Get value | +| `/kv/{key}?path=x.y` | GET | `HGET` | Get nested field | +| `/kv/{key}?meta=true` | GET | - | Get value with metadata | +| `/kv/{key}` | PUT | `SET` | Set value | +| `/kv/{key}?nx=true` | PUT | `SET ... NX` / `SETNX` | Set if not exists | +| `/kv/{key}?xx=true` | PUT | `SET ... XX` | Set if exists | +| `/kv/{key}` | PATCH | `HSET` | Update nested path | +| `/kv/{key}` | DELETE | `DEL` | Delete key | +| `/kv/{key}/incr` | POST | `INCR` / `INCRBY` | Atomic increment | +| `/kv/{key}/decr` | POST | `DECR` / `DECRBY` | Atomic decrement | +| `/kv/{key}/lpush` | POST | `LPUSH` | Push to left of list | +| `/kv/{key}/rpush` | POST | `RPUSH` | Push to right of list | +| `/kv/{key}/lpop` | POST | `LPOP` | Pop from left | +| `/kv/{key}/rpop` | POST | `RPOP` | Pop from right | +| `/kv/{key}/len` | GET | `LLEN` | Get list length | + +### Redis Alignment + +The API uses familiar Redis conventions where applicable: + +| Redis Flag | Query Param | Meaning | +|------------|-------------|---------| +| `NX` | `?nx=true` | Only set if key does **not** exist | +| `XX` | `?xx=true` | Only set if key **does** exist | +| `EX` | `?ex=3600` | Set TTL in seconds (future) | + +List operation names (`lpush`, `rpush`, `lpop`, `rpop`) match Redis exactly: +- `l` = left (front of list) +- `r` = right (back of list) +- `push` = add item +- `pop` = remove and return item + +**Queue patterns:** +- FIFO queue: `rpush` to enqueue, `lpop` to dequeue +- LIFO stack: `rpush` to push, `rpop` to pop + +--- + +## Request/Response Format + +All responses are JSON objects for consistency and extensibility. + +### List Keys + +**Request:** +```http +GET /kv +``` + +**Response:** +```json +{ + "keys": ["config", "counter", "task-queue"], + "count": 3 +} +``` + +### Get Value + +**Request:** +```http +GET /kv/config +``` + +**Response:** +```json +{ + "key": "config", + "value": { + "database": {"host": "localhost", "port": 5432}, + "retries": 3 + } +} +``` + +### Get Nested Path + +**Request:** +```http +GET /kv/config?path=database.host +``` + +**Response:** +```json +{ + "key": "config", + "path": "database.host", + "value": "localhost" +} +``` + +### Get with Metadata + +**Request:** +```http +GET /kv/config?meta=true +``` + +**Response:** +```json +{ + "key": "config", + "value": { + "database": {"host": "localhost", "port": 5432}, + "retries": 3 + }, + "created_at": "2024-01-15T10:00:00Z", + "updated_at": "2024-01-15T12:30:00Z" +} +``` + +### Set Value + +**Request:** +```http +PUT /kv/config +Content-Type: application/json + +{ + "database": {"host": "localhost", "port": 5432}, + "retries": 3 +} +``` + +**Response:** +```json +{ + "key": "config", + "value": { + "database": {"host": "localhost", "port": 5432}, + "retries": 3 + }, + "created": true, + "updated_at": "2024-01-15T12:30:00Z" +} +``` + +### Set If Not Exists (SETNX) + +**Request:** +```http +PUT /kv/lock?nx=true +Content-Type: application/json + +{"owner": "run-123", "acquired_at": "2024-01-15T12:30:00Z"} +``` + +**Response (success - key was created):** +```json +{ + "key": "lock", + "value": {"owner": "run-123", "acquired_at": "2024-01-15T12:30:00Z"}, + "created": true, + "updated_at": "2024-01-15T12:30:00Z" +} +``` + +**Response (failure - key already exists):** +```json +{ + "key": "lock", + "created": false, + "error": "key_exists" +} +``` +HTTP Status: `409 Conflict` + +### Update Nested Path + +**Request:** +```http +PATCH /kv/config +Content-Type: application/json + +{ + "path": "database.port", + "value": 5433 +} +``` + +**Response:** +```json +{ + "key": "config", + "path": "database.port", + "value": 5433, + "updated_at": "2024-01-15T12:35:00Z" +} +``` + +### Delete Key + +**Request:** +```http +DELETE /kv/config +``` + +**Response:** +```json +{ + "key": "config", + "deleted": true +} +``` + +### Increment + +**Request:** +```http +POST /kv/counter/incr +Content-Type: application/json + +{"by": 1} +``` + +Note: `by` defaults to 1 if not provided. + +**Response:** +```json +{ + "key": "counter", + "value": 43 +} +``` + +### Decrement + +**Request:** +```http +POST /kv/counter/decr +Content-Type: application/json + +{"by": 5} +``` + +**Response:** +```json +{ + "key": "counter", + "value": 38 +} +``` + +### Push to List (Left) + +**Request:** +```http +POST /kv/task-queue/lpush +Content-Type: application/json + +{"value": {"task_id": "abc123", "action": "process"}} +``` + +**Response:** +```json +{ + "key": "task-queue", + "length": 5 +} +``` + +### Push to List (Right) + +**Request:** +```http +POST /kv/task-queue/rpush +Content-Type: application/json + +{"value": {"task_id": "def456", "action": "notify"}} +``` + +**Response:** +```json +{ + "key": "task-queue", + "length": 6 +} +``` + +### Pop from List (Left) + +**Request:** +```http +POST /kv/task-queue/lpop +``` + +**Response (item returned):** +```json +{ + "key": "task-queue", + "value": {"task_id": "abc123", "action": "process"} +} +``` + +**Response (list empty):** +```json +{ + "key": "task-queue", + "value": null +} +``` + +### Pop from List (Right) + +**Request:** +```http +POST /kv/task-queue/rpop +``` + +**Response:** +```json +{ + "key": "task-queue", + "value": {"task_id": "def456", "action": "notify"} +} +``` + +### Get List Length + +**Request:** +```http +GET /kv/task-queue/len +``` + +**Response:** +```json +{ + "key": "task-queue", + "length": 42 +} +``` + +--- + +## Error Responses + +All errors return JSON with consistent structure: + +```json +{ + "error": "error_code", + "message": "Human-readable description" +} +``` + +| HTTP Status | Error Code | Description | +|-------------|------------|-------------| +| 400 | `invalid_request` | Malformed request body | +| 400 | `invalid_path` | Invalid JSON path syntax | +| 400 | `type_mismatch` | Operation doesn't match value type (e.g., incr on object) | +| 401 | `unauthorized` | Missing or invalid token | +| 403 | `token_expired` | JWT token has expired | +| 404 | `key_not_found` | Key does not exist | +| 409 | `key_exists` | Key already exists (for `?nx=true`) | +| 409 | `key_not_exists` | Key doesn't exist (for `?xx=true`) | + +--- + +## Why Atomic Operations Matter + +**Scenario:** Two runs of the same automation overlap (previous run slow, next scheduled run starts): + +Without atomics: +``` +Run A: GET counter → 5 +Run B: GET counter → 5 +Run A: PUT counter → 6 +Run B: PUT counter → 6 # Lost update! +``` + +With INCR: +``` +Run A: INCR counter → 6 +Run B: INCR counter → 7 # Correct! +``` + +--- + +## Implementation Notes + +### Single-Document Storage Design + +Each automation has exactly **ONE row** in the database containing its entire state as an encrypted JSON document. API "keys" (e.g., `/kv/config`, `/kv/counter`) are top-level fields within this single document. + +**Why single-document?** +- **Eliminates deadlocks**: Only one row per automation to lock. All operations serialize through that single lock. No possibility of lock ordering issues. +- **Simpler model**: One encryption boundary, one row to manage per automation. +- **Acceptable trade-off**: Every operation reads/writes the entire state blob, but automation state is small and access is infrequent. + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Database Row (ONE per automation) │ +├─────────────────────────────────────────────────────────────┤ +│ automation_id: uuid-123 (UNIQUE) │ +│ state_encrypted: │ +│ │ +│ Decrypted contents: │ +│ { │ +│ "config": {"host": "localhost", "port": 5432}, │ +│ "counter": 42, │ +│ "task-queue": [{"task_id": "abc"}] │ +│ } │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Atomic Operations with Encryption + +Since values are encrypted at the application level, we **cannot** use native PostgreSQL operations like `value = value + 1`. Instead, atomic operations lock the single state row and perform read-modify-write: + +```python +async def incr(self, automation_id: UUID, key: str, by: int = 1) -> int: + async with session.begin(): + # 1. Lock the automation's state row (ONE row per automation) + row = await session.execute( + select(AutomationKV) + .where(AutomationKV.automation_id == automation_id) + .with_for_update() + ) + kv = row.scalar_one_or_none() + + # 2. Decrypt entire state, modify target key, encrypt + if kv is None: + state = {key: by} + kv = AutomationKV(automation_id=automation_id) + session.add(kv) + else: + state = decrypt_value(kv.state_encrypted) + if key not in state: + state[key] = by + else: + value = state[key] + if not isinstance(value, int): + raise TypeError("Cannot increment non-integer value") + state[key] = value + by + + # 3. Update with encrypted state + kv.state_encrypted = encrypt_value(state) + + # 4. Commit releases lock + return state[key] +``` + +**Concurrency model:** +- Each automation has ONE row → all operations serialize through one lock +- No deadlock risk between keys (there's only one lock to acquire) +- Different automations → completely isolated (different rows) + +This is acceptable for our use case (automations doing 5-10 KV ops per run). The brief lock during decrypt-modify-encrypt is negligible. + +### SETNX (Set If Not Exists) + +For conditional set operations, we lock the state row, check if the key exists in the decrypted state, and proceed accordingly: + +```python +# Lock state row +state = decrypt_value(kv.state_encrypted) if kv else {} + +if nx and key in state: + return 409 # Key already exists + +state[key] = value +kv.state_encrypted = encrypt_value(state) +``` + +### Path Syntax + +Use dot notation for nested paths: `database.host` + +For keys containing dots, use bracket notation: `config["my.key.with.dots"]` + +--- + +## Data Model + +```python +class AutomationKV(Base): + """Single-document state store for automation persistence. + + Each automation has exactly ONE row containing its entire state as an + encrypted JSON document. The API presents a key-value interface, but + "keys" are top-level fields within this single document. + """ + __tablename__ = "automation_kv" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + automation_id: Mapped[uuid.UUID] = mapped_column( + Uuid, + ForeignKey("automations.id", ondelete="CASCADE"), + nullable=False, + unique=True, # ONE row per automation + ) + + # Encrypted JSON document containing all KV pairs, stored as a Fernet + # token (URL-safe base64 text) produced by the SDK's Cipher helper. + # Decrypted example: {"config": {...}, "counter": 42, "queue": [...]} + state_encrypted: Mapped[str] = mapped_column(Text, nullable=False) + + # Timestamps (foundation for future TTL support) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + onupdate=utcnow, + nullable=False, + ) + + __table_args__ = ( + Index("ix_automation_kv_automation_id", "automation_id", unique=True), + ) +``` + +### Future TTL Support + +The `created_at` and `updated_at` timestamps provide the foundation for TTL: + +```python +# Future addition for TTL +expires_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), + nullable=True, + index=True, # For efficient cleanup queries +) +``` + +TTL can be set via `?ex=3600` query param: +```http +PUT /kv/session?ex=3600 +``` + +--- + +## Dependencies to Add + +```toml +# pyproject.toml +dependencies = [ + # ... existing ... + "pyjwt>=2.8", + # Fernet encryption is provided by the SDK's Cipher helper, which is + # already pulled in via openhands-sdk — no extra crypto dependency needed. +] +``` + +--- + +## Environment Variables + +```bash +# Required: Master key for JWT signing and Fernet encryption of KV values. +# When this is set the KV store is enabled service-wide; every automation +# gets a token at dispatch time. When it's empty the feature is disabled. +AUTOMATION_KV_SECRET= + +# Optional: Row-lock timeout in milliseconds for KV operations (default: 5000). +# Applied via PostgreSQL `SET LOCAL lock_timeout` before each FOR UPDATE. +AUTOMATION_KV_LOCK_TIMEOUT_MS=5000 +``` + +--- + +## Agent Integration + +Agents need a simple way to interact with the KV store. We provide a client library and tools. + +### Package Structure + +**Recommended: Separate `openhands-kv` package** hosted in its own repo under the OpenHands org. + +``` +openhands-kv/ +├── openhands/ +│ └── kv/ +│ ├── __init__.py +│ ├── client.py ← KVClient class (HTTP wrapper) +│ └── tool.py ← KVStoreTool definition +├── pyproject.toml +└── README.md +``` + +This package is installed in the sandbox via `setup.sh`: + +```bash +# presets/prompt/setup.sh +pip install -q --no-cache-dir \ + "openhands-sdk==${SDK_VERSION}" \ + "openhands-workspace==${SDK_VERSION}" \ + "openhands-tools==${SDK_VERSION}" \ + "openhands-kv==0.1.0" +``` + +**Why a separate package?** +- Independent release cycle from SDK and automation service +- Clean separation of concerns +- Can be used outside automations if needed +- No changes required to agent-sdk or automation build systems + +> **Future consideration:** If automation grows more packages, consider converting +> the automation repo to a monorepo structure (like agent-sdk) to co-locate +> related packages while maintaining independent releases. + +### Client Library + +```python +from openhands.kv import KVClient + +# Auto-reads AUTOMATION_KV_TOKEN and OPENHANDS_CLOUD_API_URL from environment +kv = KVClient() + +# Basic operations +config = kv.get("config") +kv.set("config", {"database": {"host": "localhost"}}) +kv.delete("old-key") + +# Counters +kv.set("counter", 0) +new_value = kv.incr("counter") # Returns 1 +kv.incr("counter", by=5) # Returns 6 + +# Nested paths +host = kv.get("config", path="database.host") +kv.patch("config", path="database.port", value=5433) + +# Lists/Queues +kv.rpush("task-queue", {"task_id": "abc", "action": "process"}) +kv.rpush("task-queue", {"task_id": "def", "action": "notify"}) +task = kv.lpop("task-queue") # FIFO dequeue +length = kv.len("task-queue") + +# Conditional set (for locks, idempotency) +created = kv.set("lock", {"owner": "run-123"}, nx=True) +if not created: + print("Lock already held by another run") +``` + +### Agent Tool + +The KV store is always available, so the preset's `sdk_main.py` loads the +tool unconditionally whenever a KV token is present in the environment: + +```python +# In presets/prompt/sdk_main.py + +if os.environ.get("AUTOMATION_KV_TOKEN"): + from openhands.kv import KVStoreTool + # Register tool with agent +``` + +### Environment Variables + +The dispatcher injects a token for every run whenever the service has a KV +secret configured (i.e., whenever the feature is enabled service-wide): + +| Env Var | Purpose | +|---------|---------| +| `AUTOMATION_KV_TOKEN` | JWT token scoped to this automation | + +### Environment Detection + +The library auto-detects when running in an automation context: + +```python +class KVClient: + def __init__(self, token: str | None = None, base_url: str | None = None): + self.token = token or os.environ.get("AUTOMATION_KV_TOKEN") + self.base_url = base_url or os.environ.get("OPENHANDS_CLOUD_API_URL") + + if not self.token: + raise KVNotAvailableError( + "KV store is only available within automation runs. " + "AUTOMATION_KV_TOKEN environment variable not found." + ) +``` + +This gives a clear error if someone tries to use KV outside an automation context. + +--- + +## Open Questions / Limits + +| Topic | Question | Suggested Default | +|-------|----------|-------------------| +| **Key length** | Max characters for key names? | 255 characters | +| **State size** | Max total state size per automation? | 1 MB (encrypted) | +| **Retention** | What happens when automation is deleted? | Cascade delete all KV data | +| **TTL** | Support key expiration? | Deferred (timestamps in place for future) | + +These limits are generous for the intended use case (state persistence between automation runs). They can be adjusted based on usage patterns. + +--- + +## Next Steps + +### Design (Complete) +1. [x] Decide on MVP API scope - CRUD + counters + lists + paths +2. [x] Decide on value types - Any JSON value +3. [x] Define response format - Consistent JSON objects +4. [x] Define agent integration approach - Separate `openhands-kv` package +5. [x] Define security model - Per-run JWT tokens + JWE encryption + +### Implementation (TODO) + +**Automation Service (this repo):** +1. [x] Implement JWT signing for `AUTOMATION_KV_TOKEN` (`utils/kv.py`) +2. [x] Implement value encryption via the SDK's `Cipher` helper +3. [x] Create database migration for `automation_kv` table +4. [x] Implement KV API router (`/api/automation/v1/kv/...`) +5. [x] Update dispatcher to generate and pass `AUTOMATION_KV_TOKEN` whenever + the service has a KV secret configured (no per-automation toggle) +6. [ ] Update preset `sdk_main.py` to load the KV tool when + `AUTOMATION_KV_TOKEN` is set +7. [ ] Update preset `setup.sh` to install `openhands-kv` +8. [ ] Update the `openhands-automation` skill so agents know the KV store + is available out of the box (follow-up) + +**New `openhands-kv` Package (new repo):** +1. [ ] Create repo under OpenHands org +2. [ ] Implement `KVClient` (HTTP client library) +3. [ ] Implement `KVStoreTool` (agent tool definition) +4. [ ] Publish to PyPI +5. [ ] Documentation + +### Testing +1. [ ] Unit tests for KV API endpoints +2. [ ] Unit tests for JWT/encryption +3. [ ] Integration tests for full flow (automation → sandbox → KV API) +4. [ ] Test atomic operations (concurrent INCR, etc.) + +--- + +## References + +- [JSONBin.io API](https://jsonbin.io/api-reference) +- [Redis Commands](https://redis.io/commands/) +- OpenHands encryption: `OpenHands/OpenHands/enterprise/storage/encrypt_utils.py` +- OpenHands JWT service: `OpenHands/OpenHands/openhands/app_server/services/jwt_service.py` diff --git a/docs/kv-store-test-plan.md b/docs/kv-store-test-plan.md new file mode 100644 index 0000000..563cf29 --- /dev/null +++ b/docs/kv-store-test-plan.md @@ -0,0 +1,147 @@ +# KV Store Test Plan + +**PR:** [OpenHands/automation#69](https://github.com/OpenHands/automation/pull/69) + +> **Note:** Most test cases are now automated in `scripts/test_kv_e2e.py`. +> This document covers manual testing workflows, debugging, and tests not yet automated. + +--- + +## Token Injection + +The dispatcher injects `AUTOMATION_KV_TOKEN` into the sandbox for every run +whenever the service has `AUTOMATION_KV_SECRET` configured. The KV store is +available to every automation by default — there is no per-automation +toggle. When the service has no KV secret configured, no token is injected +and the KV API responds with HTTP 503. + +--- + +## Automated Tests + +Run the E2E test suite: + +```bash +# Quick smoke test (8 tests, ~30s) +python scripts/test_kv_e2e.py + +# Full test suite (26 tests, ~2min) +python scripts/test_kv_e2e.py --thorough +``` + +**Coverage:** Basic CRUD, INCR/DECR, list operations, nested paths, conditional SET, auth errors, type errors, edge cases. + +--- + +## Tests NOT Yet Automated + +The following require multi-run or multi-automation coordination: + +### Cross-Automation Isolation (TC-9.2) + +Verify automation A cannot access automation B's KV data: + +```bash +# Create Automation A - writes "shared-name" = "I am A" +curl -X POST "${BASE_URL}/api/automation/v1/preset/prompt" \ + -H "Authorization: Bearer ${API_KEY}" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Isolation Test A", + "prompt": "Set KV key \"shared-name\" to \"I am Automation A\". Then read and print it.", + "trigger": {"type": "cron", "schedule": "0 0 1 1 *"} + }' +# Enable KV, dispatch, note automation_id as A_ID + +# Create Automation B - writes "shared-name" = "I am B" +curl -X POST "${BASE_URL}/api/automation/v1/preset/prompt" \ + -H "Authorization: Bearer ${API_KEY}" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Isolation Test B", + "prompt": "Set KV key \"shared-name\" to \"I am Automation B\". Then read and print it.", + "trigger": {"type": "cron", "schedule": "0 0 1 1 *"} + }' +# Enable KV, dispatch, note automation_id as B_ID + +# Run A again - should still see "I am Automation A" (not B's value) +curl -X POST "${BASE_URL}/api/automation/v1/${A_ID}/dispatch" \ + -H "Authorization: Bearer ${API_KEY}" +``` + +### State Persistence Across Runs (TC-10.1) + +Verify KV data persists between automation runs: + +```bash +curl -X POST "${BASE_URL}/api/automation/v1/preset/prompt" \ + -H "Authorization: Bearer ${API_KEY}" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "KV Persistence Test", + "prompt": "Read KV key \"run_counter\". Print current value (or \"first run\" if missing). Increment it. Print new value.", + "trigger": {"type": "cron", "schedule": "0 0 1 1 *"} + }' + +# Run 1: Should print "first run", counter = 1 +# Run 2: Should print "1", counter = 2 +# Run 3: Should print "2", counter = 3 +``` + +--- + +## Debugging: View Automation Run Results + +Automation runs create conversations. To see what happened: + +```bash +# 1. Find conversation ID +curl -s "${BASE_URL}/api/v1/app-conversations/search?limit=10" \ + -H "Authorization: Bearer ${API_KEY}" \ + | jq '.items[] | {id, automation_name: .tags.automationname, status: .sandbox_status}' + +# 2. Get events for a conversation +CONV_ID="" +EVENT_IDS=$(curl -s "${BASE_URL}/api/v1/conversation/${CONV_ID}/events/search?limit=50" \ + -H "Authorization: Bearer ${API_KEY}" | jq -r '.items | map("id=" + .id) | join("&")') + +# 3. View command outputs +curl -s "${BASE_URL}/api/v1/conversation/${CONV_ID}/events?${EVENT_IDS}" \ + -H "Authorization: Bearer ${API_KEY}" \ + | jq '.[] | select(.kind == "ObservationEvent") | { + command: .observation.command, + output: .observation.content[0].text[0:500] + }' +``` + +--- + +## Quick Reference Commands + +```bash +# Create automation +curl -X POST "${BASE_URL}/api/automation/v1/preset/prompt" \ + -H "Authorization: Bearer ${API_KEY}" \ + -H "Content-Type: application/json" \ + -d '{"name": "Test", "prompt": "...", "trigger": {"type": "cron", "schedule": "0 0 1 1 *"}}' + +# Dispatch run +curl -X POST "${BASE_URL}/api/automation/v1/${ID}/dispatch" \ + -H "Authorization: Bearer ${API_KEY}" + +# List runs +curl "${BASE_URL}/api/automation/v1/${ID}/runs" \ + -H "Authorization: Bearer ${API_KEY}" + +# Delete automation +curl -X DELETE "${BASE_URL}/api/automation/v1/${ID}" \ + -H "Authorization: Bearer ${API_KEY}" +``` + +--- + +## Notes + +1. **Token is sandbox-only:** `AUTOMATION_KV_TOKEN` is injected at runtime. You cannot extract it externally. + +3. **Token scope:** Each token is scoped to a specific automation ID for strict isolation. diff --git a/migrations/versions/006_add_kv_store.py b/migrations/versions/006_add_kv_store.py new file mode 100644 index 0000000..68bec28 --- /dev/null +++ b/migrations/versions/006_add_kv_store.py @@ -0,0 +1,109 @@ +"""Add key-value store for automation state persistence. + +This migration adds the ``automation_kv`` table — one row per automation +holding the entire state document as an encrypted blob. + +Single-Document Design (Deadlock Prevention) +============================================ + +Each automation has exactly ONE row in automation_kv containing its entire +state as an encrypted JSON document. The API presents a key-value interface, +but "keys" are top-level fields within this single document. + +By storing all state in one row per automation, we eliminate multi-key +deadlock scenarios. All operations serialize through a single row lock. + +Storage Design +============== + +The encrypted state is stored as TEXT (a Fernet token, URL-safe base64). +We use the SDK's ``openhands.sdk.utils.cipher.Cipher`` (Fernet under the +hood) for encryption — see ``openhands/automation/utils/kv.py``. Fernet +emits a base64 string rather than raw bytes, so TEXT is the natural column +type. The ~33% base64 overhead is acceptable for the small JSON documents +typical of automation state (counters, cursors, configs) and keeps the +schema simple. + +Revision ID: 006 +Revises: 005 +Create Date: 2026-04-24 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + + +revision: str = "006" +down_revision: str = "005" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def _is_sqlite() -> bool: + """Check if we are running against SQLite (test/dev only).""" + return op.get_bind().dialect.name == "sqlite" + + +def upgrade() -> None: + # Create automation_kv table - ONE row per automation (single-document design) + # state_encrypted is a Fernet token (URL-safe base64 text) produced by the + # SDK Cipher. See module docstring for the design rationale. + op.create_table( + "automation_kv", + sa.Column("id", sa.Uuid, primary_key=True), + sa.Column( + "automation_id", + sa.Uuid, + sa.ForeignKey("automations.id", ondelete="CASCADE"), + nullable=False, + unique=True, # ONE row per automation - critical for deadlock prevention + ), + sa.Column("state_encrypted", sa.Text, nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + ) + + # Create unique index on automation_id (enforces one row per automation) + op.create_index( + "ix_automation_kv_automation_id", + "automation_kv", + ["automation_id"], + unique=True, + ) + + # Add schema-level documentation for DBAs and tools that inspect the + # schema directly without reading application source. SQLite doesn't + # support COMMENT, so skip these statements there. + if _is_sqlite(): + return + + op.execute( + "COMMENT ON TABLE automation_kv IS " + "'Single-document state store for automation persistence. " + "Each automation has ONE row containing its entire state as encrypted JSON. " + "The API presents a key-value interface where keys are top-level fields. " + "Single-row design eliminates multi-key deadlock scenarios. " + "See openhands/automation/utils/kv.py for encryption details.'" + ) + op.execute( + "COMMENT ON COLUMN automation_kv.state_encrypted IS " + "'Fernet token (URL-safe base64 text) containing the encrypted state " + "document as JSON. Produced by openhands.sdk.utils.cipher.Cipher.'" + ) + + +def downgrade() -> None: + op.drop_index("ix_automation_kv_automation_id", table_name="automation_kv") + op.drop_table("automation_kv") diff --git a/openhands/automation/app.py b/openhands/automation/app.py index da771d9..ec9c637 100644 --- a/openhands/automation/app.py +++ b/openhands/automation/app.py @@ -14,7 +14,7 @@ from sqlalchemy import text from openhands.automation.auth import create_http_client -from openhands.automation.config import get_settings +from openhands.automation.config import get_config, get_settings from openhands.automation.db import ( create_engine, create_session_factory, @@ -22,6 +22,7 @@ ) from openhands.automation.dispatcher import dispatcher_loop from openhands.automation.event_router import router as event_router +from openhands.automation.kv_router import router as kv_router from openhands.automation.logger import setup_all_loggers from openhands.automation.preset_router import router as preset_router from openhands.automation.router import router @@ -52,7 +53,10 @@ async def lifespan(app: FastAPI): ): logging.getLogger(noisy_logger).setLevel(logging.WARNING) - logger.info("Starting OpenHands Automations Service") + logger.info( + "Starting OpenHands Automations Service", + extra={"kv_store_configured": get_config().kv.enabled}, + ) # Create shared httpx client for auth (stored in app.state for DI) app.state.http_client = create_http_client() @@ -224,6 +228,7 @@ def _create_app() -> FastAPI: app.include_router(preset_router, prefix=_base_path) app.include_router(event_router, prefix=_base_path) app.include_router(webhook_router, prefix=_base_path) +app.include_router(kv_router, prefix=_base_path) app.include_router(router, prefix=_base_path) diff --git a/openhands/automation/config.py b/openhands/automation/config.py index b6118cb..71da7b4 100644 --- a/openhands/automation/config.py +++ b/openhands/automation/config.py @@ -8,7 +8,8 @@ ├── storage: StorageSettings # File storage (no prefix, SDK conventions) ├── log: LogSettings # Logging (no prefix) ├── http: HttpSettings # HTTP client (AUTOMATION_ prefix) - └── sandbox: SandboxSettings # Sandbox execution (AUTOMATION_ prefix) + ├── sandbox: SandboxSettings # Sandbox execution (AUTOMATION_ prefix) + └── kv: KVSettings # Key-value store (AUTOMATION_ prefix) Usage (preferred): from openhands.automation.config import get_config @@ -222,6 +223,64 @@ class SandboxSettings(BaseSettings): model_config = {"env_prefix": "AUTOMATION_"} +# --------------------------------------------------------------------------- +# KVSettings - Key-value store configuration +# --------------------------------------------------------------------------- + + +class KVSettings(BaseSettings): + """Key-value store configuration for automation state persistence. + + The KV store provides per-automation state storage with encryption and + JWT-based authentication. It is available to every automation whenever + AUTOMATION_KV_SECRET is configured at the service level. + + Environment variables (AUTOMATION_ prefix): + AUTOMATION_KV_SECRET: Secret for JWT signing and value encryption. + Must be set to enable KV store. Generate with: + python -c "import secrets; print(secrets.token_urlsafe(32))" + AUTOMATION_KV_MAX_VALUE_SIZE: Max value size in bytes (default: 64KB) + AUTOMATION_KV_LOCK_TIMEOUT_MS: Row-lock timeout in ms (default: 5000) + """ + + # Secret key for signing KV store JWT tokens and encrypting KV values. + # Must be set to enable the KV store feature. + kv_secret: str = "" + + # Row-lock timeout in milliseconds for KV operations. + # Applied via PostgreSQL `SET LOCAL lock_timeout` before FOR UPDATE. + # If the lock isn't acquired within this window we return 409 Conflict + # with Retry-After so clients can back off and retry. + kv_lock_timeout_ms: int = 5000 + + # Maximum size in bytes for KV store values (plaintext JSON, before encryption). + # + # Performance guidance - PostgreSQL TOAST behavior: + # + # Limit Stored Size TOAST Chunks Read Latency + # ------- ----------- ------------ ------------ + # < 2 KB inline 0 1x (optimal) + # 2-8 KB compressed 0 ~2x + # 64 KB ~65 KB ~33 ~5-10x + # 128 KB ~131 KB ~66 ~10-15x + # 256 KB ~262 KB ~131 ~15-25x + # 512 KB ~524 KB ~262 ~25-40x + # + # Values > 8KB are stored in a separate TOAST table, requiring index lookups + # for each ~2KB chunk. The default 64KB is generous for typical KV use cases + # (counters, flags, small configs). For larger blobs, consider object storage. + # + # Set to 0 to disable the limit (not recommended). + kv_max_value_size: int = 64 * 1024 # 64 KB + + model_config = {"env_prefix": "AUTOMATION_"} + + @property + def enabled(self) -> bool: + """Check if KV store is enabled (kv_secret is set).""" + return bool(self.kv_secret) + + # --------------------------------------------------------------------------- # ServiceSettings - Core service configuration (formerly "Settings") # --------------------------------------------------------------------------- @@ -307,6 +366,11 @@ class ServiceSettings(BaseSettings): gcp_project: str | None = None gcp_region: str | None = None + # Maximum seconds to wait for a connection from the pool. + # Prevents indefinite hangs when pool is exhausted due to slow operations. + # If pool exhaustion is frequent, increase pool_size rather than this timeout. + db_pool_timeout: float = 30 + # Local agent-server mode (self-hosted deployments) # When agent_server_url is set, the service operates in "local mode": # - Uses a persistent local agent server instead of cloud sandboxes @@ -448,6 +512,7 @@ class AppConfig: log: Logging settings http: HTTP client settings (timeouts, caching) sandbox: Sandbox execution settings (limits, retries) + kv: Key-value store settings (secrets, limits) Example: config = get_config() @@ -455,6 +520,7 @@ class AppConfig: print(config.storage.file_store) print(config.log.log_level) print(config.sandbox.max_run_duration) + print(config.kv.enabled) """ @cached_property @@ -482,6 +548,11 @@ def sandbox(self) -> SandboxSettings: """Sandbox execution configuration (AUTOMATION_ prefix).""" return SandboxSettings() + @cached_property + def kv(self) -> KVSettings: + """Key-value store configuration (AUTOMATION_ prefix).""" + return KVSettings() + @lru_cache def get_config() -> AppConfig: diff --git a/openhands/automation/db.py b/openhands/automation/db.py index 3046b06..6ff2936 100644 --- a/openhands/automation/db.py +++ b/openhands/automation/db.py @@ -107,6 +107,10 @@ async def create_engine(settings: ServiceSettings | None = None) -> EngineResult max_overflow=settings.db_max_overflow, pool_recycle=settings.db_pool_recycle, pool_pre_ping=True, + # Fail fast if pool is exhausted rather than waiting indefinitely. + # This surfaces pool exhaustion issues as errors instead of timeouts, + # making it easier to diagnose and fix (e.g., by increasing pool_size). + pool_timeout=settings.db_pool_timeout, ) return EngineResult(engine=engine, is_sqlite=False) @@ -165,6 +169,8 @@ async def getconn(): max_overflow=settings.db_max_overflow, pool_pre_ping=True, pool_recycle=settings.db_pool_recycle, + # Fail fast if pool is exhausted rather than waiting indefinitely. + pool_timeout=settings.db_pool_timeout, ) return EngineResult(engine=engine, connector=connector) diff --git a/openhands/automation/dispatcher.py b/openhands/automation/dispatcher.py index a294f86..f3d5955 100644 --- a/openhands/automation/dispatcher.py +++ b/openhands/automation/dispatcher.py @@ -39,6 +39,7 @@ ) from openhands.automation.utils import log_extra from openhands.automation.utils.api_key import APIKeyError +from openhands.automation.utils.kv import create_kv_token from openhands.automation.utils.run import ( disable_automation, mark_run_status, @@ -218,6 +219,18 @@ async def _fail(error: str, disable: bool = False) -> None: env_vars["SANDBOX_ID"] = ctx.sandbox_id env_vars["SESSION_API_KEY"] = ctx.session_key + # Inject a KV token whenever the service has a KV secret configured. + # The KV store is always available to automations — there is no per- + # automation toggle. If no secret is configured the feature is simply + # disabled service-wide. + kv_config = get_config().kv + if kv_config.kv_secret: + env_vars["AUTOMATION_KV_TOKEN"] = create_kv_token( + secret=kv_config.kv_secret, + automation_id=automation.id, + run_id=run.id, + ) + # 4. Prepare tarball source try: tarball_source: bytes | str diff --git a/openhands/automation/kv_helpers.py b/openhands/automation/kv_helpers.py new file mode 100644 index 0000000..5a51885 --- /dev/null +++ b/openhands/automation/kv_helpers.py @@ -0,0 +1,399 @@ +"""Helper functions for KV store operations. + +Provides utilities for: +- Parsing and manipulating nested paths in JSON values +- Safe encryption/decryption with proper HTTP error handling +- Type validation helpers for KV values +- Key name validation +""" + +import logging +from typing import Any + +from fastapi import HTTPException, status + +from openhands.automation.utils.kv import ( + KVEncryptionError, + KVValueError, + decrypt_value, + encrypt_value, +) + + +logger = logging.getLogger(__name__) + + +# Maximum key length (matches database column constraint) +_MAX_KEY_LENGTH = 255 + +# Maximum path depth (matches value nesting depth limit) +_MAX_PATH_DEPTH = 32 + + +# --- Key Validation --- + + +def validate_key(key: str) -> str: + """Validate a KV key name for safe storage and retrieval. + + Keys are validated to ensure they: + - Are not empty or whitespace-only + - Don't start with '$' (reserved for system keys like $version) + - Don't exceed the database column length limit (255 chars) + - Don't contain control characters (which could cause issues in logs, URLs, etc.) + + Args: + key: The key name to validate + + Returns: + The validated key (unmodified if valid) + + Raises: + HTTPException: 400 Bad Request with descriptive error if validation fails + """ + if not key: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="invalid_key: key cannot be empty", + ) + + if not key.strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="invalid_key: key cannot be whitespace-only", + ) + + # Reserve $ prefix for system keys ($version, future meta keys) + if key.startswith("$"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="invalid_key: keys starting with '$' are reserved for system use", + ) + + if len(key) > _MAX_KEY_LENGTH: + msg = f"invalid_key: key exceeds {_MAX_KEY_LENGTH} chars ({len(key)} given)" + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=msg) + + # Check for control characters (ASCII 0-31 and 127) + # These can cause issues in logging, URLs, and debugging + for i, char in enumerate(key): + code = ord(char) + if code < 32 or code == 127: + char_repr = f"\\x{code:02x}" if code < 32 else "\\x7f" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"invalid_key: control character {char_repr} at position {i}", + ) + + return key + + +# --- HTTP Error Helpers --- + + +def safe_encrypt(secret: str, value: Any) -> str: + """Encrypt a value with proper HTTP error handling. + + Wraps encrypt_value() to convert exceptions to appropriate HTTP errors: + - KVValueError (invalid JSON) → 400 Bad Request + - KVEncryptionError (encryption failure) → 500 Internal Server Error + + JSON Validation: + Values are validated before encryption to ensure they are strict JSON: + - NaN, Infinity, -Infinity are rejected (not valid JSON) + - Maximum nesting depth is enforced (prevents DoS) + - Non-serializable types are rejected + + Args: + secret: The encryption secret + value: Any JSON-serializable value + + Returns: + Encrypted Fernet token (URL-safe base64 string) + + Raises: + HTTPException: 400 for invalid values, 500 for encryption errors + """ + try: + return encrypt_value(secret, value) + except KVValueError as e: + # Client's fault: invalid JSON value (NaN, too deep, non-serializable) + logger.warning("Invalid KV value rejected: %s", e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"invalid_value: {e}", + ) + except KVEncryptionError as e: + # Our fault: encryption failed unexpectedly + logger.error("Failed to encrypt KV value: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to encrypt value", + ) + + +def safe_decrypt(secret: str, encrypted: str) -> Any: + """Decrypt a value with proper HTTP error handling. + + Wraps decrypt_value() to convert KVEncryptionError to HTTP 500. + + Args: + secret: The encryption secret + encrypted: Encrypted Fernet token from the database + + Returns: + The decrypted JSON value + + Raises: + HTTPException: 500 for decryption errors + """ + try: + return decrypt_value(secret, encrypted) + except KVEncryptionError as e: + # Our fault: decryption failed (corrupted data, wrong key, etc.) + logger.error("Failed to decrypt KV value: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to decrypt value", + ) + + +def require_dict(value: Any) -> dict: + """Validate that a value is a dict, raising HTTP 400 if not. + + Args: + value: The value to check + + Returns: + The value (for chaining) + + Raises: + HTTPException: 400 if value is not a dict + """ + if not isinstance(value, dict): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="type_mismatch: value is not an object", + ) + return value + + +def require_list(value: Any) -> list: + """Validate that a value is a list, raising HTTP 400 if not. + + Args: + value: The value to check + + Returns: + The value (for chaining) + + Raises: + HTTPException: 400 if value is not a list + """ + if not isinstance(value, list): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="type_mismatch: value is not a list", + ) + return value + + +def require_numeric(value: Any) -> int | float: + """Validate that a value is numeric (int or float), raising HTTP 400 if not. + + Note: Booleans are explicitly rejected even though bool is a subclass of int + in Python. This prevents confusing behavior where True becomes 2 after increment. + + Args: + value: The value to check + + Returns: + The value (for chaining) + + Raises: + HTTPException: 400 if value is not numeric (or is a boolean) + """ + # Explicitly reject booleans (bool is subclass of int in Python) + if isinstance(value, bool): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="type_mismatch: value is boolean, not numeric", + ) + if not isinstance(value, (int, float)): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="type_mismatch: value is not numeric", + ) + return value + + +def require_int(value: Any) -> int: + """Validate that a value is an integer, raising HTTP 400 if not. + + This is stricter than require_numeric - it rejects floats. + Used for operations like incr/decr where float arithmetic could + cause unexpected precision loss. + + Note: Booleans are explicitly rejected even though bool is a subclass of int + in Python. This prevents confusing behavior where True becomes 2 after increment. + + Args: + value: The value to check + + Returns: + The value (for chaining) + + Raises: + HTTPException: 400 if value is not an integer + """ + # Explicitly reject booleans (bool is subclass of int in Python) + if isinstance(value, bool): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="type_mismatch: value is boolean, not integer", + ) + if not isinstance(value, int): + if isinstance(value, float): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="type_mismatch: value is float, not integer (integer required)", + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="type_mismatch: value is not an integer", + ) + return value + + +# --- Path Operations --- + + +def parse_path(path: str) -> list[str]: + """Parse a path string into parts. + + Supports: + - Dot notation: database.host + - Bracket notation: config["my.key.with.dots"] + + Args: + path: A dot-notation or bracket-notation path string. + + Returns: + List of path segments. + + Raises: + ValueError: If path has invalid syntax (e.g., unclosed bracket) or + exceeds maximum depth (_MAX_PATH_DEPTH). + """ + parts: list[str] = [] + current = "" + i = 0 + + while i < len(path): + char = path[i] + + if char == ".": + if current: + parts.append(current) + current = "" + elif char == "[": + if current: + parts.append(current) + current = "" + # Find closing bracket + end = path.find("]", i) + if end == -1: + raise ValueError(f"Invalid path: unclosed bracket in '{path}'") + # Extract key (strip quotes if present) + key = path[i + 1 : end] + if key.startswith('"') and key.endswith('"'): + key = key[1:-1] + elif key.startswith("'") and key.endswith("'"): + key = key[1:-1] + parts.append(key) + i = end + else: + current += char + + i += 1 + + if current: + parts.append(current) + + # Enforce path depth limit to prevent DoS via deeply nested paths + if len(parts) > _MAX_PATH_DEPTH: + raise ValueError( + f"Path exceeds maximum depth of {_MAX_PATH_DEPTH} ({len(parts)} segments)" + ) + + return parts + + +def get_nested_value(obj: Any, path: str) -> Any: + """Get a value at a nested path using dot notation. + + Supports bracket notation for keys with dots: config["my.key"] + + Args: + obj: The object to traverse (dict or list). + path: Dot-notation or bracket-notation path. + + Returns: + The value at the specified path. + + Raises: + KeyError: If path does not exist in the object. + """ + if not path: + return obj + + parts = parse_path(path) + current = obj + + for part in parts: + if isinstance(current, dict): + if part not in current: + raise KeyError(f"Path '{path}' not found") + current = current[part] + elif isinstance(current, list): + try: + idx = int(part) + current = current[idx] + except (ValueError, IndexError): + raise KeyError(f"Path '{path}' not found") + else: + raise KeyError(f"Path '{path}' not found") + + return current + + +def set_nested_value(obj: dict, path: str, value: Any) -> dict: + """Set a value at a nested path using dot notation. + + Creates intermediate dicts as needed. + + Args: + obj: The dict to modify. + path: Dot-notation or bracket-notation path. + value: The value to set at the path. + + Returns: + The modified dict (same reference as input). + + Raises: + ValueError: If intermediate path element is not a dict. + """ + parts = parse_path(path) + current = obj + + for part in parts[:-1]: + if part not in current: + current[part] = {} + current = current[part] + if not isinstance(current, dict): + raise ValueError( + f"Cannot set path '{path}': intermediate value is not a dict" + ) + + current[parts[-1]] = value + return obj diff --git a/openhands/automation/kv_metrics.py b/openhands/automation/kv_metrics.py new file mode 100644 index 0000000..ba4ba66 --- /dev/null +++ b/openhands/automation/kv_metrics.py @@ -0,0 +1,114 @@ +"""Prometheus metrics for KV store operations. + +Provides observability into KV store performance and health: +- Operation latency by type +- Lock wait time +- Conflict (409) rate +- Version mismatch rate +- State document size + +Label Design: +- We use 'automation_name' instead of 'automation_id' to reduce cardinality. +- If name is unavailable, we use a truncated hash of the automation_id. +- This keeps the metric time series manageable at scale. + +Usage: + from openhands.automation.kv_metrics import ( + kv_operation_duration, + kv_conflict_total, + record_operation, + ) + + # Record operation duration + with record_operation("set", automation_id): + await do_operation() + + # Or manually: + with kv_operation_duration.labels(operation="get", automation="my-auto").time(): + await do_operation() +""" + +import time +from collections.abc import Generator +from contextlib import contextmanager + +from prometheus_client import Counter, Histogram + + +# --- Metrics Definitions --- + +# Operation latency histogram +# Buckets optimized for typical KV operation times (10ms to 5s) +kv_operation_duration = Histogram( + "kv_operation_duration_seconds", + "Duration of KV store operations", + ["operation"], + buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0], +) + +# Lock wait time histogram +# Separate from operation duration to isolate contention from processing time +kv_lock_wait_duration = Histogram( + "kv_lock_wait_duration_seconds", + "Time spent waiting for row lock in KV operations", + buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 5.0], +) + +# Lock/statement timeout conflicts (409 responses) +kv_conflict_total = Counter( + "kv_conflict_total", + "Number of KV store lock conflicts (409 responses)", + ["reason"], # "lock_timeout" or "version_mismatch" +) + +# State document size histogram +# Buckets aligned with PostgreSQL TOAST thresholds (see config.py) +kv_state_size_bytes = Histogram( + "kv_state_size_bytes", + "Size of encrypted state documents in bytes", + buckets=[100, 500, 1000, 2000, 8000, 16000, 32000, 64000, 128000, 256000], +) + + +# --- Helper Functions --- + + +@contextmanager +def record_operation(operation: str) -> Generator[None, None, None]: + """Context manager to record operation duration. + + Usage: + with record_operation("set"): + await do_set_operation() + """ + start = time.perf_counter() + try: + yield + finally: + duration = time.perf_counter() - start + kv_operation_duration.labels(operation=operation).observe(duration) + + +@contextmanager +def record_lock_wait() -> Generator[None, None, None]: + """Context manager to record time spent waiting for row lock.""" + start = time.perf_counter() + try: + yield + finally: + duration = time.perf_counter() - start + kv_lock_wait_duration.observe(duration) + + +def record_conflict(reason: str = "lock_timeout") -> None: + """Record a conflict (409) response. + + Args: + reason: Either "lock_timeout" or "version_mismatch" + """ + kv_conflict_total.labels(reason=reason).inc() + + +def record_state_size(size_bytes: int) -> None: + """Record the size of an encrypted state document.""" + kv_state_size_bytes.observe(size_bytes) diff --git a/openhands/automation/kv_router.py b/openhands/automation/kv_router.py new file mode 100644 index 0000000..afdabcf --- /dev/null +++ b/openhands/automation/kv_router.py @@ -0,0 +1,1135 @@ +"""FastAPI router for the automation KV store API. + +Provides a Redis-like key-value store scoped per-automation for state persistence. +Values are encrypted at the application level via the SDK's :class:`Cipher` +helper (Fernet: AES-128-CBC + HMAC-SHA256) before storage. Authentication is +via per-run JWT tokens (AUTOMATION_KV_TOKEN). + +Single-Document Backend Design +============================== + +While the API presents a multi-key interface (GET /kv/{key}, PUT /kv/{key}, etc.), +the backend stores all state in a SINGLE encrypted JSON document per automation. + + API "keys" → top-level fields in the state document + +Example: + PUT /kv/config → state["config"] = value + PUT /kv/counter → state["counter"] = value + GET /kv/config → return state["config"] + +This design eliminates deadlock risk: +- Only ONE row per automation to lock +- All operations serialize through that single lock +- No multi-key ordering issues possible + +Trade-off: Every operation reads/writes the entire state blob. This is acceptable +because automation state is intended to be small and access is infrequent. +""" + +import logging +import uuid +from typing import Annotated, Any + +from fastapi import ( + APIRouter, + Body, + Depends, + Header, + HTTPException, + Query, + Response, + status, +) +from sqlalchemy import select, text +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.config import KVSettings, get_config +from openhands.automation.db import get_session +from openhands.automation.kv_helpers import ( + get_nested_value, + require_dict, + require_int, + require_list, + safe_decrypt, + safe_encrypt, + set_nested_value, + validate_key, +) +from openhands.automation.kv_metrics import ( + record_conflict, + record_lock_wait, + record_state_size, +) +from openhands.automation.kv_schemas import ( + KVBatchOperation, + KVBatchRequest, + KVBatchResponse, + KVConflictResponse, + KVDeleteResponse, + KVIncrRequest, + KVIncrResponse, + KVKeyMetaResponse, + KVKeyPathResponse, + KVKeyResponse, + KVListKeysResponse, + KVListLengthResponse, + KVListPushRequest, + KVPatchRequest, + KVSetResponse, +) +from openhands.automation.models import AutomationKV +from openhands.automation.utils.kv import KVTokenClaims, KVTokenError, verify_kv_token + + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/v1/kv", tags=["KV Store"]) + + +# --- Authentication --- + + +async def get_token_claims( + authorization: Annotated[str, Header()], +) -> KVTokenClaims: + """Extract and verify claims from the KV token. + + The token is passed via Authorization: Bearer header. + It contains the automation_id as a trusted claim. + """ + kv_config = get_config().kv + + if not kv_config.kv_secret: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="KV store not configured (missing AUTOMATION_KV_SECRET)", + ) + + if not authorization.startswith("Bearer "): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authorization header format", + ) + + token = authorization.removeprefix("Bearer ").strip() + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing token", + ) + + try: + return verify_kv_token(kv_config.kv_secret, token) + except KVTokenError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), + ) + + +# Backward-compatible alias for tests +async def get_automation_id_from_token( + authorization: Annotated[str, Header()], +) -> uuid.UUID: + """Extract automation_id from KV token (deprecated, use get_token_claims).""" + claims = await get_token_claims(authorization) + return claims.automation_id + + +# --- Validation Helpers --- + + +# Type alias for validated KV keys - ensures key validation is applied +# Use this as a FastAPI path parameter annotation: key: ValidatedKey +ValidatedKey = Annotated[str, Depends(lambda key: validate_key(key))] + + +def _check_state_size( + state: dict[str, Any], kv_config: KVSettings | None = None +) -> None: + """Validate that the entire state document doesn't exceed the configured size limit. + + Args: + state: The state dict to check (will be JSON-serialized to measure size) + kv_config: Optional KVSettings object (fetched if not provided) + + Raises: + HTTPException: 413 Payload Too Large if state exceeds limit + """ + import json + + if kv_config is None: + kv_config = get_config().kv + + max_size = kv_config.kv_max_value_size + if max_size <= 0: + return # Size limit disabled + + # Measure the JSON-serialized size (this is what gets encrypted/stored) + try: + serialized = json.dumps(state) + except (TypeError, ValueError): + # If we can't serialize it, the encrypt step will fail anyway + return + + actual_size = len(serialized.encode("utf-8")) + if actual_size > max_size: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail=f"State size ({actual_size} bytes) exceeds limit ({max_size} bytes)", + ) + + +# --- Database Helpers --- + + +async def _get_state_row( + session: AsyncSession, + automation_id: uuid.UUID, +) -> AutomationKV | None: + """Get the state row for an automation (no lock).""" + result = await session.execute( + select(AutomationKV).where(AutomationKV.automation_id == automation_id) + ) + return result.scalars().first() + + +async def _get_state_row_for_update( + session: AsyncSession, + automation_id: uuid.UUID, + lock_timeout_ms: int = 5000, +) -> AutomationKV | None: + """Get the state row with FOR UPDATE lock and bounded wait time. + + Since there's only ONE row per automation, this is the single lock point. + All concurrent operations on this automation's state will serialize here. + + Timeout Strategy (Defense in Depth): + + 1. Statement Timeout (2x lock timeout): Safety net that kills any runaway + query, including slow encryption, network issues, or unexpected operations. + This catches problems AFTER the lock is acquired. + + 2. Lock Timeout (service-wide default): Fail fast if waiting too long for + another transaction to release the row lock. This catches contention + BEFORE the lock is acquired. Configured via AUTOMATION_KV_LOCK_TIMEOUT_MS + on the service (single global value — no per-automation knob). + + Statement timeout > lock timeout because: + - If we're waiting for a lock, lock_timeout triggers first + - If we have the lock but operation is slow, statement_timeout triggers + - The 2x ratio gives legitimate operations enough headroom + + SET LOCAL scopes both timeouts to this transaction only, so they don't + affect other queries in this session or pollute the connection pool. + + If either timeout fires, PostgreSQL raises an error which we catch and + convert to HTTP 409 Conflict, allowing clients to retry with backoff. + + Args: + session: Database session + automation_id: UUID of the automation + lock_timeout_ms: Lock timeout in milliseconds (from KVSettings) + """ + # Statement timeout: 2x lock timeout as safety net for runaway operations + statement_timeout_ms = lock_timeout_ms * 2 + stmt_sql = f"SET LOCAL statement_timeout = '{statement_timeout_ms}ms'" + await session.execute(text(stmt_sql)) + # Lock timeout: fail fast when waiting for lock (configurable per-automation) + lock_sql = f"SET LOCAL lock_timeout = '{lock_timeout_ms}ms'" + await session.execute(text(lock_sql)) + + # Record lock wait time + with record_lock_wait(): + result = await session.execute( + select(AutomationKV) + .where(AutomationKV.automation_id == automation_id) + .with_for_update() + ) + return result.scalars().first() + + +def _is_lock_timeout_error(exc: Exception) -> bool: + """Check if an exception is a PostgreSQL lock or statement timeout error. + + PostgreSQL error codes: + - 55P03 (lock_not_available): lock_timeout exceeded while waiting for lock + - 57014 (query_canceled): statement_timeout exceeded during query execution + + Both indicate the operation took too long and should be retried. + """ + error_str = str(exc).lower() + return ( + # Lock timeout errors (55P03) + "lock_not_available" in error_str + or "55p03" in error_str + or "could not obtain lock" in error_str + or "canceling statement due to lock timeout" in error_str + # Statement timeout errors (57014) + or "query_canceled" in error_str + or "57014" in error_str + or "canceling statement due to statement timeout" in error_str + ) + + +# Default retry delay in seconds for 409 responses +_RETRY_AFTER_SECONDS = "1" + + +def _raise_lock_conflict() -> None: + """Raise HTTP 409 for lock/statement timeout - signals client should retry. + + Includes Retry-After header suggesting initial backoff delay. + Clients should use exponential backoff with jitter on subsequent retries. + """ + record_conflict("lock_timeout") + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="kv_store_busy: another operation is in progress, please retry", + headers={"Retry-After": _RETRY_AFTER_SECONDS}, + ) + + +def _raise_version_conflict(expected: int, actual: int) -> None: + """Raise HTTP 409 for version mismatch - signals optimistic concurrency failure.""" + record_conflict("version_mismatch") + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "error": "version_mismatch", + "message": "State was modified by another process", + "expected_version": expected, + "actual_version": actual, + }, + headers={"Retry-After": _RETRY_AFTER_SECONDS}, + ) + + +def _decrypt_state(secret: str, row: AutomationKV | None) -> dict[str, Any]: + """Decrypt the state document from a row, returning empty dict if no row.""" + if row is None: + return {} + return safe_decrypt(secret, row.state_encrypted) + + +async def _save_state( + session: AsyncSession, + automation_id: uuid.UUID, + state: dict[str, Any], + secret: str, + existing_row: AutomationKV | None, + *, + bump_version: bool = True, +) -> AutomationKV: + """Save the state document, creating or updating the row as needed. + + Args: + session: Database session + automation_id: The automation's UUID + state: The state dict to save (will be encrypted) + secret: Encryption secret + existing_row: Existing row to update, or None to create new + bump_version: If True (default), auto-increment $version + + Returns: + The saved/updated AutomationKV row + """ + # Auto-increment $version on every write (unless explicitly disabled) + if bump_version: + state["$version"] = state.get("$version", 0) + 1 + + encrypted = safe_encrypt(secret, state) + + # Record state size metric (encrypted size includes crypto overhead) + record_state_size(len(encrypted)) + + if existing_row is None: + # Create new row + row = AutomationKV( + automation_id=automation_id, + state_encrypted=encrypted, + ) + session.add(row) + else: + # Update existing row + existing_row.state_encrypted = encrypted + row = existing_row + + await session.flush() + await session.refresh(row) + return row + + +def _get_version(state: dict[str, Any]) -> int: + """Get the current $version from state, defaulting to 0.""" + return state.get("$version", 0) + + +# --- Endpoints --- + + +@router.get("") +async def list_keys( + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVListKeysResponse: + """List all keys for this automation. + + Note: System keys (starting with $) are filtered from the response. + """ + kv_config = get_config().kv + + row = await _get_state_row(session, claims.automation_id) + state = _decrypt_state(kv_config.kv_secret, row) + + # Filter out system keys (e.g., $version) + keys = [k for k in state.keys() if not k.startswith("$")] + return KVListKeysResponse(keys=keys, count=len(keys)) + + +@router.get("/{key}") +async def get_value( + key: ValidatedKey, + path: str | None = Query(default=None, description="Nested path (dot notation)"), + meta: bool = Query(default=False, description="Include metadata and version"), + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVKeyResponse | KVKeyPathResponse | KVKeyMetaResponse: + """Get a value by key, optionally at a nested path. + + With meta=true, includes version for optimistic concurrency control. + """ + kv_config = get_config().kv + + row = await _get_state_row(session, claims.automation_id) + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="key_not_found", + ) + + value = state[key] + + if path: + try: + value = get_nested_value(value, path) + except KeyError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="invalid_path", + ) + return KVKeyPathResponse(key=key, path=path, value=value) + + if meta: + if row is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="key_not_found", + ) + return KVKeyMetaResponse( + key=key, + value=value, + version=_get_version(state), + created_at=row.created_at.isoformat(), + updated_at=row.updated_at.isoformat(), + ) + + return KVKeyResponse(key=key, value=value) + + +@router.put("/{key}") +async def set_value( + key: ValidatedKey, + body: Annotated[Any, Body()], # Accept any JSON body directly as the value + response: Response, + nx: bool = Query(default=False, description="Only set if key does not exist"), + xx: bool = Query(default=False, description="Only set if key exists"), + if_version: int | None = Query( + default=None, + description="Only set if current state version matches (optimistic lock)", + ), + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVSetResponse | KVConflictResponse: + """Set a value for a key. + + The entire request body is stored as the value. + + Query params: + - nx=true: Only set if key does NOT exist (like Redis SETNX) + - xx=true: Only set if key DOES exist + - if_version=N: Only set if current $version equals N (optimistic concurrency) + + Returns: + - 200: Key updated (existing key) + - 201: Key created (new key, or nx=true success) + - 409: Conflict (nx/xx/if_version check failed) + - 413: Payload too large (state exceeds size limit) + """ + kv_config = get_config().kv + + if nx and xx: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot use both nx and xx", + ) + + # Lock the state row for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + # Check version if specified (optimistic concurrency) + if if_version is not None: + current_version = _get_version(state) + if current_version != if_version: + _raise_version_conflict(if_version, current_version) + + key_exists = key in state + + if nx and key_exists: + response.status_code = status.HTTP_409_CONFLICT + return KVConflictResponse(key=key, created=False, error="key_exists") + + if xx and not key_exists: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="key_not_exists", + ) + + # Update state + state[key] = body + _check_state_size(state, kv_config) + + # Save + saved_row = await _save_state( + session, claims.automation_id, state, kv_config.kv_secret, row + ) + + created = not key_exists + if created: + response.status_code = status.HTTP_201_CREATED + + return KVSetResponse( + key=key, + value=body, + created=created, + updated_at=saved_row.updated_at.isoformat(), + ) + + +@router.patch("/{key}") +async def patch_value( + key: ValidatedKey, + body: KVPatchRequest, + if_version: int | None = Query( + default=None, + description="Only patch if current state version matches (optimistic lock)", + ), + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVKeyPathResponse: + """Update a nested path within an existing value. + + Query params: + - if_version=N: Only patch if current $version equals N (optimistic concurrency) + """ + kv_config = get_config().kv + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + # Check version if specified (optimistic concurrency) + if if_version is not None: + current_version = _get_version(state) + if current_version != if_version: + _raise_version_conflict(if_version, current_version) + + if key not in state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="key_not_found", + ) + + value = state[key] + require_dict(value) + + try: + set_nested_value(value, body.path, body.value) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"invalid_path: {e}", + ) + + state[key] = value + _check_state_size(state, kv_config) + + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVKeyPathResponse( + key=key, + path=body.path, + value=body.value, + ) + + +def _has_user_keys(state: dict[str, Any]) -> bool: + """Check if state has any user keys (excluding system keys like $version).""" + return any(not k.startswith("$") for k in state.keys()) + + +@router.delete("/{key}") +async def delete_key( + key: ValidatedKey, + if_version: int | None = Query( + default=None, + description="Only delete if current state version matches (optimistic lock)", + ), + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVDeleteResponse: + """Delete a key. + + Query params: + - if_version=N: Only delete if current $version equals N (optimistic concurrency) + """ + kv_config = get_config().kv + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + # Check version if specified (optimistic concurrency) + if if_version is not None: + current_version = _get_version(state) + if current_version != if_version: + _raise_version_conflict(if_version, current_version) + + if key not in state: + return KVDeleteResponse(key=key, deleted=False) + + del state[key] + + if row is not None: + if _has_user_keys(state): + # Still have user keys, update the row + await _save_state( + session, claims.automation_id, state, kv_config.kv_secret, row + ) + else: + # No user keys left, delete the row entirely + await session.delete(row) + await session.flush() + + return KVDeleteResponse(key=key, deleted=True) + + +@router.post("/{key}/incr") +async def increment( + key: ValidatedKey, + body: KVIncrRequest | None = None, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVIncrResponse: + """Atomically increment an integer value. + + If the key doesn't exist, initializes it to `by` (default 1). + + Note: The stored value must be an integer. Float values are rejected + because integer arithmetic on floats can cause precision loss. + """ + kv_config = get_config().kv + by = body.by if body else 1 + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + # Initialize with `by` + state[key] = by + new_value = by + else: + value = state[key] + require_int(value) + new_value = value + by + state[key] = new_value + + _check_state_size(state, kv_config) + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVIncrResponse(key=key, value=new_value) + + +@router.post("/{key}/decr") +async def decrement( + key: ValidatedKey, + body: KVIncrRequest | None = None, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVIncrResponse: + """Atomically decrement an integer value. + + If the key doesn't exist, initializes it to `-by` (default -1). + + Note: The stored value must be an integer. Float values are rejected + because integer arithmetic on floats can cause precision loss. + """ + kv_config = get_config().kv + by = body.by if body else 1 + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + # Initialize with `-by` + state[key] = -by + new_value = -by + else: + value = state[key] + require_int(value) + new_value = value - by + state[key] = new_value + + _check_state_size(state, kv_config) + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVIncrResponse(key=key, value=new_value) + + +@router.post("/{key}/lpush") +async def lpush( + key: ValidatedKey, + body: KVListPushRequest, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVListLengthResponse: + """Push a value to the left (front) of a list. + + Creates the list if it doesn't exist. + """ + kv_config = get_config().kv + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + # Initialize with single-element list + state[key] = [body.value] + else: + value = state[key] + require_list(value) + value.insert(0, body.value) + state[key] = value + + _check_state_size(state, kv_config) + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVListLengthResponse(key=key, length=len(state[key])) + + +@router.post("/{key}/rpush") +async def rpush( + key: ValidatedKey, + body: KVListPushRequest, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVListLengthResponse: + """Push a value to the right (back) of a list. + + Creates the list if it doesn't exist. + """ + kv_config = get_config().kv + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + # Initialize with single-element list + state[key] = [body.value] + else: + value = state[key] + require_list(value) + value.append(body.value) + state[key] = value + + _check_state_size(state, kv_config) + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVListLengthResponse(key=key, length=len(state[key])) + + +@router.post("/{key}/lpop") +async def lpop( + key: ValidatedKey, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVKeyResponse: + """Pop a value from the left (front) of a list. + + Returns null if key doesn't exist or list is empty. + """ + kv_config = get_config().kv + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + return KVKeyResponse(key=key, value=None) + + value = state[key] + require_list(value) + + if len(value) == 0: + return KVKeyResponse(key=key, value=None) + + popped = value.pop(0) + state[key] = value + + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVKeyResponse(key=key, value=popped) + + +@router.post("/{key}/rpop") +async def rpop( + key: ValidatedKey, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVKeyResponse: + """Pop a value from the right (back) of a list. + + Returns null if key doesn't exist or list is empty. + """ + kv_config = get_config().kv + + # Lock for atomic read-modify-write + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + return KVKeyResponse(key=key, value=None) + + value = state[key] + require_list(value) + + if len(value) == 0: + return KVKeyResponse(key=key, value=None) + + popped = value.pop() + state[key] = value + + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVKeyResponse(key=key, value=popped) + + +@router.get("/{key}/len") +async def list_length( + key: ValidatedKey, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVListLengthResponse: + """Get the length of a list.""" + kv_config = get_config().kv + + row = await _get_state_row(session, claims.automation_id) + state = _decrypt_state(kv_config.kv_secret, row) + + if key not in state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="key_not_found", + ) + + value = state[key] + require_list(value) + + return KVListLengthResponse(key=key, length=len(value)) + + +# --- Batch Operations --- + + +class KVOperationError(Exception): + """Raised when a batch operation fails validation.""" + + pass + + +def _validate_batch_key(key: str) -> None: + """Validate a key for batch operations (same rules as validate_key). + + Raises: + KVOperationError: If key is invalid + """ + if not key: + raise KVOperationError("key cannot be empty") + if not key.strip(): + raise KVOperationError("key cannot be whitespace-only") + if key.startswith("$"): + raise KVOperationError("keys starting with '$' are reserved for system use") + if len(key) > 255: + raise KVOperationError(f"key exceeds 255 chars ({len(key)} given)") + + +def _execute_batch_operation( + state: dict[str, Any], + op: KVBatchOperation, +) -> dict[str, Any]: + """Execute a single operation within a batch. + + Args: + state: The current state dict (modified in place) + op: The operation to execute + + Returns: + Result dict for this operation + + Raises: + KVOperationError: If operation fails validation + """ + _validate_batch_key(op.key) + key = op.key + + if op.op == "set": + key_existed = key in state + # Handle nx (set if not exists) + if op.nx and key_existed: + raise KVOperationError(f"key '{key}' already exists (nx=true)") + # Handle xx (set if exists) + if op.xx and not key_existed: + raise KVOperationError(f"key '{key}' does not exist (xx=true)") + state[key] = op.value + return {"op": "set", "key": key, "success": True, "created": not key_existed} + + elif op.op == "delete": + deleted = key in state + if deleted: + del state[key] + return {"op": "delete", "key": key, "success": True, "deleted": deleted} + + elif op.op == "incr": + by = op.by + if key not in state: + state[key] = by + new_value = by + else: + value = state[key] + if isinstance(value, bool): + raise KVOperationError(f"key '{key}' is boolean, not integer") + if not isinstance(value, int): + raise KVOperationError(f"key '{key}' is not an integer") + new_value = value + by + state[key] = new_value + return {"op": "incr", "key": key, "success": True, "value": new_value} + + elif op.op == "decr": + by = op.by + if key not in state: + state[key] = -by + new_value = -by + else: + value = state[key] + if isinstance(value, bool): + raise KVOperationError(f"key '{key}' is boolean, not integer") + if not isinstance(value, int): + raise KVOperationError(f"key '{key}' is not an integer") + new_value = value - by + state[key] = new_value + return {"op": "decr", "key": key, "success": True, "value": new_value} + + elif op.op == "lpush": + if key not in state: + state[key] = [op.value] + else: + value = state[key] + if not isinstance(value, list): + raise KVOperationError(f"key '{key}' is not a list") + value.insert(0, op.value) + return {"op": "lpush", "key": key, "success": True, "length": len(state[key])} + + elif op.op == "rpush": + if key not in state: + state[key] = [op.value] + else: + value = state[key] + if not isinstance(value, list): + raise KVOperationError(f"key '{key}' is not a list") + value.append(op.value) + return {"op": "rpush", "key": key, "success": True, "length": len(state[key])} + + elif op.op == "lpop": + if key not in state: + return {"op": "lpop", "key": key, "success": True, "value": None} + value = state[key] + if not isinstance(value, list): + raise KVOperationError(f"key '{key}' is not a list") + if len(value) == 0: + return {"op": "lpop", "key": key, "success": True, "value": None} + popped = value.pop(0) + return {"op": "lpop", "key": key, "success": True, "value": popped} + + elif op.op == "rpop": + if key not in state: + return {"op": "rpop", "key": key, "success": True, "value": None} + value = state[key] + if not isinstance(value, list): + raise KVOperationError(f"key '{key}' is not a list") + if len(value) == 0: + return {"op": "rpop", "key": key, "success": True, "value": None} + popped = value.pop() + return {"op": "rpop", "key": key, "success": True, "value": popped} + + elif op.op == "patch": + if key not in state: + state[key] = {} + value = state[key] + if not isinstance(value, dict): + raise KVOperationError(f"key '{key}' is not an object") + try: + set_nested_value(value, op.path, op.value) + except ValueError as e: + raise KVOperationError(str(e)) + return {"op": "patch", "key": key, "success": True} + + else: + raise KVOperationError(f"unknown operation: {op.op}") + + +@router.post("/batch") +async def batch( + body: KVBatchRequest, + claims: KVTokenClaims = Depends(get_token_claims), + session: AsyncSession = Depends(get_session), +) -> KVBatchResponse: + """Execute multiple KV operations atomically in a single transaction. + + All operations succeed or none do. Use `if_version` for optimistic + concurrency control - the batch will be rejected if the current state + version doesn't match. + + Operations are executed in order. The $version is incremented once + for the entire batch, not per operation. + + Returns: + - 200: All operations succeeded + - 400: An operation failed validation (e.g., incr on a list) + - 409: Version mismatch (if_version specified but doesn't match) + - 409: Lock timeout (another operation in progress) + - 413: Payload too large (state exceeds size limit) + """ + kv_config = get_config().kv + + # Acquire lock for atomic batch execution + try: + row = await _get_state_row_for_update( + session, claims.automation_id, kv_config.kv_lock_timeout_ms + ) + except Exception as e: + if _is_lock_timeout_error(e): + _raise_lock_conflict() + raise + + state = _decrypt_state(kv_config.kv_secret, row) + current_version = _get_version(state) + + # Check version if specified + if body.if_version is not None and current_version != body.if_version: + _raise_version_conflict(body.if_version, current_version) + + # Execute all operations + results = [] + for i, op in enumerate(body.operations): + try: + result = _execute_batch_operation(state, op) + results.append(result) + except KVOperationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "operation_failed", + "message": str(e), + "operation_index": i, + "operation": {"op": op.op, "key": op.key}, + }, + ) + + # Validate state size before saving + _check_state_size(state, kv_config) + + # Save state (auto-increments $version) + await _save_state(session, claims.automation_id, state, kv_config.kv_secret, row) + + return KVBatchResponse(version=_get_version(state), results=results) diff --git a/openhands/automation/kv_schemas.py b/openhands/automation/kv_schemas.py new file mode 100644 index 0000000..de218db --- /dev/null +++ b/openhands/automation/kv_schemas.py @@ -0,0 +1,249 @@ +"""Pydantic request/response schemas for the KV store API.""" + +from typing import Any, Literal + +from pydantic import BaseModel, Field, model_validator + + +# --- Batch Operation Schemas --- + + +class KVBatchOpSet(BaseModel): + """Set operation in a batch.""" + + op: Literal["set"] + key: str = Field(..., min_length=1, max_length=255) + value: Any = Field(..., description="Value to set") + nx: bool = Field(default=False, description="Only set if key does not exist") + xx: bool = Field(default=False, description="Only set if key exists") + + @model_validator(mode="after") + def validate_nx_xx(self): + if self.nx and self.xx: + raise ValueError("Cannot use both nx and xx") + return self + + +class KVBatchOpDelete(BaseModel): + """Delete operation in a batch.""" + + op: Literal["delete"] + key: str = Field(..., min_length=1, max_length=255) + + +class KVBatchOpIncr(BaseModel): + """Increment operation in a batch.""" + + op: Literal["incr"] + key: str = Field(..., min_length=1, max_length=255) + by: int = Field(default=1, description="Amount to increment by") + + +class KVBatchOpDecr(BaseModel): + """Decrement operation in a batch.""" + + op: Literal["decr"] + key: str = Field(..., min_length=1, max_length=255) + by: int = Field(default=1, description="Amount to decrement by") + + +class KVBatchOpLPush(BaseModel): + """Left push operation in a batch.""" + + op: Literal["lpush"] + key: str = Field(..., min_length=1, max_length=255) + value: Any = Field(..., description="Value to push") + + +class KVBatchOpRPush(BaseModel): + """Right push operation in a batch.""" + + op: Literal["rpush"] + key: str = Field(..., min_length=1, max_length=255) + value: Any = Field(..., description="Value to push") + + +class KVBatchOpLPop(BaseModel): + """Left pop operation in a batch.""" + + op: Literal["lpop"] + key: str = Field(..., min_length=1, max_length=255) + + +class KVBatchOpRPop(BaseModel): + """Right pop operation in a batch.""" + + op: Literal["rpop"] + key: str = Field(..., min_length=1, max_length=255) + + +class KVBatchOpPatch(BaseModel): + """Patch operation in a batch.""" + + op: Literal["patch"] + key: str = Field(..., min_length=1, max_length=255) + path: str = Field(..., description="Dot-notation path to update") + value: Any = Field(..., description="Value to set at the path") + + +# Union of all batch operation types +KVBatchOperation = ( + KVBatchOpSet + | KVBatchOpDelete + | KVBatchOpIncr + | KVBatchOpDecr + | KVBatchOpLPush + | KVBatchOpRPush + | KVBatchOpLPop + | KVBatchOpRPop + | KVBatchOpPatch +) + + +class KVBatchRequest(BaseModel): + """Request body for batch operations.""" + + if_version: int | None = Field( + default=None, + description="Only execute if current state version matches this value", + ) + operations: list[KVBatchOperation] = Field( + ..., + min_length=1, + max_length=100, + description="List of operations to execute atomically", + ) + + +# Batch operation results are returned as dicts with the following fields: +# - op: str - The operation type +# - key: str - The key operated on +# - success: bool - Always True (batch fails atomically if any op fails) +# - Additional fields depend on operation type: +# - set: created (bool) - True if key was newly created +# - delete: deleted (bool) - True if key existed and was deleted +# - incr/decr: value (int) - New value after increment/decrement +# - lpush/rpush: length (int) - New list length +# - lpop/rpop: value (Any) - Popped value, or null if list was empty +# - patch: (no additional fields) + + +class KVBatchResponse(BaseModel): + """Response for successful batch operation.""" + + version: int = Field(description="New state version after batch") + results: list[dict[str, Any]] = Field( + description="Results for each operation in order" + ) + + +class KVVersionMismatchResponse(BaseModel): + """Response when batch fails due to version mismatch.""" + + error: Literal["version_mismatch"] = "version_mismatch" + message: str = "State was modified by another process" + expected_version: int + actual_version: int + + +# --- Request Schemas --- + + +class KVSetRequest(BaseModel): + """Request body for setting a KV value (used when body is explicit).""" + + value: Any = Field(..., description="Any JSON-serializable value") + + +class KVPatchRequest(BaseModel): + """Request body for patching a nested path.""" + + path: str = Field( + ..., description="Dot-notation path to update (e.g., 'database.port')" + ) + value: Any = Field(..., description="Value to set at the path") + + +class KVIncrRequest(BaseModel): + """Request body for increment/decrement operations.""" + + by: int = Field(default=1, description="Amount to increment/decrement by") + + +class KVListPushRequest(BaseModel): + """Request body for list push operations.""" + + value: Any = Field(..., description="Value to push onto the list") + + +# --- Response Schemas --- + + +class KVKeyResponse(BaseModel): + """Response containing a key and its value.""" + + key: str + value: Any + + +class KVKeyPathResponse(BaseModel): + """Response containing a key, path, and value.""" + + key: str + path: str + value: Any + + +class KVKeyMetaResponse(BaseModel): + """Response containing a key, value, and metadata.""" + + key: str + value: Any + version: int + created_at: str + updated_at: str + + +class KVSetResponse(BaseModel): + """Response for set operations.""" + + key: str + value: Any + created: bool + updated_at: str + + +class KVDeleteResponse(BaseModel): + """Response for delete operations.""" + + key: str + deleted: bool + + +class KVListKeysResponse(BaseModel): + """Response for listing keys.""" + + keys: list[str] + count: int + + +class KVIncrResponse(BaseModel): + """Response for increment/decrement operations.""" + + key: str + value: int + + +class KVListLengthResponse(BaseModel): + """Response for list length operations.""" + + key: str + length: int + + +class KVConflictResponse(BaseModel): + """Response when a conditional operation fails.""" + + key: str + created: bool = False + error: str diff --git a/openhands/automation/models.py b/openhands/automation/models.py index 2b3bbd9..7ec585d 100644 --- a/openhands/automation/models.py +++ b/openhands/automation/models.py @@ -317,3 +317,66 @@ class CustomWebhook(Base): __table_args__ = ( Index("ix_custom_webhooks_org_source", "org_id", "source", unique=True), ) + + +class AutomationKV(Base): + """Single-document state store for automation persistence. + + Each automation has exactly ONE row containing its entire state as an + encrypted JSON document. The API presents a key-value interface, but + "keys" are top-level fields within this single document. + + Single-Document Design (Deadlock Prevention): + By storing all state in one row per automation, we eliminate multi-key + deadlock scenarios. All operations on an automation's state serialize + through a single row lock. There's no possibility of lock ordering + issues because there's only one lock to acquire. + + Trade-off: Every operation reads/writes the entire state blob. This is + acceptable because automation state is intended to be small (cursors, + counters, configs) and access is infrequent (scheduled runs). + + Storage Design: + We store encrypted state as a Fernet token (URL-safe base64 text) + produced by the SDK's :class:`Cipher`. See + ``openhands/automation/utils/kv.py`` for the full encryption rationale. + """ + + __tablename__ = "automation_kv" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + automation_id: Mapped[uuid.UUID] = mapped_column( + Uuid, + ForeignKey("automations.id", ondelete="CASCADE"), + nullable=False, + unique=True, # ONE row per automation + ) + + # Fernet token (URL-safe base64 text) containing the entire state document + # as JSON. Produced by openhands.sdk.utils.cipher.Cipher.encrypt and + # consumed by Cipher.decrypt. The decrypted JSON is a dict where keys are + # the "KV keys" exposed via the API. + # Example decrypted: {"config": {...}, "counter": 42, "queue": [...]} + state_encrypted: Mapped[str] = mapped_column(Text, nullable=False) + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + onupdate=utcnow, + nullable=False, + ) + + __table_args__ = ( + # Index for efficient lookup by automation_id (unique constraint + # is already defined on the column, this ensures index exists) + Index( + "ix_automation_kv_automation_id", + "automation_id", + unique=True, + ), + ) diff --git a/openhands/automation/utils/kv.py b/openhands/automation/utils/kv.py new file mode 100644 index 0000000..dd4aa26 --- /dev/null +++ b/openhands/automation/utils/kv.py @@ -0,0 +1,218 @@ +"""KV store utilities: JWT tokens, JSON validation, and value encryption. + +This module provides: +- JWT token generation/verification for KV store authentication +- Strict-JSON validation (rejects NaN/Infinity, caps nesting depth) +- Encryption/decryption of KV state via the SDK's :class:`Cipher` + +The :class:`Cipher` helper from ``openhands.sdk.utils.cipher`` wraps Fernet +(AES-128-CBC + HMAC-SHA256). It derives a 256-bit key from the configured +service secret, generates a fresh IV per encryption, and authenticates the +ciphertext, which is everything we need for protecting per-automation state +at rest. We deliberately use the SDK's Cipher instead of rolling our own AES +to keep this module small and to share a battle-tested implementation with +the rest of the platform. + +Fernet emits a URL-safe base64 string ("token") rather than raw bytes, so KV +state is stored in a text column. The ~33% base64 overhead is acceptable for +small automation state (counters, cursors, configs) and keeps the schema +simple. +""" + +import json +import uuid +from datetime import UTC, datetime, timedelta +from typing import Any + +import jwt +from pydantic import SecretStr + +from openhands.sdk.utils.cipher import Cipher + + +class KVTokenError(Exception): + """Error with KV store JWT token.""" + + +class KVEncryptionError(Exception): + """Error with KV value encryption/decryption.""" + + +class KVValueError(Exception): + """Error with KV value format or content.""" + + +# Maximum nesting depth for JSON values. +# Prevents stack overflow from deeply nested structures and limits complexity. +# 32 levels is generous (most real configs are <10 levels deep). +_MAX_NESTING_DEPTH = 32 + +# Token expiration: 24 hours +# +# Intentionally longer than the max automation run time (currently 2 hours) +# to provide margin for long-running automations, cleanup operations after +# run completion, and clock skew between services. The token only grants +# access to its specific automation's KV data, so a longer validity window +# has minimal security impact. +KV_TOKEN_EXPIRATION_HOURS = 24 + + +# --- JWT Token Functions --- + + +class KVTokenClaims: + """Verified claims from a KV store JWT token.""" + + __slots__ = ("automation_id",) + + def __init__(self, automation_id: uuid.UUID): + self.automation_id = automation_id + + +def create_kv_token( + secret: str, + automation_id: uuid.UUID, + run_id: uuid.UUID, +) -> str: + """Create a JWT token for KV store access. + + The token embeds the automation_id as a trusted claim, ensuring that + KV operations are scoped to the correct automation. + + Args: + secret: The signing secret (AUTOMATION_KV_SECRET) + automation_id: UUID of the automation + run_id: UUID of the current run (for audit) + + Returns: + Signed JWT token string + """ + now = datetime.now(UTC) + payload = { + "automation_id": str(automation_id), + "run_id": str(run_id), + "iat": now, + "exp": now + timedelta(hours=KV_TOKEN_EXPIRATION_HOURS), + } + return jwt.encode(payload, secret, algorithm="HS256") + + +def verify_kv_token(secret: str, token: str) -> KVTokenClaims: + """Verify a KV store JWT token and extract claims. + + Args: + secret: The signing secret (AUTOMATION_KV_SECRET) + token: The JWT token to verify + + Returns: + KVTokenClaims with automation_id + + Raises: + KVTokenError: If token is invalid, expired, or malformed + """ + try: + payload = jwt.decode(token, secret, algorithms=["HS256"]) + automation_id_str = payload.get("automation_id") + if not automation_id_str: + raise KVTokenError("Token missing automation_id claim") + return KVTokenClaims(automation_id=uuid.UUID(automation_id_str)) + except jwt.ExpiredSignatureError: + raise KVTokenError("Token has expired") + except jwt.InvalidTokenError as e: + raise KVTokenError(f"Invalid token: {e}") + except ValueError as e: + raise KVTokenError(f"Invalid automation_id format: {e}") + + +# --- JSON Validation --- + + +def _check_nesting_depth(value: Any, current_depth: int = 0) -> None: + """Check that a value doesn't exceed maximum nesting depth. + + Raises: + KVValueError: If nesting exceeds _MAX_NESTING_DEPTH + """ + if current_depth > _MAX_NESTING_DEPTH: + raise KVValueError( + f"Value exceeds maximum nesting depth of {_MAX_NESTING_DEPTH}" + ) + + if isinstance(value, dict): + for v in value.values(): + _check_nesting_depth(v, current_depth + 1) + elif isinstance(value, list): + for item in value: + _check_nesting_depth(item, current_depth + 1) + + +def _validate_json_value(value: Any) -> str: + """Validate and serialize a value to strict JSON. + + Ensures the value is JSON-serializable, contains only standard JSON + types (rejects NaN/Infinity), and doesn't exceed maximum nesting depth. + + Raises: + KVValueError: If value is not valid strict JSON + """ + try: + _check_nesting_depth(value) + except RecursionError: + raise KVValueError( + f"Value exceeds maximum nesting depth of {_MAX_NESTING_DEPTH}" + ) + + # Strict JSON: allow_nan=False rejects NaN/Infinity, ensure_ascii=False + # allows UTF-8 (more compact, widely supported). + try: + return json.dumps(value, allow_nan=False, ensure_ascii=False) + except ValueError as e: + raise KVValueError(f"Value contains non-JSON-compliant data: {e}") + except TypeError as e: + raise KVValueError(f"Value is not JSON-serializable: {e}") + + +# --- Encryption Functions --- + + +def encrypt_value(secret: str, value: Any) -> str: + """Encrypt a value for storage using the SDK Cipher (Fernet). + + Validates and JSON-serializes the value, then returns a base64-encoded + Fernet token suitable for storage in a TEXT column. + + Raises: + KVValueError: If value is not valid strict JSON + KVEncryptionError: If encryption fails + """ + plaintext_str = _validate_json_value(value) + try: + ciphertext = Cipher(secret).encrypt(SecretStr(plaintext_str)) + except Exception as e: + raise KVEncryptionError(f"Failed to encrypt value: {e}") + assert ciphertext is not None # SecretStr is non-None, so result is non-None + return ciphertext + + +def decrypt_value(secret: str, encrypted: str) -> Any: + """Decrypt a value previously produced by :func:`encrypt_value`. + + Returns the parsed JSON object. + + Raises: + KVEncryptionError: If decryption fails (wrong key, tampered data, etc.) + """ + try: + plaintext_secret = Cipher(secret).decrypt(encrypted) + except Exception as e: + raise KVEncryptionError(f"Failed to decrypt value: {e}") + + if plaintext_secret is None: + # Cipher.decrypt returns None on InvalidToken; surface as an explicit + # encryption error so callers can map it to HTTP 500. + raise KVEncryptionError("Failed to decrypt value: invalid token") + + try: + return json.loads(plaintext_secret.get_secret_value()) + except json.JSONDecodeError as e: + raise KVEncryptionError(f"Decrypted value is not valid JSON: {e}") diff --git a/openhands/automation/utils/log_context.py b/openhands/automation/utils/log_context.py index 326873d..01aab99 100644 --- a/openhands/automation/utils/log_context.py +++ b/openhands/automation/utils/log_context.py @@ -7,6 +7,7 @@ def log_extra( run_id: str | None = None, sandbox_id: str | None = None, automation_id: str | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Build extra dict for structured logging with contextual IDs. @@ -14,6 +15,7 @@ def log_extra( run_id: The automation run ID. sandbox_id: The sandbox ID. automation_id: The automation definition ID. + **kwargs: Additional context fields to include. Returns: Dict with non-None values for use as logger extra parameter. @@ -25,4 +27,5 @@ def log_extra( extra["sandbox_id"] = sandbox_id if automation_id: extra["automation_id"] = automation_id + extra.update(kwargs) return extra diff --git a/pyproject.toml b/pyproject.toml index af0b40f..6704752 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,10 @@ dependencies = [ "openhands-sdk==1.22.1", "openhands-workspace==1.22.1", "pg8000>=1.31", + "prometheus-client>=0.19", "pydantic>=2", "pydantic-settings>=2", + "pyjwt>=2.8", "python-json-logger>=3", "sqlalchemy[asyncio]>=2", "tenacity>=9.1.4", diff --git a/scripts/test_kv_e2e.py b/scripts/test_kv_e2e.py new file mode 100755 index 0000000..9732364 --- /dev/null +++ b/scripts/test_kv_e2e.py @@ -0,0 +1,955 @@ +#!/usr/bin/env python3 +"""End-to-end test for KV store functionality with full stdout/stderr capture. + +This script: +1. Creates a real automation via API (KV store is always available) +2. Generates a KV token for that automation +3. Uses run_automation() to execute a test script with full output capture +4. Cleans up the automation + +Usage: + export OPENHANDS_API_KEY="sk-oh-..." + export AUTOMATION_KV_SECRET="" # Required for token generation + python scripts/test_kv_e2e.py + + # Optional: specify staging URL + export OPENHANDS_API_URL="https://staging.all-hands.dev" +""" + +import asyncio +import os +import sys +import uuid +from pathlib import Path + +import httpx + + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from openhands.automation.execution import build_tarball, run_automation # noqa: E402 +from openhands.automation.utils.kv import create_kv_token # noqa: E402 + + +# --------------------------------------------------------------------------- +# Test script that runs inside the sandbox +# --------------------------------------------------------------------------- + +KV_TEST_SCRIPT = ''' +"""KV store test script - runs inside sandbox. + +Supports two modes: + --quick Smoke test: one test per operation type (~8 tests) + --thorough Full coverage: all edge cases and error paths (~40 tests) + +Default is --quick if no argument provided. +""" + +import json +import os +import sys + +# Use urllib since requests may not be installed +from urllib.request import Request, urlopen +from urllib.error import HTTPError + + +# Test registry +QUICK_TESTS = [] +THOROUGH_TESTS = [] + + +def quick(fn): + """Decorator to mark a test as part of quick suite.""" + QUICK_TESTS.append(fn) + THOROUGH_TESTS.append(fn) + return fn + + +def thorough(fn): + """Decorator to mark a test as thorough-only.""" + THOROUGH_TESTS.append(fn) + return fn + + +def api_call(method, path, body=None, headers=None): + """Make an HTTP request to the KV API.""" + url = f"{API_URL}/api/automation/v1/kv{path}" + req_headers = {"Authorization": f"Bearer {KV_TOKEN}"} + if headers: + req_headers.update(headers) + + data = None + if body is not None: + data = json.dumps(body).encode("utf-8") + req_headers["Content-Type"] = "application/json" + + req = Request(url, data=data, headers=req_headers, method=method) + + try: + with urlopen(req, timeout=30) as resp: + return resp.status, json.loads(resp.read().decode("utf-8")) + except HTTPError as e: + try: + body = json.loads(e.read().decode("utf-8")) + except Exception: + body = {"error": str(e)} + return e.code, body + + +def api_call_raw(method, path, body=None, headers=None, auth=True): + """Make HTTP request with optional auth control (for auth tests).""" + url = f"{API_URL}/api/automation/v1/kv{path}" + req_headers = {} + if auth: + req_headers["Authorization"] = f"Bearer {KV_TOKEN}" + if headers: + req_headers.update(headers) + + data = None + if body is not None: + if isinstance(body, bytes): + data = body + else: + data = json.dumps(body).encode("utf-8") + if "Content-Type" not in req_headers: + req_headers["Content-Type"] = "application/json" + + req = Request(url, data=data, headers=req_headers, method=method) + + try: + with urlopen(req, timeout=30) as resp: + return resp.status, resp.read().decode("utf-8") + except HTTPError as e: + return e.code, e.read().decode("utf-8") + + +# =========================================================================== +# QUICK TESTS - Core functionality smoke tests +# =========================================================================== + +@quick +def test_set_get(): + """[TC-3.1/3.3] Basic SET and GET operations.""" + print("\\n[TEST] SET and GET") + + # SET + status, resp = api_call("PUT", "/test_key", {"message": "hello", "count": 42}) + print(f" PUT /test_key: {status}") + if status not in (200, 201): + print(f" FAIL: {resp}") + return False + + # GET + status, resp = api_call("GET", "/test_key") + print(f" GET /test_key: {status}") + if status != 200: + print(f" FAIL: {resp}") + return False + + expected = {"message": "hello", "count": 42} + if resp.get("value") != expected: + print(f" FAIL: Expected {expected}, got {resp.get('value')}") + return False + + print(" PASS") + return True + + +@quick +def test_delete(): + """[TC-3.8] DELETE operation.""" + print("\\n[TEST] DELETE") + + api_call("PUT", "/to_delete", "bye") + + status, resp = api_call("DELETE", "/to_delete") + print(f" DELETE /to_delete: {status}") + if status != 200: + print(f" FAIL: Expected 200, got {status}") + return False + + # Verify gone + status, resp = api_call("GET", "/to_delete") + print(f" GET after delete: {status}") + if status != 404: + print(f" FAIL: Expected 404, got {status}") + return False + + print(" PASS") + return True + + +@quick +def test_incr_decr(): + """[TC-6.2/6.4] INCR and DECR on existing key.""" + print("\\n[TEST] INCR and DECR") + + api_call("PUT", "/counter", 10) + + status, resp = api_call("POST", "/counter/incr", {"by": 5}) + print(f" INCR by 5: {status}, value={resp.get('value')}") + if resp.get("value") != 15: + print(f" FAIL: Expected 15, got {resp.get('value')}") + return False + + status, resp = api_call("POST", "/counter/decr", {"by": 3}) + print(f" DECR by 3: {status}, value={resp.get('value')}") + if resp.get("value") != 12: + print(f" FAIL: Expected 12, got {resp.get('value')}") + return False + + print(" PASS") + return True + + +@quick +def test_list_operations(): + """[TC-7.1-7.6] List RPUSH, LPUSH, LPOP, RPOP, LEN.""" + print("\\n[TEST] List operations") + + api_call("DELETE", "/my_list") + + # RPUSH to create list + status, resp = api_call("POST", "/my_list/rpush", {"value": "a"}) + print(f" RPUSH 'a': {status}, length={resp.get('length')}") + if resp.get("length") != 1: + print(f" FAIL: Expected length 1") + return False + + api_call("POST", "/my_list/rpush", {"value": "b"}) + api_call("POST", "/my_list/rpush", {"value": "c"}) + + # LPUSH + status, resp = api_call("POST", "/my_list/lpush", {"value": "z"}) + print(f" LPUSH 'z': {status}, length={resp.get('length')}") + + # Verify order: [z, a, b, c] + status, resp = api_call("GET", "/my_list") + if resp.get("value") != ["z", "a", "b", "c"]: + print(f" FAIL: Expected ['z', 'a', 'b', 'c'], got {resp.get('value')}") + return False + + # LPOP + status, resp = api_call("POST", "/my_list/lpop") + print(f" LPOP: {status}, value={resp.get('value')}") + if resp.get("value") != "z": + print(f" FAIL: Expected 'z'") + return False + + # RPOP + status, resp = api_call("POST", "/my_list/rpop") + print(f" RPOP: {status}, value={resp.get('value')}") + if resp.get("value") != "c": + print(f" FAIL: Expected 'c'") + return False + + # LEN + status, resp = api_call("GET", "/my_list/len") + print(f" LEN: {status}, length={resp.get('length')}") + if resp.get("length") != 2: + print(f" FAIL: Expected 2") + return False + + print(" PASS") + return True + + +@quick +def test_nested_path(): + """[TC-3.5/5.1] Nested path GET and PATCH.""" + print("\\n[TEST] Nested path operations") + + config = { + "database": {"host": "localhost", "port": 5432}, + "cache": {"enabled": True}, + } + api_call("PUT", "/config", config) + + # PATCH nested value + patch_data = {"path": "database.port", "value": 5433} + status, resp = api_call("PATCH", "/config", patch_data) + print(f" PATCH database.port=5433: {status}") + if status != 200: + print(f" FAIL: {resp}") + return False + + # GET with path + status, resp = api_call("GET", "/config?path=database.port") + print(f" GET with path: {status}, value={resp.get('value')}") + if resp.get("value") != 5433: + print(f" FAIL: Expected 5433") + return False + + print(" PASS") + return True + + +@quick +def test_conditional_set(): + """[TC-4.1/4.2] Conditional SET with NX flag.""" + print("\\n[TEST] Conditional SET (nx)") + + api_call("DELETE", "/cond_key") + + # NX when key doesn't exist - should succeed + status, resp = api_call("PUT", "/cond_key?nx=true", "first") + print(f" PUT with nx=true (new): {status}") + if status != 201: + print(f" FAIL: Expected 201, got {status}") + return False + + # NX when key exists - should fail + status, resp = api_call("PUT", "/cond_key?nx=true", "second") + print(f" PUT with nx=true (exists): {status}") + if status != 409: + print(f" FAIL: Expected 409, got {status}") + return False + + # Verify value unchanged + status, resp = api_call("GET", "/cond_key") + if resp.get("value") != "first": + print(f" FAIL: Value should be 'first'") + return False + + print(" PASS") + return True + + +@quick +def test_list_keys(): + """[TC-3.10] List all keys.""" + print("\\n[TEST] List keys") + + api_call("PUT", "/list_test_a", "a") + api_call("PUT", "/list_test_b", "b") + + status, resp = api_call("GET", "") + print(f" GET /kv: {status}, count={resp.get('count')}") + if status != 200: + print(f" FAIL: {resp}") + return False + + keys = resp.get("keys", []) + if "list_test_a" not in keys or "list_test_b" not in keys: + print(f" FAIL: Expected keys to include list_test_a and list_test_b") + return False + + print(" PASS") + return True + + +@quick +def test_get_with_meta(): + """[TC-3.4] GET with metadata.""" + print("\\n[TEST] GET with metadata") + + api_call("PUT", "/meta_test", "value") + + status, resp = api_call("GET", "/meta_test?meta=true") + print(f" GET with meta=true: {status}") + + if "created_at" not in resp or "updated_at" not in resp: + print(f" FAIL: Missing timestamps") + return False + + print(f" created_at: {resp.get('created_at')}") + print(" PASS") + return True + + +# =========================================================================== +# THOROUGH TESTS - Edge cases, error paths, security +# =========================================================================== + +@thorough +def test_get_nonexistent_key(): + """[TC-3.6] GET non-existent key returns 404.""" + print("\\n[TEST] GET non-existent key") + + status, resp = api_call("GET", "/definitely_does_not_exist_12345") + print(f" GET /nonexistent: {status}") + if status != 404: + print(f" FAIL: Expected 404, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_get_nonexistent_path(): + """[TC-3.7] GET non-existent nested path.""" + print("\\n[TEST] GET non-existent nested path") + + api_call("PUT", "/path_test", {"a": {"b": 1}}) + + status, resp = api_call("GET", "/path_test?path=a.c.d") + print(f" GET with invalid path: {status}") + # Should return 404 or null value + if status not in (200, 404): + print(f" FAIL: Expected 200 or 404, got {status}") + return False + + if status == 200 and resp.get("value") is not None: + print(f" FAIL: Expected null value for missing path") + return False + + print(" PASS") + return True + + +@thorough +def test_delete_nonexistent(): + """[TC-3.9] DELETE non-existent key.""" + print("\\n[TEST] DELETE non-existent key") + + status, resp = api_call("DELETE", "/never_existed_xyz") + print(f" DELETE /nonexistent: {status}, deleted={resp.get('deleted')}") + + # Should succeed but indicate nothing was deleted + if status != 200: + print(f" FAIL: Expected 200, got {status}") + return False + + if resp.get("deleted") is not False: + print(f" FAIL: Expected deleted=false") + return False + + print(" PASS") + return True + + +@thorough +def test_conditional_set_xx(): + """[TC-4.3/4.4] Conditional SET with XX flag.""" + print("\\n[TEST] Conditional SET (xx)") + + api_call("DELETE", "/xx_test") + + # XX when key doesn't exist - should fail with 409 Conflict + status, resp = api_call("PUT", "/xx_test?xx=true", "value") + print(f" PUT with xx=true (missing): {status}") + if status != 409: + print(f" FAIL: Expected 409, got {status}") + return False + + # Create key first + api_call("PUT", "/xx_test", "original") + + # XX when key exists - should succeed + status, resp = api_call("PUT", "/xx_test?xx=true", "updated") + print(f" PUT with xx=true (exists): {status}") + if status != 200: + print(f" FAIL: Expected 200, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_patch_nonexistent(): + """[TC-5.3] PATCH non-existent key returns 404.""" + print("\\n[TEST] PATCH non-existent key") + + status, resp = api_call("PATCH", "/nonexistent_patch", {"path": "x", "value": 1}) + print(f" PATCH /nonexistent: {status}") + if status != 404: + print(f" FAIL: Expected 404, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_incr_new_key(): + """[TC-6.1] INCR on non-existent key initializes to 1.""" + print("\\n[TEST] INCR new key") + + api_call("DELETE", "/new_incr_counter") + + status, resp = api_call("POST", "/new_incr_counter/incr") + print(f" INCR new key: {status}, value={resp.get('value')}") + if resp.get("value") != 1: + print(f" FAIL: Expected 1, got {resp.get('value')}") + return False + + print(" PASS") + return True + + +@thorough +def test_decr_new_key(): + """[TC-6.5] DECR on non-existent key initializes to -1.""" + print("\\n[TEST] DECR new key") + + api_call("DELETE", "/new_decr_counter") + + status, resp = api_call("POST", "/new_decr_counter/decr") + print(f" DECR new key: {status}, value={resp.get('value')}") + if resp.get("value") != -1: + print(f" FAIL: Expected -1, got {resp.get('value')}") + return False + + print(" PASS") + return True + + +@thorough +def test_incr_non_numeric(): + """[TC-6.6] INCR on non-numeric value returns error.""" + print("\\n[TEST] INCR non-numeric") + + api_call("PUT", "/string_val", "hello") + + status, resp = api_call("POST", "/string_val/incr") + print(f" INCR string value: {status}") + if status != 400: + print(f" FAIL: Expected 400, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_lpop_empty_list(): + """[TC-7.7] LPOP from empty list returns null.""" + print("\\n[TEST] LPOP empty list") + + api_call("PUT", "/empty_list", []) + + status, resp = api_call("POST", "/empty_list/lpop") + print(f" LPOP empty: {status}, value={resp.get('value')}") + if resp.get("value") is not None: + print(f" FAIL: Expected null, got {resp.get('value')}") + return False + + print(" PASS") + return True + + +@thorough +def test_lpop_nonexistent(): + """[TC-7.7b] LPOP from non-existent key returns null.""" + print("\\n[TEST] LPOP non-existent key") + + api_call("DELETE", "/no_such_list") + + status, resp = api_call("POST", "/no_such_list/lpop") + print(f" LPOP nonexistent: {status}, value={resp.get('value')}") + if resp.get("value") is not None: + print(f" FAIL: Expected null") + return False + + print(" PASS") + return True + + +@thorough +def test_push_to_non_list(): + """[TC-7.8] RPUSH to non-list value returns error.""" + print("\\n[TEST] RPUSH to non-list") + + api_call("PUT", "/not_a_list", {"key": "value"}) + + status, resp = api_call("POST", "/not_a_list/rpush", {"value": "item"}) + print(f" RPUSH to dict: {status}") + if status != 400: + print(f" FAIL: Expected 400, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_len_nonexistent(): + """[TC-7.9] LEN on non-existent key returns 404.""" + print("\\n[TEST] LEN non-existent key") + + api_call("DELETE", "/no_such_list_len") + + status, resp = api_call("GET", "/no_such_list_len/len") + print(f" LEN nonexistent: {status}") + if status != 404: + print(f" FAIL: Expected 404, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_special_characters_in_key(): + """[TC-8.1] Key with special characters.""" + print("\\n[TEST] Special characters in key") + + # Test with dashes, underscores, numbers + key = "test-key_123" + status, resp = api_call("PUT", f"/{key}", "value") + print(f" PUT /{key}: {status}") + if status not in (200, 201): + print(f" FAIL: {resp}") + return False + + status, resp = api_call("GET", f"/{key}") + if resp.get("value") != "value": + print(f" FAIL: Value mismatch") + return False + + print(" PASS") + return True + + +@thorough +def test_null_value(): + """[TC-8.6] Store null value - rejected as empty body.""" + print("\\n[TEST] Store null value") + + # Null/empty body is rejected by FastAPI validation + status, resp = api_call("PUT", "/null_test", None) + print(f" PUT null: {status}") + if status != 422: + print(f" FAIL: Expected 422, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_various_json_types(): + """[TC-8.7] Store various JSON types.""" + print("\\n[TEST] Various JSON types") + + test_cases = [ + ("string_type", "hello"), + ("number_int", 42), + ("number_float", 3.14), + ("boolean_true", True), + ("boolean_false", False), + ("array_type", [1, 2, 3]), + ("nested_obj", {"a": {"b": {"c": 1}}}), + ] + + for key, value in test_cases: + status, _ = api_call("PUT", f"/type_{key}", value) + if status not in (200, 201): + print(f" FAIL: PUT {key} returned {status}") + return False + + status, resp = api_call("GET", f"/type_{key}") + if resp.get("value") != value: + print(f" FAIL: {key} value mismatch: {resp.get('value')} != {value}") + return False + print(f" {key}: OK") + + print(" PASS") + return True + + +@thorough +def test_auth_missing_token(): + """[TC-2.1] Access without token returns 401.""" + print("\\n[TEST] Auth - missing token") + + global KV_TOKEN + saved_token = KV_TOKEN + KV_TOKEN = "" + + # Missing Authorization header returns 422 (FastAPI validation error) + # before our auth middleware runs + status, _ = api_call_raw("GET", "/test", auth=False) + print(f" GET without token: {status}") + + KV_TOKEN = saved_token + + if status not in (401, 403, 422): + print(f" FAIL: Expected 401, 403, or 422, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_auth_invalid_token(): + """[TC-2.2] Access with invalid token returns 401.""" + print("\\n[TEST] Auth - invalid token") + + headers = {"Authorization": "Bearer invalid.token.here"} + status, _ = api_call_raw("GET", "/test", headers=headers) + print(f" GET with invalid token: {status}") + + if status not in (401, 403): + print(f" FAIL: Expected 401 or 403, got {status}") + return False + + print(" PASS") + return True + + +@thorough +def test_invalid_json_body(): + """[TC-11.1] Invalid JSON body returns 400.""" + print("\\n[TEST] Invalid JSON body") + + status, _ = api_call_raw( + "PUT", "/bad_json", + body=b"not valid json {", + headers={"Content-Type": "application/json"} + ) + print(f" PUT invalid JSON: {status}") + + if status != 400 and status != 422: + print(f" FAIL: Expected 400 or 422, got {status}") + return False + + print(" PASS") + return True + + +def main(): + global API_URL, KV_TOKEN + + API_URL = os.environ.get("OPENHANDS_CLOUD_API_URL", "").rstrip("/") + KV_TOKEN = os.environ.get("AUTOMATION_KV_TOKEN", "") + + # Parse mode from command line + mode = "quick" + if len(sys.argv) > 1: + if sys.argv[1] == "--thorough": + mode = "thorough" + elif sys.argv[1] == "--quick": + mode = "quick" + + tests = QUICK_TESTS if mode == "quick" else THOROUGH_TESTS + + print("=" * 60) + print(f"KV STORE E2E TEST ({mode.upper()} MODE)") + print(f"Running {len(tests)} tests") + print("=" * 60) + print(f"API URL: {API_URL}") + token_info = f"present ({len(KV_TOKEN)} chars)" if KV_TOKEN else "MISSING" + print(f"KV Token: {token_info}") + + if not API_URL: + print("\\nFAIL: OPENHANDS_CLOUD_API_URL not set") + sys.exit(1) + + if not KV_TOKEN: + print("\\nFAIL: AUTOMATION_KV_TOKEN not set") + sys.exit(1) + + passed = 0 + failed = 0 + + for test in tests: + try: + if test(): + passed += 1 + else: + failed += 1 + except Exception as e: + print(f" ERROR: {e}") + import traceback + traceback.print_exc() + failed += 1 + + print("\\n" + "=" * 60) + print(f"RESULTS ({mode.upper()}): {passed} passed, {failed} failed") + print("=" * 60) + + if failed == 0: + print("\\nKV_STORE_ALL_TESTS_PASSED") + sys.exit(0) + else: + print("\\nKV_STORE_TESTS_FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() +''' + + +# Legacy test script reference (keeping for backwards compatibility) +KV_TEST_SCRIPT_QUICK = KV_TEST_SCRIPT.replace('mode = "quick"', 'mode = "quick"') +KV_TEST_SCRIPT_THOROUGH = KV_TEST_SCRIPT.replace('mode = "quick"', 'mode = "thorough"') + + +async def create_automation( + client: httpx.AsyncClient, api_url: str, api_key: str +) -> str: + """Create a test automation (KV always available). Returns automation_id.""" + print("Creating automation (KV store is always available)...") + + resp = await client.post( + f"{api_url}/api/automation/v1/preset/prompt", + headers={"Authorization": f"Bearer {api_key}"}, + json={ + "name": f"KV Store Test {uuid.uuid4().hex[:8]}", + "prompt": "This is a test automation for KV store verification.", + "trigger": { + "type": "cron", + "schedule": "0 0 1 1 *", # Once a year (won't actually trigger) + "timezone": "UTC", + }, + }, + ) + + if resp.status_code != 201: + print(f"Failed to create automation: {resp.status_code}") + print(resp.text) + sys.exit(1) + + data = resp.json() + automation_id = data["id"] + print(f"Created automation: {automation_id}") + return automation_id + + +async def delete_automation( + client: httpx.AsyncClient, api_url: str, api_key: str, automation_id: str +): + """Delete the test automation (best-effort cleanup).""" + print(f"\nCleaning up automation {automation_id}...") + resp = await client.delete( + f"{api_url}/api/automation/v1/{automation_id}", + headers={"Authorization": f"Bearer {api_key}"}, + ) + if resp.status_code == 204: + print("Automation deleted.") + elif resp.status_code == 403: + print("Note: Cleanup skipped (API key lacks manage_automations permission)") + else: + print(f"Warning: Failed to delete automation: {resp.status_code}") + + +async def main(): + # --- Configuration --- + api_key = os.environ.get("OPENHANDS_API_KEY") + kv_secret = os.environ.get("AUTOMATION_KV_SECRET") + api_url = os.environ.get( + "OPENHANDS_API_URL", "https://staging.all-hands.dev" + ).rstrip("/") + + # Parse mode from command line + mode = "quick" + if "--thorough" in sys.argv: + mode = "thorough" + + print("=" * 70) + print(f"KV STORE E2E TEST RUNNER ({mode.upper()} MODE)") + print("=" * 70) + print(f"API URL: {api_url}") + print(f"API Key: {'present' if api_key else 'MISSING'}") + print(f"KV Secret: {'present' if kv_secret else 'MISSING'}") + print() + + if not api_key: + print("ERROR: Set OPENHANDS_API_KEY environment variable") + sys.exit(1) + + if not kv_secret: + print("ERROR: Set AUTOMATION_KV_SECRET environment variable") + print(" (Must match the secret configured in staging)") + sys.exit(1) + + # Select test script based on mode + test_script = KV_TEST_SCRIPT + entrypoint = f"python main.py --{mode}" + + # --- Create automation via API --- + automation_id = None + async with httpx.AsyncClient(timeout=60) as client: + try: + automation_id = await create_automation(client, api_url, api_key) + automation_uuid = uuid.UUID(automation_id) + + # --- Generate KV token --- + run_id = uuid.uuid4() + kv_token = create_kv_token( + secret=kv_secret, + automation_id=automation_uuid, + run_id=run_id, + ) + print(f"Generated KV token for run_id={run_id}") + + # --- Build tarball --- + print("\nBuilding test tarball...") + tarball = build_tarball( + { + "main.py": test_script, + } + ) + print(f"Tarball size: {len(tarball)} bytes") + + # --- Run automation --- + print("\n" + "-" * 70) + print(f"EXECUTING IN SANDBOX ({mode.upper()} MODE)") + print("-" * 70) + + result = await run_automation( + api_url=api_url, + api_key=api_key, + entrypoint=entrypoint, + tarball_source=tarball, + env_vars={ + "OPENHANDS_API_KEY": api_key, + "OPENHANDS_CLOUD_API_URL": api_url, + "AUTOMATION_KV_TOKEN": kv_token, + }, + timeout=600 if mode == "thorough" else 300, + keep_sandbox=False, + ) + + # --- Display results --- + print("\n" + "=" * 70) + print("EXECUTION RESULT") + print("=" * 70) + print(f"Success: {result.success}") + print(f"Exit code: {result.exit_code}") + print(f"Sandbox ID: {result.sandbox_id}") + + if result.stdout: + print("\n" + "-" * 70) + print("STDOUT") + print("-" * 70) + print(result.stdout) + + if result.stderr: + print("\n" + "-" * 70) + print("STDERR (last 3000 chars)") + print("-" * 70) + print(result.stderr[-3000:]) + + if result.error: + print("\n" + "-" * 70) + print("ERROR") + print("-" * 70) + print(result.error) + + # --- Final verdict --- + print("\n" + "=" * 70) + if result.success and "KV_STORE_ALL_TESTS_PASSED" in result.stdout: + print(f"✅ KV STORE E2E TEST PASSED ({mode.upper()} MODE)") + print("=" * 70) + return 0 + else: + print(f"❌ KV STORE E2E TEST FAILED ({mode.upper()} MODE)") + print("=" * 70) + return 1 + + finally: + # --- Cleanup --- + if automation_id: + await delete_automation(client, api_url, api_key, automation_id) + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/tests/test_kv_batch.py b/tests/test_kv_batch.py new file mode 100644 index 0000000..57c1197 --- /dev/null +++ b/tests/test_kv_batch.py @@ -0,0 +1,433 @@ +"""Unit tests for KV batch operations and $version functionality. + +These tests focus on the batch operation logic without requiring a database. +""" + +import pytest +from fastapi import HTTPException + +from openhands.automation.kv_helpers import validate_key +from openhands.automation.kv_router import ( + KVOperationError, + _execute_batch_operation, + _get_version, + _validate_batch_key, +) +from openhands.automation.kv_schemas import ( + KVBatchOpDecr, + KVBatchOpDelete, + KVBatchOpIncr, + KVBatchOpLPop, + KVBatchOpLPush, + KVBatchOpPatch, + KVBatchOpRPop, + KVBatchOpRPush, + KVBatchOpSet, + KVBatchRequest, +) + + +class TestValidateKeyReserved: + """Test that $ prefix keys are rejected.""" + + def test_dollar_prefix_rejected(self): + with pytest.raises(HTTPException) as exc: + validate_key("$version") + assert exc.value.status_code == 400 + assert "reserved" in exc.value.detail.lower() + + def test_dollar_prefix_any_name_rejected(self): + with pytest.raises(HTTPException) as exc: + validate_key("$anything") + assert exc.value.status_code == 400 + assert "reserved" in exc.value.detail.lower() + + def test_dollar_in_middle_allowed(self): + # $ in middle is fine, only prefix is reserved + result = validate_key("my$key") + assert result == "my$key" + + def test_dollar_at_end_allowed(self): + result = validate_key("key$") + assert result == "key$" + + +class TestValidateBatchKey: + """Test batch key validation.""" + + def test_valid_key(self): + _validate_batch_key("mykey") # Should not raise + + def test_empty_key_rejected(self): + with pytest.raises(KVOperationError, match="empty"): + _validate_batch_key("") + + def test_whitespace_key_rejected(self): + with pytest.raises(KVOperationError, match="whitespace"): + _validate_batch_key(" ") + + def test_dollar_prefix_rejected(self): + with pytest.raises(KVOperationError, match="reserved"): + _validate_batch_key("$version") + + def test_long_key_rejected(self): + with pytest.raises(KVOperationError, match="exceeds 255"): + _validate_batch_key("x" * 256) + + +class TestGetVersion: + """Test version extraction.""" + + def test_get_version_present(self): + assert _get_version({"$version": 5, "key": "value"}) == 5 + + def test_get_version_missing(self): + assert _get_version({"key": "value"}) == 0 + + def test_get_version_empty_state(self): + assert _get_version({}) == 0 + + +class TestBatchOpSet: + """Test set operation in batch.""" + + def test_set_new_key(self): + state = {} + op = KVBatchOpSet(op="set", key="foo", value="bar") + result = _execute_batch_operation(state, op) + + assert state["foo"] == "bar" + assert result == {"op": "set", "key": "foo", "success": True, "created": True} + + def test_set_existing_key(self): + state = {"foo": "old"} + op = KVBatchOpSet(op="set", key="foo", value="new") + result = _execute_batch_operation(state, op) + + assert state["foo"] == "new" + assert result == {"op": "set", "key": "foo", "success": True, "created": False} + + def test_set_nx_creates_new(self): + state = {} + op = KVBatchOpSet(op="set", key="foo", value="bar", nx=True) + result = _execute_batch_operation(state, op) + + assert state["foo"] == "bar" + assert result["created"] is True + + def test_set_nx_fails_if_exists(self): + state = {"foo": "old"} + op = KVBatchOpSet(op="set", key="foo", value="new", nx=True) + + with pytest.raises(KVOperationError, match="already exists"): + _execute_batch_operation(state, op) + + def test_set_xx_updates_existing(self): + state = {"foo": "old"} + op = KVBatchOpSet(op="set", key="foo", value="new", xx=True) + result = _execute_batch_operation(state, op) + + assert state["foo"] == "new" + assert result["created"] is False + + def test_set_xx_fails_if_not_exists(self): + state = {} + op = KVBatchOpSet(op="set", key="foo", value="bar", xx=True) + + with pytest.raises(KVOperationError, match="does not exist"): + _execute_batch_operation(state, op) + + def test_set_reserved_key_rejected(self): + state = {} + op = KVBatchOpSet(op="set", key="$version", value=100) + + with pytest.raises(KVOperationError, match="reserved"): + _execute_batch_operation(state, op) + + +class TestBatchOpDelete: + """Test delete operation in batch.""" + + def test_delete_existing(self): + state = {"foo": "bar"} + op = KVBatchOpDelete(op="delete", key="foo") + result = _execute_batch_operation(state, op) + + assert "foo" not in state + expected = {"op": "delete", "key": "foo", "success": True, "deleted": True} + assert result == expected + + def test_delete_nonexistent(self): + state = {} + op = KVBatchOpDelete(op="delete", key="foo") + result = _execute_batch_operation(state, op) + + expected = {"op": "delete", "key": "foo", "success": True, "deleted": False} + assert result == expected + + +class TestBatchOpIncr: + """Test incr operation in batch.""" + + def test_incr_creates_key(self): + state = {} + op = KVBatchOpIncr(op="incr", key="counter") + result = _execute_batch_operation(state, op) + + assert state["counter"] == 1 + assert result == {"op": "incr", "key": "counter", "success": True, "value": 1} + + def test_incr_increments_existing(self): + state = {"counter": 5} + op = KVBatchOpIncr(op="incr", key="counter") + result = _execute_batch_operation(state, op) + + assert state["counter"] == 6 + assert result["value"] == 6 + + def test_incr_by_custom_amount(self): + state = {"counter": 10} + op = KVBatchOpIncr(op="incr", key="counter", by=5) + result = _execute_batch_operation(state, op) + + assert state["counter"] == 15 + assert result["value"] == 15 + + def test_incr_rejects_non_integer(self): + state = {"counter": "not a number"} + op = KVBatchOpIncr(op="incr", key="counter") + + with pytest.raises(KVOperationError, match="not an integer"): + _execute_batch_operation(state, op) + + def test_incr_rejects_boolean(self): + state = {"flag": True} + op = KVBatchOpIncr(op="incr", key="flag") + + with pytest.raises(KVOperationError, match="boolean"): + _execute_batch_operation(state, op) + + +class TestBatchOpDecr: + """Test decr operation in batch.""" + + def test_decr_creates_negative(self): + state = {} + op = KVBatchOpDecr(op="decr", key="counter") + result = _execute_batch_operation(state, op) + + assert state["counter"] == -1 + assert result["value"] == -1 + + def test_decr_decrements_existing(self): + state = {"counter": 10} + op = KVBatchOpDecr(op="decr", key="counter", by=3) + result = _execute_batch_operation(state, op) + + assert state["counter"] == 7 + assert result["value"] == 7 + + +class TestBatchOpLPush: + """Test lpush operation in batch.""" + + def test_lpush_creates_list(self): + state = {} + op = KVBatchOpLPush(op="lpush", key="queue", value="item1") + result = _execute_batch_operation(state, op) + + assert state["queue"] == ["item1"] + assert result == {"op": "lpush", "key": "queue", "success": True, "length": 1} + + def test_lpush_prepends(self): + state = {"queue": ["b", "c"]} + op = KVBatchOpLPush(op="lpush", key="queue", value="a") + result = _execute_batch_operation(state, op) + + assert state["queue"] == ["a", "b", "c"] + assert result["length"] == 3 + + def test_lpush_rejects_non_list(self): + state = {"queue": "not a list"} + op = KVBatchOpLPush(op="lpush", key="queue", value="item") + + with pytest.raises(KVOperationError, match="not a list"): + _execute_batch_operation(state, op) + + +class TestBatchOpRPush: + """Test rpush operation in batch.""" + + def test_rpush_creates_list(self): + state = {} + op = KVBatchOpRPush(op="rpush", key="queue", value="item1") + result = _execute_batch_operation(state, op) + + assert state["queue"] == ["item1"] + assert result["length"] == 1 + + def test_rpush_appends(self): + state = {"queue": ["a", "b"]} + op = KVBatchOpRPush(op="rpush", key="queue", value="c") + result = _execute_batch_operation(state, op) + + assert state["queue"] == ["a", "b", "c"] + assert result["length"] == 3 + + +class TestBatchOpLPop: + """Test lpop operation in batch.""" + + def test_lpop_returns_first(self): + state = {"queue": ["a", "b", "c"]} + op = KVBatchOpLPop(op="lpop", key="queue") + result = _execute_batch_operation(state, op) + + assert state["queue"] == ["b", "c"] + assert result == {"op": "lpop", "key": "queue", "success": True, "value": "a"} + + def test_lpop_empty_returns_null(self): + state = {"queue": []} + op = KVBatchOpLPop(op="lpop", key="queue") + result = _execute_batch_operation(state, op) + + assert result["value"] is None + + def test_lpop_nonexistent_returns_null(self): + state = {} + op = KVBatchOpLPop(op="lpop", key="queue") + result = _execute_batch_operation(state, op) + + assert result["value"] is None + + +class TestBatchOpRPop: + """Test rpop operation in batch.""" + + def test_rpop_returns_last(self): + state = {"queue": ["a", "b", "c"]} + op = KVBatchOpRPop(op="rpop", key="queue") + result = _execute_batch_operation(state, op) + + assert state["queue"] == ["a", "b"] + assert result["value"] == "c" + + +class TestBatchOpPatch: + """Test patch operation in batch.""" + + def test_patch_updates_nested(self): + state = {"config": {"db": {"host": "localhost"}}} + op = KVBatchOpPatch(op="patch", key="config", path="db.port", value=5432) + result = _execute_batch_operation(state, op) + + assert state["config"]["db"]["port"] == 5432 + assert result == {"op": "patch", "key": "config", "success": True} + + def test_patch_creates_key_if_missing(self): + state = {} + op = KVBatchOpPatch(op="patch", key="config", path="db.host", value="localhost") + _execute_batch_operation(state, op) + + assert state["config"]["db"]["host"] == "localhost" + + def test_patch_rejects_non_dict(self): + state = {"config": "not a dict"} + op = KVBatchOpPatch(op="patch", key="config", path="db.host", value="localhost") + + with pytest.raises(KVOperationError, match="not an object"): + _execute_batch_operation(state, op) + + +class TestBatchRequest: + """Test batch request validation.""" + + def test_valid_batch(self): + req = KVBatchRequest( + operations=[ + KVBatchOpSet(op="set", key="a", value=1), + KVBatchOpIncr(op="incr", key="b"), + ] + ) + assert len(req.operations) == 2 + + def test_batch_with_version(self): + req = KVBatchRequest( + if_version=5, + operations=[KVBatchOpSet(op="set", key="a", value=1)], + ) + assert req.if_version == 5 + + def test_empty_operations_rejected(self): + with pytest.raises(ValueError): + KVBatchRequest(operations=[]) + + def test_too_many_operations_rejected(self): + ops: list = [KVBatchOpIncr(op="incr", key=f"k{i}") for i in range(101)] + with pytest.raises(ValueError): + KVBatchRequest(operations=ops) + + +class TestBatchMultipleOps: + """Test executing multiple operations in sequence.""" + + def test_multiple_ops_in_order(self): + state = {} + + ops = [ + KVBatchOpSet(op="set", key="counter", value=0), + KVBatchOpIncr(op="incr", key="counter", by=5), + KVBatchOpIncr(op="incr", key="counter", by=3), + KVBatchOpRPush(op="rpush", key="log", value="started"), + KVBatchOpRPush(op="rpush", key="log", value="finished"), + ] + + results = [] + for op in ops: + results.append(_execute_batch_operation(state, op)) + + assert state["counter"] == 8 + assert state["log"] == ["started", "finished"] + assert results[0]["created"] is True + assert results[1]["value"] == 5 + assert results[2]["value"] == 8 + assert results[3]["length"] == 1 + assert results[4]["length"] == 2 + + def test_early_failure_stops_batch(self): + """Simulate what happens when an operation fails mid-batch.""" + state = {"counter": "not a number"} + + ops = [ + KVBatchOpSet(op="set", key="before", value="ok"), + KVBatchOpIncr(op="incr", key="counter"), # This will fail + KVBatchOpSet(op="set", key="after", value="should not run"), + ] + + # First op succeeds + _execute_batch_operation(state, ops[0]) + assert state["before"] == "ok" + + # Second op fails + with pytest.raises(KVOperationError): + _execute_batch_operation(state, ops[1]) + + # In a real batch, the transaction would rollback, so "before" wouldn't persist + # But we're testing that the error is raised properly + + +class TestVersionBump: + """Test that $version is properly managed.""" + + def test_version_starts_at_zero_if_missing(self): + state = {"key": "value"} + assert _get_version(state) == 0 + + def test_version_preserved_across_reads(self): + state = {"$version": 5, "key": "value"} + assert _get_version(state) == 5 + # Operations don't touch $version directly + op = KVBatchOpSet(op="set", key="other", value="x") + _execute_batch_operation(state, op) + # $version unchanged by operation (bump happens in _save_state) + assert state["$version"] == 5 diff --git a/tests/test_kv_concurrency.py b/tests/test_kv_concurrency.py new file mode 100644 index 0000000..d438d0b --- /dev/null +++ b/tests/test_kv_concurrency.py @@ -0,0 +1,156 @@ +"""Tests for KV store concurrency controls. + +Tests cover: +- Statement timeout (safety net for runaway operations) +- Retry-After header on 409 responses +- Metrics recording +""" + +import uuid + +import pytest + +from openhands.automation.kv_metrics import ( + kv_conflict_total, + record_conflict, + record_lock_wait, + record_operation, + record_state_size, +) +from openhands.automation.kv_router import ( + _is_lock_timeout_error, + _raise_lock_conflict, + _raise_version_conflict, +) + + +# --- Test Constants --- +TEST_AUTOMATION_ID = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") +TEST_RUN_ID = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") + + +class TestStatementTimeoutDetection: + """Tests for statement timeout error detection.""" + + def test_detects_lock_timeout_55p03(self): + """Detects lock timeout error code 55P03.""" + exc = Exception("ERROR: canceling statement due to lock timeout (55P03)") + assert _is_lock_timeout_error(exc) is True + + def test_detects_lock_not_available(self): + """Detects lock_not_available error.""" + exc = Exception("asyncpg.exceptions.LockNotAvailableError: lock_not_available") + assert _is_lock_timeout_error(exc) is True + + def test_detects_statement_timeout_57014(self): + """Detects statement timeout error code 57014.""" + exc = Exception("ERROR: canceling statement due to statement timeout (57014)") + assert _is_lock_timeout_error(exc) is True + + def test_detects_query_canceled(self): + """Detects query_canceled error.""" + exc = Exception("asyncpg.exceptions.QueryCanceledError: query_canceled") + assert _is_lock_timeout_error(exc) is True + + def test_ignores_unrelated_errors(self): + """Ignores unrelated database errors.""" + exc = Exception("ERROR: duplicate key value violates unique constraint") + assert _is_lock_timeout_error(exc) is False + + def test_ignores_generic_errors(self): + """Ignores generic Python errors.""" + exc = ValueError("invalid value") + assert _is_lock_timeout_error(exc) is False + + +class TestRetryAfterHeader: + """Tests for Retry-After header on 409 responses.""" + + def test_lock_conflict_includes_retry_after(self): + """_raise_lock_conflict includes Retry-After header.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + _raise_lock_conflict() + + exc = exc_info.value + assert exc.status_code == 409 + assert exc.headers is not None + assert "Retry-After" in exc.headers + assert exc.headers["Retry-After"] == "1" + + def test_version_conflict_includes_retry_after(self): + """_raise_version_conflict includes Retry-After header.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + _raise_version_conflict(expected=5, actual=6) + + exc = exc_info.value + assert exc.status_code == 409 + assert exc.headers is not None + assert "Retry-After" in exc.headers + assert exc.headers["Retry-After"] == "1" + + def test_version_conflict_includes_versions(self): + """_raise_version_conflict includes version info in detail.""" + from typing import Any, cast + + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + _raise_version_conflict(expected=5, actual=6) + + exc = exc_info.value + detail = cast(dict[str, Any], exc.detail) + assert detail["error"] == "version_mismatch" + assert detail["expected_version"] == 5 + assert detail["actual_version"] == 6 + + +class TestKVMetrics: + """Tests for KV store Prometheus metrics.""" + + def test_record_operation_timing(self): + """record_operation measures duration.""" + import time + + # Use the context manager + with record_operation("test_op"): + time.sleep(0.01) # 10ms + + # Metric should have been recorded (we can't easily check exact value + # but we can verify no exceptions) + + def test_record_lock_wait_timing(self): + """record_lock_wait measures duration.""" + import time + + with record_lock_wait(): + time.sleep(0.001) # 1ms + + def test_record_conflict_lock_timeout(self): + """record_conflict increments counter for lock timeout.""" + # Get initial count (if any) + initial = kv_conflict_total.labels(reason="lock_timeout")._value.get() + + record_conflict("lock_timeout") + + # Should have incremented + new_value = kv_conflict_total.labels(reason="lock_timeout")._value.get() + assert new_value == initial + 1 + + def test_record_conflict_version_mismatch(self): + """record_conflict increments counter for version mismatch.""" + initial = kv_conflict_total.labels(reason="version_mismatch")._value.get() + + record_conflict("version_mismatch") + + new_value = kv_conflict_total.labels(reason="version_mismatch")._value.get() + assert new_value == initial + 1 + + def test_record_state_size(self): + """record_state_size records to histogram.""" + # Just verify it doesn't raise + record_state_size(1000) + record_state_size(50000) diff --git a/tests/test_kv_helpers.py b/tests/test_kv_helpers.py new file mode 100644 index 0000000..b8f2907 --- /dev/null +++ b/tests/test_kv_helpers.py @@ -0,0 +1,629 @@ +"""Tests for KV store helper functions. + +Tests cover: +- Path parsing (valid and edge cases) +- Key validation (robustness against malicious/accidental inputs) +- Type validation (numeric, integer, list, dict) +""" + +import pytest +from fastapi import HTTPException + +from openhands.automation.kv_helpers import ( + _MAX_KEY_LENGTH, + _MAX_PATH_DEPTH, + get_nested_value, + parse_path, + require_dict, + require_int, + require_list, + require_numeric, + set_nested_value, + validate_key, +) + + +class TestParsePath: + """Tests for parse_path() function.""" + + def test_simple_dot_notation(self): + """Simple dot-separated path.""" + assert parse_path("database.host") == ["database", "host"] + + def test_single_key(self): + """Single key with no dots.""" + assert parse_path("key") == ["key"] + + def test_empty_string(self): + """Empty string returns empty list.""" + assert parse_path("") == [] + + def test_bracket_notation_double_quotes(self): + """Bracket notation with double quotes.""" + assert parse_path('config["my.key"]') == ["config", "my.key"] + + def test_bracket_notation_single_quotes(self): + """Bracket notation with single quotes.""" + assert parse_path("config['my.key']") == ["config", "my.key"] + + def test_bracket_notation_no_quotes(self): + """Bracket notation without quotes.""" + assert parse_path("config[0]") == ["config", "0"] + + def test_mixed_notation(self): + """Mix of dot and bracket notation.""" + assert parse_path('data["items"][0].name') == ["data", "items", "0", "name"] + + def test_consecutive_brackets(self): + """Multiple consecutive brackets.""" + assert parse_path("arr[0][1]") == ["arr", "0", "1"] + + def test_numeric_keys(self): + """Numeric keys in dot notation.""" + assert parse_path("data.0.1") == ["data", "0", "1"] + + def test_trailing_dot(self): + """Trailing dot is ignored.""" + assert parse_path("foo.bar.") == ["foo", "bar"] + + def test_leading_dot(self): + """Leading dot is ignored.""" + assert parse_path(".foo.bar") == ["foo", "bar"] + + def test_unclosed_bracket_raises(self): + """Unclosed bracket raises ValueError.""" + with pytest.raises(ValueError, match="unclosed bracket"): + parse_path("config[key") + + def test_path_at_max_depth_succeeds(self): + """Path at exactly max depth succeeds.""" + path = ".".join(["a"] * _MAX_PATH_DEPTH) + parts = parse_path(path) + assert len(parts) == _MAX_PATH_DEPTH + + def test_path_exceeds_max_depth_raises(self): + """Path exceeding max depth raises ValueError.""" + path = ".".join(["a"] * (_MAX_PATH_DEPTH + 1)) + with pytest.raises(ValueError, match="exceeds maximum depth"): + parse_path(path) + + def test_very_deep_path_raises(self): + """Very deep path raises with helpful error message.""" + path = ".".join(["x"] * 100) + with pytest.raises(ValueError) as exc_info: + parse_path(path) + assert "100 segments" in str(exc_info.value) + + def test_empty_segments_ignored(self): + """Empty segments from consecutive dots are ignored.""" + # Two consecutive dots should not create empty segment + assert parse_path("foo..bar") == ["foo", "bar"] + + def test_bracket_at_end(self): + """Bracket notation at end of path.""" + assert parse_path('config.database["host"]') == ["config", "database", "host"] + + +class TestGetNestedValue: + """Tests for get_nested_value() function.""" + + def test_simple_dict_access(self): + """Access simple dict key.""" + obj = {"foo": "bar"} + assert get_nested_value(obj, "foo") == "bar" + + def test_nested_dict_access(self): + """Access nested dict.""" + obj = {"database": {"host": "localhost", "port": 5432}} + assert get_nested_value(obj, "database.host") == "localhost" + + def test_list_index_access(self): + """Access list by index.""" + obj = {"items": ["a", "b", "c"]} + assert get_nested_value(obj, "items.1") == "b" + + def test_nested_list_access(self): + """Access nested list.""" + obj = {"matrix": [[1, 2], [3, 4]]} + assert get_nested_value(obj, "matrix.0.1") == 2 + + def test_empty_path_returns_object(self): + """Empty path returns the object itself.""" + obj = {"foo": "bar"} + assert get_nested_value(obj, "") == obj + + def test_missing_key_raises(self): + """Missing key raises KeyError.""" + obj = {"foo": "bar"} + with pytest.raises(KeyError, match="not found"): + get_nested_value(obj, "missing") + + def test_missing_nested_key_raises(self): + """Missing nested key raises KeyError.""" + obj = {"foo": {"bar": "baz"}} + with pytest.raises(KeyError, match="not found"): + get_nested_value(obj, "foo.missing") + + def test_list_index_out_of_bounds_raises(self): + """List index out of bounds raises KeyError.""" + obj = {"items": ["a", "b"]} + with pytest.raises(KeyError, match="not found"): + get_nested_value(obj, "items.5") + + def test_invalid_list_index_raises(self): + """Non-numeric list index raises KeyError.""" + obj = {"items": ["a", "b"]} + with pytest.raises(KeyError, match="not found"): + get_nested_value(obj, "items.foo") + + def test_traverse_non_container_raises(self): + """Traversing through a non-dict/list raises KeyError.""" + obj = {"foo": "bar"} + with pytest.raises(KeyError, match="not found"): + get_nested_value(obj, "foo.baz") + + def test_bracket_notation_with_dots(self): + """Access key containing dots via bracket notation.""" + obj = {"config": {"my.key.with.dots": "value"}} + assert get_nested_value(obj, 'config["my.key.with.dots"]') == "value" + + +class TestSetNestedValue: + """Tests for set_nested_value() function.""" + + def test_set_simple_key(self): + """Set simple key.""" + obj: dict = {} + set_nested_value(obj, "foo", "bar") + assert obj == {"foo": "bar"} + + def test_set_nested_key(self): + """Set nested key.""" + obj = {"database": {}} + set_nested_value(obj, "database.host", "localhost") + assert obj == {"database": {"host": "localhost"}} + + def test_create_intermediate_dicts(self): + """Creates intermediate dicts as needed.""" + obj: dict = {} + set_nested_value(obj, "a.b.c", "value") + assert obj == {"a": {"b": {"c": "value"}}} + + def test_overwrite_existing_value(self): + """Overwrite existing value.""" + obj = {"foo": "old"} + set_nested_value(obj, "foo", "new") + assert obj == {"foo": "new"} + + def test_returns_same_object(self): + """Returns the same dict object (mutated in place).""" + obj = {"foo": "bar"} + result = set_nested_value(obj, "baz", "qux") + assert result is obj + + def test_intermediate_non_dict_raises(self): + """Setting through non-dict intermediate raises ValueError.""" + obj = {"foo": "bar"} + with pytest.raises(ValueError, match="intermediate value is not a dict"): + set_nested_value(obj, "foo.baz", "value") + + def test_bracket_notation_with_dots(self): + """Set key containing dots via bracket notation.""" + obj = {"config": {}} + set_nested_value(obj, 'config["my.key"]', "value") + assert obj == {"config": {"my.key": "value"}} + + +# ============================================================================= +# Key Validation Tests +# ============================================================================= + + +class TestValidateKey: + """Tests for validate_key() function. + + Validates that key names are safe for storage and retrieval. + Protects against accidental, malicious, and ignorant clients. + """ + + # --- Valid keys --- + + def test_simple_key(self): + """Simple alphanumeric key is valid.""" + assert validate_key("my_key") == "my_key" + + def test_key_with_dots(self): + """Key with dots is valid (dots are only special in paths).""" + assert validate_key("config.json") == "config.json" + + def test_key_with_hyphens(self): + """Key with hyphens is valid.""" + assert validate_key("my-key-name") == "my-key-name" + + def test_key_with_spaces(self): + """Key with internal spaces is valid.""" + assert validate_key("my key") == "my key" + + def test_unicode_key(self): + """Unicode characters in keys are valid.""" + assert validate_key("日本語キー") == "日本語キー" + assert validate_key("emoji_🔑") == "emoji_🔑" + + def test_max_length_key(self): + """Key at exactly max length is valid.""" + key = "a" * _MAX_KEY_LENGTH + assert validate_key(key) == key + + def test_numeric_key(self): + """Numeric string key is valid.""" + assert validate_key("12345") == "12345" + + # --- Invalid keys: Empty/Whitespace --- + + def test_empty_key_rejected(self): + """Empty string key is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("") + assert exc_info.value.status_code == 400 + assert "cannot be empty" in exc_info.value.detail + + def test_whitespace_only_key_rejected(self): + """Whitespace-only key is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key(" ") + assert exc_info.value.status_code == 400 + assert "whitespace-only" in exc_info.value.detail + + def test_tabs_only_key_rejected(self): + """Tab-only key is rejected as whitespace.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("\t\t") + assert exc_info.value.status_code == 400 + # Tabs are control characters, so this might fail on control char check first + # Either error message is acceptable + + # --- Invalid keys: Too long --- + + def test_key_exceeds_max_length_rejected(self): + """Key exceeding max length is rejected.""" + key = "a" * (_MAX_KEY_LENGTH + 1) + with pytest.raises(HTTPException) as exc_info: + validate_key(key) + assert exc_info.value.status_code == 400 + assert "exceeds" in exc_info.value.detail + assert str(_MAX_KEY_LENGTH) in exc_info.value.detail + + def test_very_long_key_rejected(self): + """Very long key is rejected with helpful error.""" + key = "x" * 1000 + with pytest.raises(HTTPException) as exc_info: + validate_key(key) + assert exc_info.value.status_code == 400 + assert "1000 given" in exc_info.value.detail + + # --- Invalid keys: Control characters --- + + def test_null_byte_in_key_rejected(self): + """Key containing null byte is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("key\x00value") + assert exc_info.value.status_code == 400 + assert "control character" in exc_info.value.detail + assert "\\x00" in exc_info.value.detail + + def test_newline_in_key_rejected(self): + """Key containing newline is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("key\nvalue") + assert exc_info.value.status_code == 400 + assert "control character" in exc_info.value.detail + assert "\\x0a" in exc_info.value.detail + + def test_carriage_return_in_key_rejected(self): + """Key containing carriage return is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("key\rvalue") + assert exc_info.value.status_code == 400 + assert "control character" in exc_info.value.detail + + def test_tab_in_key_rejected(self): + """Key containing tab is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("key\tvalue") + assert exc_info.value.status_code == 400 + assert "control character" in exc_info.value.detail + assert "\\x09" in exc_info.value.detail + + def test_delete_char_in_key_rejected(self): + """Key containing DEL character (0x7F) is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("key\x7fvalue") + assert exc_info.value.status_code == 400 + assert "control character" in exc_info.value.detail + assert "\\x7f" in exc_info.value.detail + + def test_bell_char_in_key_rejected(self): + """Key containing bell character is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("key\x07value") + assert exc_info.value.status_code == 400 + assert "control character" in exc_info.value.detail + + def test_control_char_position_reported(self): + """Error message includes position of control character.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("abc\x00def") + assert "position 3" in exc_info.value.detail + + # --- Edge cases that SHOULD be allowed --- + + def test_leading_space_allowed(self): + """Leading space is allowed (not whitespace-only).""" + # This might be surprising, but " key" has content + assert validate_key(" key") == " key" + + def test_trailing_space_allowed(self): + """Trailing space is allowed.""" + assert validate_key("key ") == "key " + + def test_path_traversal_string_allowed(self): + """Path traversal strings are allowed (no filesystem, just strings).""" + # These look suspicious but are harmless as KV keys + assert validate_key("../../../etc/passwd") == "../../../etc/passwd" + assert validate_key("..") == ".." + + def test_javascript_prototype_names_allowed(self): + """JavaScript prototype pollution names are allowed.""" + # These are valid key names, just be careful in JS clients + assert validate_key("__proto__") == "__proto__" + assert validate_key("constructor") == "constructor" + assert validate_key("toString") == "toString" + + def test_slashes_allowed(self): + """Slashes are allowed in keys.""" + assert validate_key("path/to/key") == "path/to/key" + + # --- Invalid keys: Reserved prefix --- + + def test_dollar_prefix_rejected(self): + """Key starting with $ is rejected (reserved for system use).""" + with pytest.raises(HTTPException) as exc_info: + validate_key("$version") + assert exc_info.value.status_code == 400 + assert "reserved" in exc_info.value.detail.lower() + + def test_dollar_prefix_any_name_rejected(self): + """Any key starting with $ is rejected.""" + with pytest.raises(HTTPException) as exc_info: + validate_key("$anything") + assert exc_info.value.status_code == 400 + assert "reserved" in exc_info.value.detail.lower() + + def test_dollar_in_middle_allowed(self): + """Dollar sign in middle of key is allowed.""" + assert validate_key("my$key") == "my$key" + + def test_dollar_at_end_allowed(self): + """Dollar sign at end of key is allowed.""" + assert validate_key("key$") == "key$" + + +# ============================================================================= +# Type Validation Tests +# ============================================================================= + + +class TestRequireNumeric: + """Tests for require_numeric() function. + + Protects against type confusion, especially the Python quirk + where bool is a subclass of int. + """ + + # --- Valid numeric values --- + + def test_integer_accepted(self): + """Integer values are accepted.""" + assert require_numeric(42) == 42 + assert require_numeric(0) == 0 + assert require_numeric(-1) == -1 + + def test_float_accepted(self): + """Float values are accepted.""" + assert require_numeric(3.14) == 3.14 + assert require_numeric(0.0) == 0.0 + assert require_numeric(-1.5) == -1.5 + + def test_large_integer_accepted(self): + """Large integers are accepted.""" + big = 10**100 + assert require_numeric(big) == big + + def test_scientific_notation_accepted(self): + """Scientific notation floats are accepted.""" + assert require_numeric(1e10) == 1e10 + assert require_numeric(1e-10) == 1e-10 + + # --- Boolean rejection (critical!) --- + + def test_true_rejected(self): + """Boolean True is rejected even though bool is subclass of int.""" + with pytest.raises(HTTPException) as exc_info: + require_numeric(True) + assert exc_info.value.status_code == 400 + assert "boolean" in exc_info.value.detail + + def test_false_rejected(self): + """Boolean False is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_numeric(False) + assert exc_info.value.status_code == 400 + assert "boolean" in exc_info.value.detail + + # --- Other non-numeric types --- + + def test_string_rejected(self): + """String values are rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_numeric("42") + assert exc_info.value.status_code == 400 + assert "not numeric" in exc_info.value.detail + + def test_numeric_string_rejected(self): + """Numeric-looking strings are rejected (no coercion).""" + with pytest.raises(HTTPException) as exc_info: + require_numeric("3.14") + assert exc_info.value.status_code == 400 + + def test_none_rejected(self): + """None is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_numeric(None) + assert exc_info.value.status_code == 400 + + def test_list_rejected(self): + """List is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_numeric([1, 2, 3]) + assert exc_info.value.status_code == 400 + + def test_dict_rejected(self): + """Dict is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_numeric({"value": 42}) + assert exc_info.value.status_code == 400 + + +class TestRequireInt: + """Tests for require_int() function. + + Stricter than require_numeric - used for incr/decr operations + where float arithmetic could cause precision loss. + """ + + # --- Valid integer values --- + + def test_integer_accepted(self): + """Integer values are accepted.""" + assert require_int(42) == 42 + assert require_int(0) == 0 + assert require_int(-1) == -1 + + def test_large_integer_accepted(self): + """Large integers are accepted.""" + big = 10**100 + assert require_int(big) == big + + # --- Float rejection (critical for incr/decr!) --- + + def test_float_rejected(self): + """Float values are rejected with helpful message.""" + with pytest.raises(HTTPException) as exc_info: + require_int(3.14) + assert exc_info.value.status_code == 400 + assert "float" in exc_info.value.detail + assert "integer" in exc_info.value.detail + + def test_whole_number_float_rejected(self): + """Even whole-number floats like 1.0 are rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_int(1.0) + assert exc_info.value.status_code == 400 + assert "float" in exc_info.value.detail + + def test_zero_float_rejected(self): + """0.0 is rejected (use 0 instead).""" + with pytest.raises(HTTPException) as exc_info: + require_int(0.0) + assert exc_info.value.status_code == 400 + + # --- Boolean rejection --- + + def test_true_rejected(self): + """Boolean True is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_int(True) + assert exc_info.value.status_code == 400 + assert "boolean" in exc_info.value.detail + + def test_false_rejected(self): + """Boolean False is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_int(False) + assert exc_info.value.status_code == 400 + assert "boolean" in exc_info.value.detail + + # --- Other types --- + + def test_string_rejected(self): + """String values are rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_int("42") + assert exc_info.value.status_code == 400 + assert "not an integer" in exc_info.value.detail + + def test_none_rejected(self): + """None is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_int(None) + assert exc_info.value.status_code == 400 + + +class TestRequireList: + """Tests for require_list() function.""" + + def test_empty_list_accepted(self): + """Empty list is accepted.""" + assert require_list([]) == [] + + def test_list_with_items_accepted(self): + """List with items is accepted.""" + assert require_list([1, 2, 3]) == [1, 2, 3] + + def test_nested_list_accepted(self): + """Nested list is accepted.""" + assert require_list([[1], [2]]) == [[1], [2]] + + def test_dict_rejected(self): + """Dict is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_list({}) + assert exc_info.value.status_code == 400 + assert "not a list" in exc_info.value.detail + + def test_string_rejected(self): + """String is rejected (even though iterable).""" + with pytest.raises(HTTPException) as exc_info: + require_list("hello") + assert exc_info.value.status_code == 400 + + def test_tuple_rejected(self): + """Tuple is rejected (we want explicit list type).""" + with pytest.raises(HTTPException) as exc_info: + require_list((1, 2, 3)) + assert exc_info.value.status_code == 400 + + +class TestRequireDict: + """Tests for require_dict() function.""" + + def test_empty_dict_accepted(self): + """Empty dict is accepted.""" + assert require_dict({}) == {} + + def test_dict_with_items_accepted(self): + """Dict with items is accepted.""" + assert require_dict({"key": "value"}) == {"key": "value"} + + def test_list_rejected(self): + """List is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_dict([]) + assert exc_info.value.status_code == 400 + assert "not an object" in exc_info.value.detail + + def test_string_rejected(self): + """String is rejected.""" + with pytest.raises(HTTPException) as exc_info: + require_dict("hello") + assert exc_info.value.status_code == 400 diff --git a/tests/test_kv_router.py b/tests/test_kv_router.py new file mode 100644 index 0000000..250f962 --- /dev/null +++ b/tests/test_kv_router.py @@ -0,0 +1,832 @@ +"""Tests for KV store API endpoints. + +Testing Strategy +================ + +This module uses two different test client fixtures depending on the test type: + +1. `kv_client` - For most tests (single-request tests) + - Overrides `get_session` to use a SHARED async_session + - All requests go through the same database session/connection + - Simpler setup, good for testing individual endpoint behavior + - ⚠️ NOT suitable for concurrent request tests (causes deadlocks) + +2. `concurrent_kv_client` - For concurrency tests ONLY + - Does NOT override `get_session` + - Each request gets its own session from the session factory + - Enables true concurrent database operations with separate connections + - Required for testing FOR UPDATE locking behavior + +Single-Document Backend +----------------------- +The KV store uses a single-document backend where each automation has exactly +ONE row containing all its state as an encrypted JSON document. API "keys" +are top-level fields in that document. + +This design eliminates deadlock risk: all operations serialize through a single +row lock per automation. +""" + +import uuid +from typing import Any + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.app import app +from openhands.automation.db import get_session +from openhands.automation.kv_router import get_token_claims +from openhands.automation.models import Automation, AutomationKV +from openhands.automation.utils.kv import ( + KVTokenClaims, + decrypt_value, + encrypt_value, +) + + +# Test UUIDs +TEST_USER_ID = uuid.UUID("12345678-1234-5678-1234-567812345678") +TEST_ORG_ID = uuid.UUID("87654321-4321-8765-4321-876543218765") +TEST_AUTOMATION_ID = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") +TEST_RUN_ID = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") + +# Test secret for JWT and encryption +TEST_KV_SECRET = "test-kv-secret-key-for-testing-only" + + +# ============================================================================= +# Test Data Helpers (Single-Document Backend) +# ============================================================================= + + +async def create_test_state( + session: AsyncSession, + automation_id: uuid.UUID, + state: dict[str, Any], + secret: str = TEST_KV_SECRET, +) -> AutomationKV: + """Create a state row for an automation with the given state dict. + + In the single-document model, each automation has ONE row containing + all keys as top-level fields in the encrypted JSON document. + """ + encrypted = encrypt_value(secret, state) + row = AutomationKV( + automation_id=automation_id, + state_encrypted=encrypted, + ) + session.add(row) + await session.flush() + return row + + +async def get_test_state( + session: AsyncSession, + automation_id: uuid.UUID, + secret: str = TEST_KV_SECRET, +) -> dict[str, Any] | None: + """Get the current state dict for an automation. + + Returns None if no state row exists. + """ + result = await session.execute( + select(AutomationKV).where(AutomationKV.automation_id == automation_id) + ) + row = result.scalars().first() + if row is None: + return None + return decrypt_value(secret, row.state_encrypted) + + +# ============================================================================= +# Test Client Fixtures +# ============================================================================= + + +@pytest.fixture +async def kv_client(async_engine, async_session_factory, async_session, monkeypatch): + """Create an async test client with KV token auth (shared session).""" + monkeypatch.setenv("AUTOMATION_KV_SECRET", TEST_KV_SECRET) + + from openhands.automation.config import clear_config_cache + + clear_config_cache() + + async def override_get_session(): + yield async_session + + async def override_get_token_claims(): + return KVTokenClaims(automation_id=TEST_AUTOMATION_ID) + + app.dependency_overrides[get_session] = override_get_session + app.dependency_overrides[get_token_claims] = override_get_token_claims + + app.state.engine = async_engine + app.state.session_factory = async_session_factory + + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="http://test", + ) as client: + yield client + + app.dependency_overrides.clear() + clear_config_cache() + + +@pytest.fixture(autouse=True) +async def automation_with_kv(async_session): + """Create a test automation (KV store is always available).""" + automation = Automation( + id=TEST_AUTOMATION_ID, + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + name="Test Automation with KV", + trigger={"type": "cron", "schedule": "0 9 * * *", "timezone": "UTC"}, + tarball_path="s3://bucket/code.tar.gz", + entrypoint="uv run script.py", + ) + async_session.add(automation) + await async_session.commit() + return automation + + +# ============================================================================= +# Token and Encryption Tests +# ============================================================================= + + +class TestKVTokenAuth: + """Tests for KV token authentication.""" + + def test_create_and_verify_token(self): + """Token can be created and verified.""" + from openhands.automation.utils.kv import create_kv_token, verify_kv_token + + token = create_kv_token( + secret=TEST_KV_SECRET, + automation_id=TEST_AUTOMATION_ID, + run_id=TEST_RUN_ID, + ) + + result = verify_kv_token(TEST_KV_SECRET, token) + assert result.automation_id == TEST_AUTOMATION_ID + + def test_invalid_token_raises_error(self): + """Invalid token raises KVTokenError.""" + from openhands.automation.utils.kv import KVTokenError, verify_kv_token + + with pytest.raises(KVTokenError): + verify_kv_token(TEST_KV_SECRET, "invalid-token") + + +class TestKVEncryption: + """Tests for KV value encryption.""" + + def test_encrypt_decrypt_dict(self): + """Dict values can be encrypted and decrypted.""" + original = {"key": "value", "nested": {"a": 1}} + encrypted = encrypt_value(TEST_KV_SECRET, original) + decrypted = decrypt_value(TEST_KV_SECRET, encrypted) + assert decrypted == original + + def test_encrypt_decrypt_list(self): + """List values can be encrypted and decrypted.""" + original = [1, 2, {"key": "value"}] + encrypted = encrypt_value(TEST_KV_SECRET, original) + decrypted = decrypt_value(TEST_KV_SECRET, encrypted) + assert decrypted == original + + +# ============================================================================= +# API Endpoint Tests +# ============================================================================= + + +class TestListKeys: + """Tests for GET /kv endpoint.""" + + async def test_list_keys_empty(self, kv_client): + """List keys returns empty when no keys exist.""" + response = await kv_client.get("/api/automation/v1/kv") + + assert response.status_code == 200 + data = response.json() + assert data["keys"] == [] + assert data["count"] == 0 + + async def test_list_keys_with_data(self, kv_client, async_session): + """List keys returns all keys for the automation.""" + await create_test_state( + async_session, + TEST_AUTOMATION_ID, + {"config": {"test": True}, "counter": 42, "queue": []}, + ) + await async_session.commit() + + response = await kv_client.get("/api/automation/v1/kv") + + assert response.status_code == 200 + data = response.json() + assert set(data["keys"]) == {"config", "counter", "queue"} + assert data["count"] == 3 + + +class TestGetValue: + """Tests for GET /kv/{key} endpoint.""" + + async def test_get_value_not_found(self, kv_client): + """Get non-existent key returns 404.""" + response = await kv_client.get("/api/automation/v1/kv/nonexistent") + + assert response.status_code == 404 + assert response.json()["detail"] == "key_not_found" + + async def test_get_value_success(self, kv_client, async_session): + """Get existing key returns value.""" + value = {"database": {"host": "localhost", "port": 5432}} + await create_test_state(async_session, TEST_AUTOMATION_ID, {"config": value}) + await async_session.commit() + + response = await kv_client.get("/api/automation/v1/kv/config") + + assert response.status_code == 200 + data = response.json() + assert data["key"] == "config" + assert data["value"] == value + + async def test_get_value_with_path(self, kv_client, async_session): + """Get nested path returns specific value.""" + value = {"database": {"host": "localhost", "port": 5432}} + await create_test_state(async_session, TEST_AUTOMATION_ID, {"config": value}) + await async_session.commit() + + response = await kv_client.get( + "/api/automation/v1/kv/config?path=database.host" + ) + + assert response.status_code == 200 + data = response.json() + assert data["key"] == "config" + assert data["path"] == "database.host" + assert data["value"] == "localhost" + + async def test_get_value_with_meta(self, kv_client, async_session): + """Get with meta=true returns timestamps.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"config": "test"}) + await async_session.commit() + + response = await kv_client.get("/api/automation/v1/kv/config?meta=true") + + assert response.status_code == 200 + data = response.json() + assert "created_at" in data + assert "updated_at" in data + + +class TestSetValue: + """Tests for PUT /kv/{key} endpoint.""" + + async def test_set_new_value(self, kv_client): + """Set creates new key (returns 201 Created).""" + response = await kv_client.put( + "/api/automation/v1/kv/config", + json={"setting": "value"}, + ) + + assert response.status_code == 201 + data = response.json() + assert data["key"] == "config" + assert data["value"] == {"setting": "value"} + assert data["created"] is True + + async def test_set_updates_existing_value(self, kv_client, async_session): + """Set updates existing key (returns 200 OK).""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"config": "old"}) + await async_session.commit() + + response = await kv_client.put( + "/api/automation/v1/kv/config", + json="new", + ) + + assert response.status_code == 200 + data = response.json() + assert data["value"] == "new" + assert data["created"] is False + + async def test_set_with_nx_creates_new(self, kv_client): + """Set with nx=true creates new key.""" + response = await kv_client.put( + "/api/automation/v1/kv/lock?nx=true", + json={"owner": "run-123"}, + ) + + assert response.status_code == 201 + assert response.json()["created"] is True + + async def test_set_with_nx_fails_if_exists(self, kv_client, async_session): + """Set with nx=true fails if key exists.""" + await create_test_state( + async_session, TEST_AUTOMATION_ID, {"lock": {"owner": "other"}} + ) + await async_session.commit() + + response = await kv_client.put( + "/api/automation/v1/kv/lock?nx=true", + json={"owner": "run-123"}, + ) + + assert response.status_code == 409 + assert response.json()["error"] == "key_exists" + + async def test_set_with_xx_updates_existing(self, kv_client, async_session): + """Set with xx=true updates existing key.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"config": "old"}) + await async_session.commit() + + response = await kv_client.put( + "/api/automation/v1/kv/config?xx=true", + json="new", + ) + + assert response.status_code == 200 + assert response.json()["value"] == "new" + + async def test_set_with_xx_fails_if_not_exists(self, kv_client): + """Set with xx=true fails if key doesn't exist.""" + response = await kv_client.put( + "/api/automation/v1/kv/nonexistent?xx=true", + json="value", + ) + + assert response.status_code == 409 + + +class TestPatchValue: + """Tests for PATCH /kv/{key} endpoint.""" + + async def test_patch_updates_nested_path(self, kv_client, async_session): + """Patch updates a nested path.""" + await create_test_state( + async_session, + TEST_AUTOMATION_ID, + {"config": {"database": {"host": "old", "port": 5432}}}, + ) + await async_session.commit() + + response = await kv_client.patch( + "/api/automation/v1/kv/config", + json={"path": "database.host", "value": "new"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["path"] == "database.host" + assert data["value"] == "new" + + async def test_patch_not_found(self, kv_client): + """Patch returns 404 for non-existent key.""" + response = await kv_client.patch( + "/api/automation/v1/kv/nonexistent", + json={"path": "some.path", "value": "value"}, + ) + + assert response.status_code == 404 + + +class TestDeleteKey: + """Tests for DELETE /kv/{key} endpoint.""" + + async def test_delete_existing_key(self, kv_client, async_session): + """Delete removes existing key.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"config": "test"}) + await async_session.commit() + + response = await kv_client.delete("/api/automation/v1/kv/config") + + assert response.status_code == 200 + assert response.json()["deleted"] is True + + async def test_delete_nonexistent_key(self, kv_client): + """Delete returns deleted=false for non-existent key.""" + response = await kv_client.delete("/api/automation/v1/kv/nonexistent") + + assert response.status_code == 200 + assert response.json()["deleted"] is False + + +class TestIncrement: + """Tests for POST /kv/{key}/incr endpoint.""" + + async def test_incr_creates_key(self, kv_client, async_session): + """Incr creates key with value if it doesn't exist.""" + response = await kv_client.post("/api/automation/v1/kv/counter/incr") + + assert response.status_code == 200 + assert response.json()["value"] == 1 + + async def test_incr_increments_existing(self, kv_client, async_session): + """Incr increments existing integer value.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"counter": 5}) + await async_session.commit() + + response = await kv_client.post("/api/automation/v1/kv/counter/incr") + + assert response.status_code == 200 + assert response.json()["value"] == 6 + + async def test_incr_by_custom_amount(self, kv_client, async_session): + """Incr with by parameter increments by that amount.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"counter": 10}) + await async_session.commit() + + response = await kv_client.post( + "/api/automation/v1/kv/counter/incr", + json={"by": 5}, + ) + + assert response.status_code == 200 + assert response.json()["value"] == 15 + + async def test_incr_rejects_non_integer(self, kv_client, async_session): + """Incr rejects non-integer values.""" + await create_test_state( + async_session, TEST_AUTOMATION_ID, {"counter": {"not": "numeric"}} + ) + await async_session.commit() + + response = await kv_client.post("/api/automation/v1/kv/counter/incr") + + assert response.status_code == 400 + assert "type_mismatch" in response.json()["detail"] + + +class TestDecrement: + """Tests for POST /kv/{key}/decr endpoint.""" + + async def test_decr_creates_key(self, kv_client): + """Decr creates key with negative value if it doesn't exist.""" + response = await kv_client.post("/api/automation/v1/kv/counter/decr") + + assert response.status_code == 200 + assert response.json()["value"] == -1 + + async def test_decr_decrements_existing(self, kv_client, async_session): + """Decr decrements existing integer value.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"counter": 5}) + await async_session.commit() + + response = await kv_client.post("/api/automation/v1/kv/counter/decr") + + assert response.status_code == 200 + assert response.json()["value"] == 4 + + +class TestListOperations: + """Tests for list push/pop operations.""" + + async def test_lpush_creates_list(self, kv_client): + """LPUSH creates list if it doesn't exist.""" + response = await kv_client.post( + "/api/automation/v1/kv/queue/lpush", + json={"value": "first"}, + ) + + assert response.status_code == 200 + assert response.json()["length"] == 1 + + async def test_rpush_creates_list(self, kv_client): + """RPUSH creates list if it doesn't exist.""" + response = await kv_client.post( + "/api/automation/v1/kv/queue/rpush", + json={"value": "first"}, + ) + + assert response.status_code == 200 + assert response.json()["length"] == 1 + + async def test_lpush_prepends(self, kv_client, async_session): + """LPUSH prepends to existing list.""" + await create_test_state( + async_session, TEST_AUTOMATION_ID, {"queue": ["second"]} + ) + await async_session.commit() + + response = await kv_client.post( + "/api/automation/v1/kv/queue/lpush", + json={"value": "first"}, + ) + + assert response.status_code == 200 + assert response.json()["length"] == 2 + + # Verify order + state = await get_test_state(async_session, TEST_AUTOMATION_ID) + assert state is not None + assert state["queue"] == ["first", "second"] + + async def test_rpush_appends(self, kv_client, async_session): + """RPUSH appends to existing list.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"queue": ["first"]}) + await async_session.commit() + + response = await kv_client.post( + "/api/automation/v1/kv/queue/rpush", + json={"value": "second"}, + ) + + assert response.status_code == 200 + assert response.json()["length"] == 2 + + # Verify order + state = await get_test_state(async_session, TEST_AUTOMATION_ID) + assert state is not None + assert state["queue"] == ["first", "second"] + + async def test_lpop_returns_first(self, kv_client, async_session): + """LPOP returns and removes first element.""" + await create_test_state( + async_session, TEST_AUTOMATION_ID, {"queue": ["first", "second", "third"]} + ) + await async_session.commit() + + response = await kv_client.post("/api/automation/v1/kv/queue/lpop") + + assert response.status_code == 200 + assert response.json()["value"] == "first" + + # Verify remaining + state = await get_test_state(async_session, TEST_AUTOMATION_ID) + assert state is not None + assert state["queue"] == ["second", "third"] + + async def test_rpop_returns_last(self, kv_client, async_session): + """RPOP returns and removes last element.""" + await create_test_state( + async_session, TEST_AUTOMATION_ID, {"queue": ["first", "second", "third"]} + ) + await async_session.commit() + + response = await kv_client.post("/api/automation/v1/kv/queue/rpop") + + assert response.status_code == 200 + assert response.json()["value"] == "third" + + # Verify remaining + state = await get_test_state(async_session, TEST_AUTOMATION_ID) + assert state is not None + assert state["queue"] == ["first", "second"] + + async def test_lpop_empty_returns_null(self, kv_client, async_session): + """LPOP on empty list returns null.""" + await create_test_state(async_session, TEST_AUTOMATION_ID, {"queue": []}) + await async_session.commit() + + response = await kv_client.post("/api/automation/v1/kv/queue/lpop") + + assert response.status_code == 200 + assert response.json()["value"] is None + + async def test_rpop_nonexistent_returns_null(self, kv_client): + """RPOP on non-existent key returns null.""" + response = await kv_client.post("/api/automation/v1/kv/nonexistent/rpop") + + assert response.status_code == 200 + assert response.json()["value"] is None + + async def test_len_returns_length(self, kv_client, async_session): + """LEN returns list length.""" + await create_test_state( + async_session, TEST_AUTOMATION_ID, {"queue": [1, 2, 3, 4, 5]} + ) + await async_session.commit() + + response = await kv_client.get("/api/automation/v1/kv/queue/len") + + assert response.status_code == 200 + assert response.json()["length"] == 5 + + async def test_len_not_found(self, kv_client): + """LEN returns 404 for non-existent key.""" + response = await kv_client.get("/api/automation/v1/kv/nonexistent/len") + + assert response.status_code == 404 + + async def test_lpush_rejects_non_list(self, kv_client, async_session): + """LPUSH rejects non-list values.""" + await create_test_state( + async_session, TEST_AUTOMATION_ID, {"notlist": {"not": "a list"}} + ) + await async_session.commit() + + response = await kv_client.post( + "/api/automation/v1/kv/notlist/lpush", + json={"value": "item"}, + ) + + assert response.status_code == 400 + assert "type_mismatch" in response.json()["detail"] + + +class TestKeyValidation: + """Tests for key name validation.""" + + async def test_empty_key_rejected(self, kv_client): + """Empty key is rejected.""" + # FastAPI will return 404 for empty path segment, not 400 + response = await kv_client.get("/api/automation/v1/kv/") + assert response.status_code in (404, 307) # Redirect or not found + + async def test_very_long_key_rejected(self, kv_client): + """Key exceeding 255 chars is rejected.""" + long_key = "x" * 300 + response = await kv_client.get(f"/api/automation/v1/kv/{long_key}") + + assert response.status_code == 400 + assert "invalid_key" in response.json()["detail"] + + +class TestSingleDocumentIsolation: + """Tests verifying single-document design properties.""" + + async def test_multiple_keys_in_one_doc(self, kv_client, async_session): + """Multiple keys are stored in a single document.""" + # Create first key + await kv_client.put("/api/automation/v1/kv/key1", json="value1") + # Create second key + await kv_client.put("/api/automation/v1/kv/key2", json="value2") + # Create third key + await kv_client.put("/api/automation/v1/kv/key3", json="value3") + + # Verify all keys are in one state document + state = await get_test_state(async_session, TEST_AUTOMATION_ID) + assert state is not None + # Filter out system keys ($version) for comparison + user_keys = {k: v for k, v in state.items() if not k.startswith("$")} + assert user_keys == {"key1": "value1", "key2": "value2", "key3": "value3"} + # $version should be present and incremented (3 writes) + assert state.get("$version") == 3 + + # Verify only ONE row exists in the database + result = await async_session.execute( + select(AutomationKV).where(AutomationKV.automation_id == TEST_AUTOMATION_ID) + ) + rows = result.scalars().all() + assert len(rows) == 1 + + async def test_delete_last_key_removes_row(self, kv_client, async_session): + """Deleting the last user key removes the state row entirely.""" + # Create a key + await kv_client.put("/api/automation/v1/kv/onlykey", json="value") + + # Delete it + await kv_client.delete("/api/automation/v1/kv/onlykey") + + # Verify row is gone (no user keys remain, so row is deleted) + result = await async_session.execute( + select(AutomationKV).where(AutomationKV.automation_id == TEST_AUTOMATION_ID) + ) + rows = result.scalars().all() + assert len(rows) == 0 + + async def test_operations_preserve_other_keys(self, kv_client, async_session): + """Operations on one key don't affect other keys.""" + # Create state with multiple keys + await create_test_state( + async_session, + TEST_AUTOMATION_ID, + {"counter": 10, "config": {"setting": True}, "queue": ["item"]}, + ) + await async_session.commit() + + # Increment counter + await kv_client.post("/api/automation/v1/kv/counter/incr") + + # Verify other keys are unchanged + state = await get_test_state(async_session, TEST_AUTOMATION_ID) + assert state is not None + assert state["counter"] == 11 + assert state["config"] == {"setting": True} + assert state["queue"] == ["item"] + + +# ============================================================================= +# Tests for if_version on individual endpoints +# ============================================================================= + + +class TestIfVersionOnIndividualEndpoints: + """Test if_version query parameter for optimistic concurrency.""" + + async def test_set_with_matching_version_succeeds(self, kv_client, async_session): + """PUT with matching if_version succeeds.""" + # Create initial key (version becomes 1) + resp = await kv_client.put("/api/automation/v1/kv/foo", json="bar") + assert resp.status_code == 201 + + # Update with correct version + resp = await kv_client.put("/api/automation/v1/kv/foo?if_version=1", json="baz") + assert resp.status_code == 200 + assert resp.json()["value"] == "baz" + + async def test_set_with_mismatched_version_fails(self, kv_client, async_session): + """PUT with wrong if_version returns 409.""" + # Create initial key (version becomes 1) + resp = await kv_client.put("/api/automation/v1/kv/foo", json="bar") + assert resp.status_code == 201 + + # Try to update with wrong version + resp = await kv_client.put( + "/api/automation/v1/kv/foo?if_version=99", json="baz" + ) + assert resp.status_code == 409 + data = resp.json()["detail"] + assert data["error"] == "version_mismatch" + assert data["expected_version"] == 99 + assert data["actual_version"] == 1 + + async def test_patch_with_matching_version_succeeds(self, kv_client, async_session): + """PATCH with matching if_version succeeds.""" + # Create initial key with dict value (version becomes 1) + resp = await kv_client.put( + "/api/automation/v1/kv/config", json={"host": "localhost"} + ) + assert resp.status_code == 201 + + # Patch with correct version + resp = await kv_client.patch( + "/api/automation/v1/kv/config?if_version=1", + json={"path": "port", "value": 5432}, + ) + assert resp.status_code == 200 + + async def test_patch_with_mismatched_version_fails(self, kv_client, async_session): + """PATCH with wrong if_version returns 409.""" + # Create initial key (version becomes 1) + resp = await kv_client.put( + "/api/automation/v1/kv/config", json={"host": "localhost"} + ) + assert resp.status_code == 201 + + # Try to patch with wrong version + resp = await kv_client.patch( + "/api/automation/v1/kv/config?if_version=99", + json={"path": "port", "value": 5432}, + ) + assert resp.status_code == 409 + data = resp.json()["detail"] + assert data["error"] == "version_mismatch" + + async def test_delete_with_matching_version_succeeds( + self, kv_client, async_session + ): + """DELETE with matching if_version succeeds.""" + # Create initial key (version becomes 1) + resp = await kv_client.put("/api/automation/v1/kv/foo", json="bar") + assert resp.status_code == 201 + + # Delete with correct version + resp = await kv_client.delete("/api/automation/v1/kv/foo?if_version=1") + assert resp.status_code == 200 + assert resp.json()["deleted"] is True + + async def test_delete_with_mismatched_version_fails(self, kv_client, async_session): + """DELETE with wrong if_version returns 409.""" + # Create initial key (version becomes 1) + resp = await kv_client.put("/api/automation/v1/kv/foo", json="bar") + assert resp.status_code == 201 + + # Try to delete with wrong version + resp = await kv_client.delete("/api/automation/v1/kv/foo?if_version=99") + assert resp.status_code == 409 + data = resp.json()["detail"] + assert data["error"] == "version_mismatch" + assert data["expected_version"] == 99 + assert data["actual_version"] == 1 + + async def test_version_increments_across_operations(self, kv_client, async_session): + """Version increments consistently across different operations.""" + # Create (v=1) + resp = await kv_client.put("/api/automation/v1/kv/foo", json="bar") + assert resp.status_code == 201 + + # Update (v=2) + resp = await kv_client.put("/api/automation/v1/kv/foo", json="baz") + assert resp.status_code == 200 + + # Incr new key (v=3) + resp = await kv_client.post("/api/automation/v1/kv/counter/incr") + assert resp.status_code == 200 + + # Get with meta to check version + resp = await kv_client.get("/api/automation/v1/kv/foo?meta=true") + assert resp.status_code == 200 + assert resp.json()["version"] == 3 + + # Delete with version check should work + resp = await kv_client.delete("/api/automation/v1/kv/foo?if_version=3") + assert resp.status_code == 200 + assert resp.json()["deleted"] is True diff --git a/uv.lock b/uv.lock index 8cc2233..e5cd550 100644 --- a/uv.lock +++ b/uv.lock @@ -2175,8 +2175,10 @@ dependencies = [ { name = "openhands-sdk" }, { name = "openhands-workspace" }, { name = "pg8000" }, + { name = "prometheus-client" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "python-json-logger" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "tenacity" }, @@ -2213,8 +2215,10 @@ requires-dist = [ { name = "openhands-sdk", specifier = "==1.22.1" }, { name = "openhands-workspace", specifier = "==1.22.1" }, { name = "pg8000", specifier = ">=1.31" }, + { name = "prometheus-client", specifier = ">=0.19" }, { name = "pydantic", specifier = ">=2" }, { name = "pydantic-settings", specifier = ">=2" }, + { name = "pyjwt", specifier = ">=2.8" }, { name = "python-json-logger", specifier = ">=3" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2" }, { name = "tenacity", specifier = ">=9.1.4" }, @@ -2600,6 +2604,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/80/6e/4b28b62ecb6aae56769c34a8ff1d661473ec1e9519e2d5f8b2c150086b26/pre_commit-4.6.0-py2.py3-none-any.whl", hash = "sha256:e2cf246f7299edcabcf15f9b0571fdce06058527f0a06535068a86d38089f29b", size = 226472, upload-time = "2026-04-21T20:31:40.092Z" }, ] +[[package]] +name = "prometheus-client" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/fb/d9aa83ffe43ce1f19e557c0971d04b90561b0cfd50762aafb01968285553/prometheus_client-0.25.0.tar.gz", hash = "sha256:5e373b75c31afb3c86f1a52fa1ad470c9aace18082d39ec0d2f918d11cc9ba28", size = 86035, upload-time = "2026-04-09T19:53:42.359Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl", hash = "sha256:d5aec89e349a6ec230805d0df882f3807f74fd6c1a2fa86864e3c2279059fed1", size = 64154, upload-time = "2026-04-09T19:53:41.324Z" }, +] + [[package]] name = "propcache" version = "0.4.1"