From d0c567a7777db505bc650595e49d93bbbf194ca7 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Mon, 22 Jun 2026 15:50:10 +0800 Subject: [PATCH 1/6] feat(device): add callback public URL handling --- docker/.env.example | 2 + docker/docker-compose.dev.yml | 1 + docker/docker-compose.test.yml | 1 + docs/designs/callback-public-base-url.zh.md | 301 ++++++++++++++++++ docs/designs/device-registration-api.md | 198 +++++++++++- internal/api/handlers/axon_rpc.go | 33 +- internal/api/handlers/callback_urls.go | 49 +++ internal/api/handlers/device_registration.go | 36 ++- .../api/handlers/device_registration_test.go | 10 +- .../recorder_axon_interaction_test.go | 9 +- internal/api/handlers/task.go | 13 +- .../api/handlers/task_callback_config_test.go | 136 ++++++++ internal/config/config.go | 56 +++- internal/config/config_test.go | 105 ++++-- internal/server/server.go | 4 +- 15 files changed, 884 insertions(+), 70 deletions(-) create mode 100644 docs/designs/callback-public-base-url.zh.md create mode 100644 internal/api/handlers/callback_urls.go create mode 100644 internal/api/handlers/task_callback_config_test.go diff --git a/docker/.env.example b/docker/.env.example index 34ec451..77f60c6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -11,6 +11,8 @@ KEYSTONE_FACTORY_ID=factory-archebase # ----------------------------------------------------------------------------- KEYSTONE_MODE=edge KEYSTONE_BIND_ADDR=:9999 +# URL that Axon recorder can use to call Keystone callback APIs. +KEYSTONE_CALLBACK_PUBLIC_BASE_URL=http://ip:9999 KEYSTONE_READ_TIMEOUT=30 KEYSTONE_WRITE_TIMEOUT=30 KEYSTONE_SHUTDOWN_TIMEOUT=10 diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index 4ef6d7a..49d17a3 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -19,6 +19,7 @@ services: - go-modules:/root/go/pkg/mod environment: - KEYSTONE_BIND_ADDR=:8080 + - KEYSTONE_CALLBACK_PUBLIC_BASE_URL=http://localhost:8080 - KEYSTONE_MINIO_ENDPOINT=http://minio:9000 - KEYSTONE_MINIO_ACCESS_KEY=minioadmin - KEYSTONE_MINIO_SECRET_KEY=minioadmin diff --git a/docker/docker-compose.test.yml b/docker/docker-compose.test.yml index 96b9c5e..dcf1ecd 100644 --- a/docker/docker-compose.test.yml +++ b/docker/docker-compose.test.yml @@ -15,6 +15,7 @@ services: - "8080:8080" environment: - KEYSTONE_BIND_ADDR=:8080 + - KEYSTONE_CALLBACK_PUBLIC_BASE_URL=http://localhost:8080 - KEYSTONE_SYNC_ENABLED=false - KEYSTONE_MINIO_ENDPOINT=http://minio:9000 - KEYSTONE_MINIO_ACCESS_KEY=minioadmin diff --git a/docs/designs/callback-public-base-url.zh.md b/docs/designs/callback-public-base-url.zh.md new file mode 100644 index 0000000..2e3a22d --- /dev/null +++ b/docs/designs/callback-public-base-url.zh.md @@ -0,0 +1,301 @@ + + +# Keystone Callback Public Base URL 设计 + +**状态:** proposed + +**范围:** Keystone 配置、设备注册响应、任务配置 callback URL、Synapse 下发任务逻辑。 + +## 1. 背景 + +Axon recorder 在机器人侧执行任务。任务开始和结束时,recorder 会通过 HTTP POST +回调 Keystone: + +- `POST /api/v1/callbacks/start` +- `POST /api/v1/callbacks/finish` + +当前 Synapse 在下发任务配置时,用浏览器当前 origin 拼 callback URL: + +```js +new URL('/api/v1/callbacks/start', window.location.origin) +new URL('/api/v1/callbacks/finish', window.location.origin) +``` + +开发环境中,如果浏览器打开的是 `http://localhost:5174`,那么下发给 recorder 的 +callback URL 就会变成: + +```text +http://localhost:5174/api/v1/callbacks/start +http://localhost:5174/api/v1/callbacks/finish +``` + +但 Keystone 后端实际服务端口可能是 `9999`,`5174` 只是 Vite 前端开发服务器端口。 +Vite 可以把浏览器请求代理到 Keystone,但机器人侧 recorder 不应该依赖前端开发服务器。 + +因此 Keystone 需要显式配置“机器人应该用哪个地址访问 Keystone callback API”。 + +## 2. 新增配置 + +新增环境变量: + +```text +KEYSTONE_CALLBACK_PUBLIC_BASE_URL +``` + +含义:Keystone 对 Axon recorder 可访问的 callback 公共基地址。 + +示例: + +```text +KEYSTONE_CALLBACK_PUBLIC_BASE_URL=http://192.168.1.20:9999 +KEYSTONE_CALLBACK_PUBLIC_BASE_URL=https://keystone.factory.internal +``` + +该配置放在 Keystone 的 `ServerConfig` 中: + +```go +type ServerConfig struct { + Mode string + BindAddr string + CallbackPublicBaseURL string + ReadTimeout int + WriteTimeout int + ShutdownTimeout int +} +``` + +## 3. 配置校验规则 + +`KEYSTONE_CALLBACK_PUBLIC_BASE_URL` 必须显式配置。未配置时,Keystone 启动失败。 + +校验规则: + +| 规则 | 要求 | +| --- | --- | +| 非空 | 必须填写 | +| URL 类型 | 必须是绝对 URL | +| scheme | 只允许 `http` 或 `https` | +| host | 必须非空 | +| path | 必须为空或 `/` | +| query | 必须为空 | +| fragment | 必须为空 | + +允许: + +```text +http://192.168.1.20:9999 +https://keystone.factory.internal +http://keystone.factory.internal/ +``` + +不允许: + +```text +192.168.1.20:9999 +ftp://192.168.1.20:9999 +http:///api +http://gateway.local/keystone +http://gateway.local?x=1 +http://gateway.local#abc +``` + +校验通过后,Keystone 应规范化该值,去掉末尾 `/`: + +```text +http://keystone.factory.internal/ +-> http://keystone.factory.internal +``` + +## 4. 为什么不允许 base path + +本轮设计不允许: + +```text +http://gateway.local/keystone +``` + +原因是 Axon 的 callback allowlist 只返回固定路径前缀: + +```text +/api/v1/callbacks/ +``` + +如果 base URL 带 `/keystone`,recorder 实际访问路径会变成: + +```text +/keystone/api/v1/callbacks/start +``` + +这会和 allowlist 的 `/api/v1/callbacks/` 对不上。 + +所以本轮约定:`KEYSTONE_CALLBACK_PUBLIC_BASE_URL` 只表达 scheme、host、port,不表达路径前缀。 + +## 5. Keystone 派生出来的值 + +Keystone 从 `KEYSTONE_CALLBACK_PUBLIC_BASE_URL` 派生任务 callback URL: + +```text +start_callback_url = KEYSTONE_CALLBACK_PUBLIC_BASE_URL + /api/v1/callbacks/start +finish_callback_url = KEYSTONE_CALLBACK_PUBLIC_BASE_URL + /api/v1/callbacks/finish +``` + +示例: + +```text +KEYSTONE_CALLBACK_PUBLIC_BASE_URL=http://192.168.1.20:9999 +``` + +派生结果: + +```json +{ + "start_callback_url": "http://192.168.1.20:9999/api/v1/callbacks/start", + "finish_callback_url": "http://192.168.1.20:9999/api/v1/callbacks/finish" +} +``` + +Keystone 从同一个配置派生注册响应中的 callback allowlist: + +```text +allowed_host = URL(KEYSTONE_CALLBACK_PUBLIC_BASE_URL).Host +allowed_path_prefix = "/api/v1/callbacks/" +``` + +示例: + +```json +{ + "callback_allowlist": { + "allowed_host": "192.168.1.20:9999", + "allowed_path_prefix": "/api/v1/callbacks/" + } +} +``` + +如果 URL 使用默认端口,不额外补 `:80` 或 `:443`: + +```text +https://keystone.factory.internal +-> allowed_host = keystone.factory.internal +``` + +## 6. 设备注册响应 + +`POST /api/v1/devices/register` 注册成功后,应返回 callback allowlist。注册响应也会返回 +一次性明文 `ws_client_auth_token`,但该 token 的签发、存储和 WebSocket 校验规则由 +`device-registration-api.md` 定义。 + +```json +{ + "device_id": "factory01-type02-0007", + "factory": "上海一厂", + "factory_id": "1", + "robot_type": "搬运机器人", + "robot_type_id": "2", + "robot_id": "42", + "ws_client_auth_token": "kws_v1_example", + "callback_allowlist": { + "allowed_host": "192.168.1.20:9999", + "allowed_path_prefix": "/api/v1/callbacks/" + } +} +``` + +注册响应只返回“允许访问范围”,不返回具体的 `start_callback_url` / +`finish_callback_url`。 + +## 7. 任务配置响应 + +`GET /api/v1/tasks/:id/config` 应返回 Keystone 服务端生成的 callback URL: + +```json +{ + "task_id": "task_20260622_001", + "device_id": "factory01-type02-0007", + "start_callback_url": "http://192.168.1.20:9999/api/v1/callbacks/start", + "finish_callback_url": "http://192.168.1.20:9999/api/v1/callbacks/finish" +} +``` + +任务配置只返回“本次任务具体回调地址”,不返回 `callback_allowlist`。 + +## 8. Keystone 下发安全边界 + +前端请求体里即使带了 callback URL,Keystone 也不应该信任它。 + +服务端在发送 recorder config RPC 前,应统一覆盖: + +```text +start_callback_url = 服务端生成值 +finish_callback_url = 服务端生成值 +``` + +这样可以避免 Synapse、脚本或其他调用方把 `localhost:5174`、错误 IP、或任意第三方地址塞进任务配置。 + +## 9. Synapse 行为 + +Synapse 不再使用 `window.location.origin` 拼 callback URL。 + +推荐行为: + +- 打开任务配置时,从 Keystone 的 `GET /tasks/:id/config` 获取 `start_callback_url` + 和 `finish_callback_url`。 +- 下发 recorder config 前,检查这两个字段是否非空。 +- 如果缺失,阻止下发并提示 Keystone 配置错误。 + +错误提示建议: + +```text +后端未返回 callback URL,请检查 Keystone 的 KEYSTONE_CALLBACK_PUBLIC_BASE_URL 配置 +``` + +Synapse 不做 `5174` 兜底,也不从浏览器 origin 推导 callback 地址。 + +## 10. 开发环境 + +代码默认值仍然为空,强制显式配置。 + +Docker 开发环境可以预填: + +```text +KEYSTONE_CALLBACK_PUBLIC_BASE_URL=http://localhost:8080 +``` + +如果本机开发实际后端端口是 `9999`,应在本地 `.env` 或启动环境中配置: + +```text +KEYSTONE_CALLBACK_PUBLIC_BASE_URL=http://localhost:9999 +``` + +## 11. 本轮实现范围 + +本轮实现包含: + +- Keystone 新增 `KEYSTONE_CALLBACK_PUBLIC_BASE_URL` 加载和强校验。 +- Keystone 统一生成 callback URL。 +- Keystone `devices/register` 返回 `callback_allowlist`。 +- Keystone `tasks/:id/config` 返回服务端生成的 callback URL。 +- Keystone recorder config 下发入口覆盖前端传入的 callback URL。 +- Synapse 移除 `window.location.origin` 拼 callback URL 的逻辑。 +- Synapse 下发前校验 callback URL 缺失并阻止。 +- Docker dev compose 补默认开发值。 +- 单元测试覆盖配置校验、callback URL 生成、注册响应字段。 + +本轮不包含: + +- Axon transfer WebSocket token 鉴权。 +- token 轮换 API。 +- Axon token file 写入。 +- HTML 动效或可视化文档。 + +## 12. 关键原则 + +- Keystone 不能可靠地自动知道机器人应该用哪个地址访问自己。 +- callback 地址必须由 Keystone 显式配置生成,不能由 Synapse 浏览器 origin 推导。 +- 注册响应给 Axon “允许访问哪里”。 +- 任务配置给 recorder “这次具体回调哪里”。 +- Keystone 是最终安全边界,不能信任前端传入的 callback URL。 diff --git a/docs/designs/device-registration-api.md b/docs/designs/device-registration-api.md index 145c626..2699255 100644 --- a/docs/designs/device-registration-api.md +++ b/docs/designs/device-registration-api.md @@ -11,12 +11,13 @@ SPDX-License-Identifier: MulanPSL-2.0 ## 1. Overview `POST /api/v1/devices/register` registers one robot device for an installation script and -returns a Keystone-generated `device_id`. +returns a Keystone-generated `device_id` and a one-time plaintext WebSocket client token. The API is intentionally separate from `POST /api/v1/robots`. The caller provides only a factory display name and a robot type model. Keystone validates both values against existing master data, allocates a human-readable ASCII device ID, inserts a row into -`robots`, and returns the generated ID. +`robots`, inserts a hashed recorder WebSocket client token, and returns the generated ID +plus plaintext token once. This API does not require authentication in the first version. Deployment-level network access control is expected to protect the endpoint. @@ -54,7 +55,12 @@ Success returns `201 Created`. "factory_id": "1", "robot_type": "SynGloves", "robot_type_id": "3", - "robot_id": "9" + "robot_id": "9", + "ws_client_auth_token": "kws_v1_3Z2iX5lFh7mYxLQd9P0sAqzF2Z3w4R5t6U7v8W9x0Y", + "callback_allowlist": { + "allowed_host": "192.168.1.20:9999", + "allowed_path_prefix": "/api/v1/callbacks/" + } } ``` @@ -66,6 +72,14 @@ Success returns `201 Created`. | `robot_type` | Resolved `robot_types.model` | | `robot_type_id` | Resolved `robot_types.id`, encoded as a string for existing API style | | `robot_id` | Inserted `robots.id`, encoded as a string for existing API style | +| `ws_client_auth_token` | One-time plaintext token for Axon recorder WebSocket client authentication | +| `callback_allowlist.allowed_host` | Host and optional port that Axon recorder is allowed to call for Keystone callbacks | +| `callback_allowlist.allowed_path_prefix` | Callback path prefix allowed for recorder HTTP callbacks | + +Keystone does not return `ws_client_auth_token_file`. Axon owns local file path policy. If +Keystone returns `ws_client_auth_token` without a file field, `axon_config register` writes +the token to its default `/var/lib/axon/secrets/ws_client.token`, or to the path supplied by +the `--ws-client-token-file` option. ## 5. Error Responses @@ -84,7 +98,7 @@ Errors follow Keystone's usual JSON shape: | `400` | Empty or missing `robot_type` | `robot_type is required` | | `404` | No active factory with matching `factories.name` | `factory not found` | | `404` | No active robot type with matching `robot_types.model` | `robot_type not found` | -| `500` | Allocation, insert, or transaction failure | `failed to register device` | +| `500` | Allocation, robot insert, token insert, or transaction failure | `failed to register device` | ## 6. Device ID Format @@ -129,11 +143,141 @@ Every successful call creates one new `robots` row: | `status` | `active` | | `metadata` | `{}` | +Every successful call also creates one active `ws_client_auth_tokens` row for the inserted +robot. If token generation or token insertion fails, the whole registration transaction +rolls back and no robot is created. + The API is non-idempotent. Repeating the same request successfully creates another robot and returns another `device_id`. The install script should call this endpoint only when no local `device_id` already exists. -## 8. Concurrency +## 8. WebSocket Client Token + +Keystone signs no long-lived JWT here. It generates an opaque random token: + +```text +kws_v1_ +``` + +Generation rules: + +- Generate 32 cryptographically random bytes. +- Encode with URL-safe base64 without padding. +- Prefix with `kws_v1_`. +- Hash the complete token string with SHA-256. +- Store only the SHA-256 hex digest. +- Return plaintext only in the successful registration response. + +Keystone must not log the plaintext token, persist the plaintext token, or echo it in error +responses. Swagger examples should use placeholders, not generated secrets. + +The token table stores only `robot_id`, not `device_id`. `robots.device_id` remains the +single source of truth. If a robot's `device_id` is later changed and Axon is updated to use +the new device ID, the existing token can still authenticate that robot. If the old token +should stop working, a future revoke or rotate flow must revoke it explicitly. + +Migration files: + +```text +internal/storage/database/migrations/000007_ws_client_auth_tokens.up.sql +internal/storage/database/migrations/000007_ws_client_auth_tokens.down.sql +``` + +Table shape: + +```sql +CREATE TABLE IF NOT EXISTS ws_client_auth_tokens ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + robot_id BIGINT NOT NULL, + token_hash CHAR(64) NOT NULL, + token_version VARCHAR(16) NOT NULL DEFAULT 'kws_v1', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_rotated_at TIMESTAMP NULL, + last_used_at TIMESTAMP NULL, + revoked_at TIMESTAMP NULL, + UNIQUE INDEX idx_ws_client_token_hash (token_hash), + INDEX idx_ws_client_robot_active (robot_id, revoked_at) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +``` + +The table intentionally does not define a database foreign key. Existing Keystone schema +style keeps these relationships application-managed, and this avoids introducing migration +ordering and SQLite fixture complexity. + +This version does not provide a rotate endpoint. `last_rotated_at` and `revoked_at` are +reserved for a future explicit token rotation or revocation flow. + +## 9. Recorder WebSocket Authentication + +This token is required only for the Axon recorder WebSocket in this implementation. Axon +transfer WebSocket remains unchanged because Axon transfer does not currently send a +Bearer token during its WebSocket handshake. + +Keystone validates the recorder WebSocket before `websocket.Accept`. + +Accepted token transport: + +```http +Authorization: Bearer kws_v1_... +``` + +Unsupported token transports: + +```text +?token=... +X-API-Key: ... +Sec-WebSocket-Protocol: ... +``` + +Validation query shape: + +```sql +SELECT t.id +FROM ws_client_auth_tokens t +JOIN robots r ON r.id = t.robot_id +WHERE r.device_id = ? + AND t.token_hash = ? + AND t.revoked_at IS NULL + AND r.status = 'active' + AND r.deleted_at IS NULL +LIMIT 1; +``` + +The `?` values are: + +1. `device_id` from the recorder WebSocket URL. +2. SHA-256 hex digest of the Bearer token. + +Successful validation updates `last_used_at` best-effort once during the handshake. Failure +to update `last_used_at` is logged but does not reject an otherwise valid connection. + +If `RecorderHandler` has no database handle, Keystone rejects recorder WebSocket +connections with `503 Service Unavailable`; it must not bypass authentication. + +Authentication failure handling: + +| Condition | Status | Response | +|------|------|------| +| Missing `Authorization` | `401` | `{"error":"unauthorized"}` | +| Non-Bearer authorization | `401` | `{"error":"unauthorized"}` | +| Invalid token format | `401` | `{"error":"unauthorized"}` | +| Token hash not found | `401` | `{"error":"unauthorized"}` | +| Token belongs to another robot | `401` | `{"error":"unauthorized"}` | +| Token revoked | `401` | `{"error":"unauthorized"}` | +| Robot deleted or not active | `401` | `{"error":"unauthorized"}` | +| Database unavailable in handler | `503` | `{"error":"service unavailable"}` | + +`401` responses include: + +```http +WWW-Authenticate: Bearer +``` + +Logs may include `device_id` and a broad reason such as `missing bearer token` or +`invalid token`, but must not include token plaintext or distinguish "not found" from +"belongs to another device". + +## 10. Concurrency Concurrent requests with the same `factory` and `robot_type` are supported. Keystone uses `device_id_sequences` to serialize allocation for each `(factory_id, robot_type_id)` pair. @@ -158,7 +302,10 @@ FOR UPDATE; The selected `next_sequence` is used in `device_id`, then Keystone increments `next_sequence` in the same transaction before inserting the robot. -## 9. Install Script Usage +Token insertion is part of the same transaction. If inserting the token row fails, the +robot insert and device sequence increment are rolled back with the transaction. + +## 11. Install Script Usage Example: @@ -175,19 +322,47 @@ Expected script behavior: 2. Skip registration if a local `device_id` already exists. 3. Call `POST /api/v1/devices/register`. 4. Persist `device_id` locally before starting Axon services. -5. Use the persisted `device_id` for later Keystone and Axon connections. +5. If `ws_client_auth_token` is present, write it to the Axon ws client token file. +6. Use the persisted `device_id` and token file for later Keystone recorder WebSocket + connections. + +`axon_config register` already implements the token-file behavior. If the response omits +`ws_client_auth_token_file`, it writes the token to `/var/lib/axon/secrets/ws_client.token` +or the path supplied by `--ws-client-token-file`. -## 10. Implementation Notes +## 12. Implementation Notes Implementation files: | File | Purpose | |------|------| | `internal/api/handlers/device_registration.go` | Request validation, transaction, sequence allocation, robot insertion | +| `internal/api/handlers/ws_client_auth.go` | Token generation, hashing, storage, and recorder WebSocket validation | | `internal/server/server.go` | Handler construction and route registration | | `internal/storage/database/migrations/000002_device_id_sequences.up.sql` | Sequence table migration | | `internal/storage/database/migrations/000002_device_id_sequences.down.sql` | Sequence table rollback | +| `internal/storage/database/migrations/000007_ws_client_auth_tokens.up.sql` | WebSocket client token table migration | +| `internal/storage/database/migrations/000007_ws_client_auth_tokens.down.sql` | WebSocket client token table rollback | | `internal/api/handlers/device_registration_test.go` | Focused handler and route tests | +| `internal/api/handlers/recorder_ws_auth_test.go` | Recorder WebSocket token authentication tests | + +TDD coverage should include: + +1. Register success returns `ws_client_auth_token` with `kws_v1_` prefix. +2. Register success stores only token hash, not plaintext. +3. Token table insert failure rolls back robot creation. +4. Recorder WebSocket without Authorization returns `401`. +5. Recorder WebSocket with wrong Bearer token returns `401`. +6. Recorder WebSocket with correct token and device ID connects. +7. Token for robot A cannot connect as robot B. +8. Deleted or non-active robot cannot authenticate. + +Out of scope for this implementation: + +- Axon transfer WebSocket token authentication. +- Token query parameters, `X-API-Key`, or `Sec-WebSocket-Protocol`. +- Token rotation API. +- Register endpoint authentication. Validation performed during implementation: @@ -205,6 +380,11 @@ Manual API verification was performed against a local Keystone instance: "factory_id": "1", "robot_type": "SynGloves", "robot_type_id": "3", - "robot_id": "9" + "robot_id": "9", + "ws_client_auth_token": "kws_v1_example", + "callback_allowlist": { + "allowed_host": "192.168.1.20:9999", + "allowed_path_prefix": "/api/v1/callbacks/" + } } ``` diff --git a/internal/api/handlers/axon_rpc.go b/internal/api/handlers/axon_rpc.go index 8cc50de..53fa3c3 100644 --- a/internal/api/handlers/axon_rpc.go +++ b/internal/api/handlers/axon_rpc.go @@ -27,11 +27,12 @@ import ( // RecorderHandler handles REST and WebSocket traffic for Axon Recorder RPC. type RecorderHandler struct { - hub *services.RecorderHub - transferHub *services.TransferHub - stateBroker *services.DeviceStateBroker - cfg *config.RecorderConfig - db *sqlx.DB + hub *services.RecorderHub + transferHub *services.TransferHub + stateBroker *services.DeviceStateBroker + cfg *config.RecorderConfig + db *sqlx.DB + callbackURLs callbackURLs } // NewRecorderHandler creates a new RecorderHandler. @@ -39,6 +40,14 @@ func NewRecorderHandler(hub *services.RecorderHub, cfg *config.RecorderConfig, d return &RecorderHandler{hub: hub, cfg: cfg, db: db} } +// SetCallbackPublicBaseURL configures callback URLs sent in recorder task config RPCs. +func (h *RecorderHandler) SetCallbackPublicBaseURL(callbackPublicBaseURL string) { + if h == nil { + return + } + h.callbackURLs = newCallbackURLs(callbackPublicBaseURL) +} + // SetDeviceStateDeps enables device connection/state event publishing. func (h *RecorderHandler) SetDeviceStateDeps(transferHub *services.TransferHub, broker *services.DeviceStateBroker) { if h == nil { @@ -256,6 +265,8 @@ func (h *RecorderHandler) Config(c *gin.Context) { return } + h.overrideTaskConfigCallbackURLs(params) + if !h.callRPC(c, "config", params) { return } @@ -263,6 +274,18 @@ func (h *RecorderHandler) Config(c *gin.Context) { advanceTaskPendingToReady(h.db, c.Param("device_id"), taskID, "config") } +func (h *RecorderHandler) overrideTaskConfigCallbackURLs(params map[string]interface{}) { + if h == nil || !h.callbackURLs.configured() || params == nil { + return + } + taskConfig, ok := params["task_config"].(map[string]interface{}) + if !ok || taskConfig == nil { + return + } + taskConfig["start_callback_url"] = h.callbackURLs.startURL() + taskConfig["finish_callback_url"] = h.callbackURLs.finishURL() +} + func (h *RecorderHandler) requireTaskConfigurable(c *gin.Context, taskID string) bool { if h == nil || h.db == nil { return true diff --git a/internal/api/handlers/callback_urls.go b/internal/api/handlers/callback_urls.go new file mode 100644 index 0000000..2dff3d9 --- /dev/null +++ b/internal/api/handlers/callback_urls.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2026 ArcheBase +// +// SPDX-License-Identifier: MulanPSL-2.0 + +package handlers + +import ( + "net/url" + "strings" +) + +const callbackPathPrefix = "/api/v1/callbacks/" + +// CallbackAllowlist describes the callback URL scope Axon is allowed to call. +type CallbackAllowlist struct { + AllowedHost string `json:"allowed_host"` + AllowedPathPrefix string `json:"allowed_path_prefix"` +} + +type callbackURLs struct { + baseURL string +} + +func newCallbackURLs(baseURL string) callbackURLs { + return callbackURLs{baseURL: strings.TrimRight(strings.TrimSpace(baseURL), "/")} +} + +func (u callbackURLs) configured() bool { + return u.baseURL != "" +} + +func (u callbackURLs) allowlist() CallbackAllowlist { + parsed, err := url.Parse(u.baseURL) + if err != nil { + return CallbackAllowlist{AllowedPathPrefix: callbackPathPrefix} + } + return CallbackAllowlist{ + AllowedHost: parsed.Host, + AllowedPathPrefix: callbackPathPrefix, + } +} + +func (u callbackURLs) startURL() string { + return u.baseURL + callbackPathPrefix + "start" +} + +func (u callbackURLs) finishURL() string { + return u.baseURL + callbackPathPrefix + "finish" +} diff --git a/internal/api/handlers/device_registration.go b/internal/api/handlers/device_registration.go index c9741b8..06b2be4 100644 --- a/internal/api/handlers/device_registration.go +++ b/internal/api/handlers/device_registration.go @@ -26,12 +26,16 @@ var ( // DeviceRegistrationHandler handles install-time device registration requests. type DeviceRegistrationHandler struct { - db *sqlx.DB + db *sqlx.DB + callbackURLs callbackURLs } // NewDeviceRegistrationHandler creates a new DeviceRegistrationHandler. -func NewDeviceRegistrationHandler(db *sqlx.DB) *DeviceRegistrationHandler { - return &DeviceRegistrationHandler{db: db} +func NewDeviceRegistrationHandler(db *sqlx.DB, callbackPublicBaseURL string) *DeviceRegistrationHandler { + return &DeviceRegistrationHandler{ + db: db, + callbackURLs: newCallbackURLs(callbackPublicBaseURL), + } } // DeviceRegistrationRequest represents the request body for device registration. @@ -42,12 +46,13 @@ type DeviceRegistrationRequest struct { // DeviceRegistrationResponse represents a successful device registration. type DeviceRegistrationResponse struct { - DeviceID string `json:"device_id"` - Factory string `json:"factory"` - FactoryID string `json:"factory_id"` - RobotType string `json:"robot_type"` - RobotTypeID string `json:"robot_type_id"` - RobotID string `json:"robot_id"` + DeviceID string `json:"device_id"` + Factory string `json:"factory"` + FactoryID string `json:"factory_id"` + RobotType string `json:"robot_type"` + RobotTypeID string `json:"robot_type_id"` + RobotID string `json:"robot_id"` + CallbackAllowlist CallbackAllowlist `json:"callback_allowlist"` } type deviceRegistrationFactoryRow struct { @@ -187,12 +192,13 @@ func (h *DeviceRegistrationHandler) registerDevice(factoryName, robotTypeModel s } return DeviceRegistrationResponse{ - DeviceID: deviceID, - Factory: factory.Name, - FactoryID: strconv.FormatInt(factory.ID, 10), - RobotType: robotType.Model, - RobotTypeID: strconv.FormatInt(robotType.ID, 10), - RobotID: strconv.FormatInt(robotID, 10), + DeviceID: deviceID, + Factory: factory.Name, + FactoryID: strconv.FormatInt(factory.ID, 10), + RobotType: robotType.Model, + RobotTypeID: strconv.FormatInt(robotType.ID, 10), + RobotID: strconv.FormatInt(robotID, 10), + CallbackAllowlist: h.callbackURLs.allowlist(), }, nil } diff --git a/internal/api/handlers/device_registration_test.go b/internal/api/handlers/device_registration_test.go index d4271df..d7c3ae9 100644 --- a/internal/api/handlers/device_registration_test.go +++ b/internal/api/handlers/device_registration_test.go @@ -102,6 +102,12 @@ func TestDeviceRegistrationHandlerRegisterDevice_Success(t *testing.T) { if !isASCII(resp.DeviceID) { t.Fatalf("device_id is not ASCII: %q", resp.DeviceID) } + if resp.CallbackAllowlist.AllowedHost != "192.168.1.20:9999" { + t.Fatalf("allowed_host=%q want 192.168.1.20:9999", resp.CallbackAllowlist.AllowedHost) + } + if resp.CallbackAllowlist.AllowedPathPrefix != "/api/v1/callbacks/" { + t.Fatalf("allowed_path_prefix=%q want /api/v1/callbacks/", resp.CallbackAllowlist.AllowedPathPrefix) + } var robotCount int if err := db.Get(&robotCount, "SELECT COUNT(*) FROM robots WHERE device_id = ?", resp.DeviceID); err != nil { @@ -162,7 +168,7 @@ func TestDeviceRegistrationRoutes_DoNotConflictWithRobotDeviceRoutes(t *testing. v1 := router.Group("/api/v1") NewRobotHandler(nil, nil, nil).RegisterRoutes(v1) - NewDeviceRegistrationHandler(nil).RegisterRoutes(v1) + NewDeviceRegistrationHandler(nil, "http://192.168.1.20:9999").RegisterRoutes(v1) } func registerTestDevice(t *testing.T, router *gin.Engine) DeviceRegistrationResponse { @@ -188,7 +194,7 @@ func newTestDeviceRegistrationRouter(t *testing.T, db *sqlx.DB) *gin.Engine { gin.SetMode(gin.TestMode) router := gin.New() - handler := NewDeviceRegistrationHandler(db) + handler := NewDeviceRegistrationHandler(db, "http://192.168.1.20:9999") v1 := router.Group("/api/v1") handler.RegisterRoutes(v1) diff --git a/internal/api/handlers/recorder_axon_interaction_test.go b/internal/api/handlers/recorder_axon_interaction_test.go index b629476..2a78660 100644 --- a/internal/api/handlers/recorder_axon_interaction_test.go +++ b/internal/api/handlers/recorder_axon_interaction_test.go @@ -1245,6 +1245,7 @@ func TestRecorderWebSocketRPCActionProtocol(t *testing.T) { hub := services.NewRecorderHub() handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, db) + handler.SetCallbackPublicBaseURL("http://192.168.1.20:9999") wsURL := newRecorderWebSocketTestServer(t, handler, "robot-001") axon := connectFakeRecorderAxon(t, wsURL) getState := axon.receiveRPC(t, "get_state") @@ -1264,7 +1265,7 @@ func TestRecorderWebSocketRPCActionProtocol(t *testing.T) { name: "config forwards task_config payload", method: http.MethodPost, path: "/recorder/robot-001/config", - body: `{"task_config":{"task_id":"task-protocol","device_id":"robot-001"}}`, + body: `{"task_config":{"task_id":"task-protocol","device_id":"robot-001","start_callback_url":"http://localhost:5174/api/v1/callbacks/start","finish_callback_url":"http://localhost:5174/api/v1/callbacks/finish"}}`, wantAction: "config", check: func(t *testing.T, req services.RPCRequest) { t.Helper() @@ -1272,6 +1273,12 @@ func TestRecorderWebSocketRPCActionProtocol(t *testing.T) { if !ok || tc["task_id"] != "task-protocol" || tc["device_id"] != "robot-001" { t.Fatalf("config params=%#v missing task_config", req.Params) } + if tc["start_callback_url"] != "http://192.168.1.20:9999/api/v1/callbacks/start" { + t.Fatalf("start_callback_url=%#v", tc["start_callback_url"]) + } + if tc["finish_callback_url"] != "http://192.168.1.20:9999/api/v1/callbacks/finish" { + t.Fatalf("finish_callback_url=%#v", tc["finish_callback_url"]) + } }, }, { diff --git a/internal/api/handlers/task.go b/internal/api/handlers/task.go index 2759648..bc65af1 100644 --- a/internal/api/handlers/task.go +++ b/internal/api/handlers/task.go @@ -70,6 +70,7 @@ type TaskHandler struct { recorderHub *services.RecorderHub recorderRPCTimeout time.Duration transferWriteTimeout time.Duration + callbackURLs callbackURLs } // NewTaskHandler creates a new TaskHandler. @@ -87,6 +88,14 @@ func NewTaskHandler(db *sqlx.DB, hub *services.TransferHub, recorderHub *service } } +// SetCallbackPublicBaseURL configures Keystone callback URLs returned in task configs. +func (h *TaskHandler) SetCallbackPublicBaseURL(callbackPublicBaseURL string) { + if h == nil { + return + } + h.callbackURLs = newCallbackURLs(callbackPublicBaseURL) +} + func (h *TaskHandler) axonTransferWriteTimeout() time.Duration { if h == nil || h.transferWriteTimeout <= 0 { return services.DefaultTransferWriteTimeout @@ -1190,8 +1199,8 @@ func (h *TaskHandler) GetTaskConfig(c *gin.Context) { Skills: skills, SOPID: strings.TrimSpace(row.SOPSlug.String), Topics: parseJSONArray(row.ROSTopics.String), - StartCallbackURL: "http://keystone.factory.internal/api/v1/callbacks/start", - FinishCallbackURL: "http://keystone.factory.internal/api/v1/callbacks/finish", + StartCallbackURL: h.callbackURLs.startURL(), + FinishCallbackURL: h.callbackURLs.finishURL(), UserToken: "", } diff --git a/internal/api/handlers/task_callback_config_test.go b/internal/api/handlers/task_callback_config_test.go new file mode 100644 index 0000000..1290caf --- /dev/null +++ b/internal/api/handlers/task_callback_config_test.go @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: 2026 ArcheBase +// +// SPDX-License-Identifier: MulanPSL-2.0 + +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/jmoiron/sqlx" + _ "modernc.org/sqlite" +) + +func TestGetTaskConfigUsesConfiguredCallbackPublicBaseURL(t *testing.T) { + db := newTestTaskConfigCallbackDB(t) + defer db.Close() + + handler := NewTaskHandler(db, nil, nil, 0) + handler.SetCallbackPublicBaseURL("http://192.168.1.20:9999") + + gin.SetMode(gin.TestMode) + router := gin.New() + router.GET("/tasks/:id/config", handler.GetTaskConfig) + + w := httptest.NewRecorder() + router.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/tasks/1/config", nil)) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d want=%d body=%s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp TaskConfig + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, w.Body.String()) + } + if resp.StartCallbackURL != "http://192.168.1.20:9999/api/v1/callbacks/start" { + t.Fatalf("start_callback_url=%q", resp.StartCallbackURL) + } + if resp.FinishCallbackURL != "http://192.168.1.20:9999/api/v1/callbacks/finish" { + t.Fatalf("finish_callback_url=%q", resp.FinishCallbackURL) + } +} + +func newTestTaskConfigCallbackDB(t *testing.T) *sqlx.DB { + t.Helper() + db, err := sqlx.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite db: %v", err) + } + + schema := []string{ + `CREATE TABLE tasks ( + id INTEGER PRIMARY KEY, + task_id TEXT NOT NULL, + workstation_id INTEGER, + order_id INTEGER, + factory_id INTEGER, + sop_id INTEGER, + scene_name TEXT, + subscene_name TEXT, + initial_scene_layout TEXT, + status TEXT NOT NULL, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE factories ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE workstations ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + robot_serial TEXT NOT NULL, + robot_id INTEGER NOT NULL, + collector_name TEXT NOT NULL, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE robots ( + id INTEGER PRIMARY KEY, + robot_type_id INTEGER NOT NULL, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE robot_types ( + id INTEGER PRIMARY KEY, + ros_topics TEXT NOT NULL, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE sops ( + id INTEGER PRIMARY KEY, + slug TEXT NOT NULL, + skill_sequence TEXT NOT NULL, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE skills ( + id INTEGER PRIMARY KEY, + slug TEXT NOT NULL, + deleted_at TIMESTAMP NULL + )`, + } + for _, stmt := range schema { + if _, err := db.Exec(stmt); err != nil { + t.Fatalf("create schema: %v", err) + } + } + + now := time.Now().UTC() + seed := []struct { + sql string + args []any + }{ + {`INSERT INTO factories (id, name) VALUES (30, '上海一厂')`, nil}, + {`INSERT INTO orders (id, name) VALUES (10, 'order-a')`, nil}, + {`INSERT INTO robot_types (id, ros_topics) VALUES (12, '["/camera","/tf"]')`, nil}, + {`INSERT INTO robots (id, robot_type_id) VALUES (20, 12)`, nil}, + {`INSERT INTO workstations (id, name, robot_serial, robot_id, collector_name) VALUES (40, 'station-a', 'robot-001', 20, 'collector-a')`, nil}, + {`INSERT INTO sops (id, slug, skill_sequence) VALUES (50, 'sop-a', '["1"]')`, nil}, + {`INSERT INTO skills (id, slug) VALUES (1, 'pick')`, nil}, + {`INSERT INTO tasks (id, task_id, workstation_id, order_id, factory_id, sop_id, scene_name, subscene_name, initial_scene_layout, status) VALUES (1, 'task-a', 40, 10, 30, 50, 'scene-a', 'sub-a', '{}', 'pending')`, []any{now}}, + } + for _, stmt := range seed { + if _, err := db.Exec(stmt.sql, stmt.args...); err != nil { + t.Fatalf("seed data: %v", err) + } + } + return db +} diff --git a/internal/config/config.go b/internal/config/config.go index 4df907e..d45061b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,6 +7,7 @@ package config import ( "fmt" + "net/url" "os" "path/filepath" "strconv" @@ -30,11 +31,12 @@ type Config struct { // ServerConfig server configuration type ServerConfig struct { - Mode string - BindAddr string - ReadTimeout int // seconds - WriteTimeout int // seconds - ShutdownTimeout int // seconds + Mode string + BindAddr string + CallbackPublicBaseURL string + ReadTimeout int // seconds + WriteTimeout int // seconds + ShutdownTimeout int // seconds } // DatabaseConfig database configuration @@ -152,11 +154,12 @@ type AuthConfig struct { func Load() (*Config, error) { cfg := &Config{ Server: ServerConfig{ - Mode: getEnv("KEYSTONE_MODE", "edge"), - BindAddr: getEnv("KEYSTONE_BIND_ADDR", ":8080"), - ReadTimeout: getEnvInt("KEYSTONE_READ_TIMEOUT", 30), - WriteTimeout: getEnvInt("KEYSTONE_WRITE_TIMEOUT", 30), - ShutdownTimeout: getEnvInt("KEYSTONE_SHUTDOWN_TIMEOUT", 10), + Mode: getEnv("KEYSTONE_MODE", "edge"), + BindAddr: getEnv("KEYSTONE_BIND_ADDR", ":8080"), + CallbackPublicBaseURL: getEnv("KEYSTONE_CALLBACK_PUBLIC_BASE_URL", ""), + ReadTimeout: getEnvInt("KEYSTONE_READ_TIMEOUT", 30), + WriteTimeout: getEnvInt("KEYSTONE_WRITE_TIMEOUT", 30), + ShutdownTimeout: getEnvInt("KEYSTONE_SHUTDOWN_TIMEOUT", 10), }, Database: DatabaseConfig{ Driver: "mysql", @@ -262,6 +265,11 @@ func (c *Config) Validate() error { if c.Server.Mode != "edge" { return fmt.Errorf("invalid mode: %s, must be 'edge'", c.Server.Mode) } + callbackPublicBaseURL, err := normalizeCallbackPublicBaseURL(c.Server.CallbackPublicBaseURL) + if err != nil { + return err + } + c.Server.CallbackPublicBaseURL = callbackPublicBaseURL if c.Database.DSN == "" { return fmt.Errorf("database DSN is required") } @@ -326,6 +334,34 @@ func (c *Config) Validate() error { return nil } +func normalizeCallbackPublicBaseURL(raw string) (string, error) { + value := strings.TrimSpace(raw) + if value == "" { + return "", fmt.Errorf("KEYSTONE_CALLBACK_PUBLIC_BASE_URL is required") + } + parsed, err := url.Parse(value) + if err != nil { + return "", fmt.Errorf("KEYSTONE_CALLBACK_PUBLIC_BASE_URL %q is invalid: %w", raw, err) + } + if !parsed.IsAbs() || parsed.Host == "" { + return "", fmt.Errorf("KEYSTONE_CALLBACK_PUBLIC_BASE_URL must be an absolute URL with host") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", fmt.Errorf("KEYSTONE_CALLBACK_PUBLIC_BASE_URL scheme must be http or https") + } + if parsed.Path != "" && parsed.Path != "/" { + return "", fmt.Errorf("KEYSTONE_CALLBACK_PUBLIC_BASE_URL path must be empty") + } + if parsed.RawQuery != "" { + return "", fmt.Errorf("KEYSTONE_CALLBACK_PUBLIC_BASE_URL query must be empty") + } + if parsed.Fragment != "" { + return "", fmt.Errorf("KEYSTONE_CALLBACK_PUBLIC_BASE_URL fragment must be empty") + } + parsed.Path = "" + return parsed.String(), nil +} + func getEnv(key, fallback string) string { if val := os.Getenv(key); val != "" { return val diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4578975..341297e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -15,14 +15,15 @@ import ( func TestLoad(t *testing.T) { // Save original environment variables originalEnv := map[string]string{ - "KEYSTONE_MODE": os.Getenv("KEYSTONE_MODE"), - "KEYSTONE_MYSQL_HOST": os.Getenv("KEYSTONE_MYSQL_HOST"), - "KEYSTONE_MYSQL_PASSWORD": os.Getenv("KEYSTONE_MYSQL_PASSWORD"), - "KEYSTONE_MINIO_ACCESS_KEY": os.Getenv("KEYSTONE_MINIO_ACCESS_KEY"), - "KEYSTONE_MINIO_SECRET_KEY": os.Getenv("KEYSTONE_MINIO_SECRET_KEY"), - "KEYSTONE_FACTORY_ID": os.Getenv("KEYSTONE_FACTORY_ID"), - "KEYSTONE_SYNC_AUTO_SCAN_ENABLED": os.Getenv("KEYSTONE_SYNC_AUTO_SCAN_ENABLED"), - "KEYSTONE_SYNC_DP_CONFIG": os.Getenv("KEYSTONE_SYNC_DP_CONFIG"), + "KEYSTONE_MODE": os.Getenv("KEYSTONE_MODE"), + "KEYSTONE_MYSQL_HOST": os.Getenv("KEYSTONE_MYSQL_HOST"), + "KEYSTONE_MYSQL_PASSWORD": os.Getenv("KEYSTONE_MYSQL_PASSWORD"), + "KEYSTONE_MINIO_ACCESS_KEY": os.Getenv("KEYSTONE_MINIO_ACCESS_KEY"), + "KEYSTONE_MINIO_SECRET_KEY": os.Getenv("KEYSTONE_MINIO_SECRET_KEY"), + "KEYSTONE_FACTORY_ID": os.Getenv("KEYSTONE_FACTORY_ID"), + "KEYSTONE_SYNC_AUTO_SCAN_ENABLED": os.Getenv("KEYSTONE_SYNC_AUTO_SCAN_ENABLED"), + "KEYSTONE_SYNC_DP_CONFIG": os.Getenv("KEYSTONE_SYNC_DP_CONFIG"), + "KEYSTONE_CALLBACK_PUBLIC_BASE_URL": os.Getenv("KEYSTONE_CALLBACK_PUBLIC_BASE_URL"), } defer func() { // Restore original environment variables @@ -42,6 +43,7 @@ func TestLoad(t *testing.T) { os.Setenv("KEYSTONE_MINIO_ACCESS_KEY", "test-access-key") os.Setenv("KEYSTONE_MINIO_SECRET_KEY", "test-secret-key") os.Setenv("KEYSTONE_FACTORY_ID", "factory-test") + os.Setenv("KEYSTONE_CALLBACK_PUBLIC_BASE_URL", "http://127.0.0.1:9999") cfg, err := Load() if err != nil { @@ -56,6 +58,9 @@ func TestLoad(t *testing.T) { if cfg.Server.BindAddr != ":8080" { t.Errorf("Load().Server.BindAddr = %v, want :8080", cfg.Server.BindAddr) } + if cfg.Server.CallbackPublicBaseURL != "http://127.0.0.1:9999" { + t.Errorf("Load().Server.CallbackPublicBaseURL = %q, want http://127.0.0.1:9999", cfg.Server.CallbackPublicBaseURL) + } // Verify reading from environment variables if cfg.Database.DSN == "" { @@ -106,15 +111,16 @@ func TestLoad(t *testing.T) { func TestLoadWithCustomEnv(t *testing.T) { // Save original environment variables originalEnv := map[string]string{ - "KEYSTONE_MODE": os.Getenv("KEYSTONE_MODE"), - "KEYSTONE_BIND_ADDR": os.Getenv("KEYSTONE_BIND_ADDR"), - "KEYSTONE_MYSQL_PASSWORD": os.Getenv("KEYSTONE_MYSQL_PASSWORD"), - "KEYSTONE_MINIO_ACCESS_KEY": os.Getenv("KEYSTONE_MINIO_ACCESS_KEY"), - "KEYSTONE_MINIO_SECRET_KEY": os.Getenv("KEYSTONE_MINIO_SECRET_KEY"), - "KEYSTONE_QA_MAX_WORKERS": os.Getenv("KEYSTONE_QA_MAX_WORKERS"), - "KEYSTONE_MAX_MEMORY_MB": os.Getenv("KEYSTONE_MAX_MEMORY_MB"), - "KEYSTONE_DASHBOARD_DISPLAY_TOKEN": os.Getenv("KEYSTONE_DASHBOARD_DISPLAY_TOKEN"), - "KEYSTONE_SYNC_AUTO_SCAN_ENABLED": os.Getenv("KEYSTONE_SYNC_AUTO_SCAN_ENABLED"), + "KEYSTONE_MODE": os.Getenv("KEYSTONE_MODE"), + "KEYSTONE_BIND_ADDR": os.Getenv("KEYSTONE_BIND_ADDR"), + "KEYSTONE_MYSQL_PASSWORD": os.Getenv("KEYSTONE_MYSQL_PASSWORD"), + "KEYSTONE_MINIO_ACCESS_KEY": os.Getenv("KEYSTONE_MINIO_ACCESS_KEY"), + "KEYSTONE_MINIO_SECRET_KEY": os.Getenv("KEYSTONE_MINIO_SECRET_KEY"), + "KEYSTONE_QA_MAX_WORKERS": os.Getenv("KEYSTONE_QA_MAX_WORKERS"), + "KEYSTONE_MAX_MEMORY_MB": os.Getenv("KEYSTONE_MAX_MEMORY_MB"), + "KEYSTONE_DASHBOARD_DISPLAY_TOKEN": os.Getenv("KEYSTONE_DASHBOARD_DISPLAY_TOKEN"), + "KEYSTONE_SYNC_AUTO_SCAN_ENABLED": os.Getenv("KEYSTONE_SYNC_AUTO_SCAN_ENABLED"), + "KEYSTONE_CALLBACK_PUBLIC_BASE_URL": os.Getenv("KEYSTONE_CALLBACK_PUBLIC_BASE_URL"), } defer func() { for k, v := range originalEnv { @@ -136,6 +142,7 @@ func TestLoadWithCustomEnv(t *testing.T) { os.Setenv("KEYSTONE_MAX_MEMORY_MB", "8192") os.Setenv("KEYSTONE_DASHBOARD_DISPLAY_TOKEN", "display-secret") os.Setenv("KEYSTONE_SYNC_AUTO_SCAN_ENABLED", "true") + os.Setenv("KEYSTONE_CALLBACK_PUBLIC_BASE_URL", "https://keystone.factory.internal") cfg, err := Load() if err != nil { @@ -161,6 +168,9 @@ func TestLoadWithCustomEnv(t *testing.T) { if !cfg.Sync.AutoScanEnabled { t.Error("Load().Sync.AutoScanEnabled = false, want true") } + if cfg.Server.CallbackPublicBaseURL != "https://keystone.factory.internal" { + t.Errorf("Load().Server.CallbackPublicBaseURL = %q, want https://keystone.factory.internal", cfg.Server.CallbackPublicBaseURL) + } } func TestConfigValidate(t *testing.T) { @@ -172,7 +182,7 @@ func TestConfigValidate(t *testing.T) { { name: "Valid configuration", cfg: &Config{ - Server: ServerConfig{Mode: "edge"}, + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{ DSN: "user:pass@tcp(localhost:3306)/db", }, @@ -189,7 +199,7 @@ func TestConfigValidate(t *testing.T) { { name: "Invalid mode", cfg: &Config{ - Server: ServerConfig{Mode: "cloud"}, + Server: ServerConfig{Mode: "cloud", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{ DSN: "user:pass@tcp(localhost:3306)/db", }, @@ -203,7 +213,7 @@ func TestConfigValidate(t *testing.T) { { name: "Empty DSN", cfg: &Config{ - Server: ServerConfig{Mode: "edge"}, + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{ DSN: "", }, @@ -217,7 +227,7 @@ func TestConfigValidate(t *testing.T) { { name: "Empty storage keys", cfg: &Config{ - Server: ServerConfig{Mode: "edge"}, + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{ DSN: "user:pass@tcp(localhost:3306)/db", }, @@ -241,7 +251,7 @@ func TestConfigValidate(t *testing.T) { { name: "Only admin username set (no password)", cfg: &Config{ - Server: ServerConfig{Mode: "edge"}, + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{DSN: "user:pass@tcp(localhost:3306)/db"}, Storage: StorageConfig{AccessKey: "key", SecretKey: "secret"}, Auth: AuthConfig{JWTSecret: "secret", AdminUsername: "admin", AdminPassword: ""}, @@ -251,7 +261,7 @@ func TestConfigValidate(t *testing.T) { { name: "Only admin password set (no username)", cfg: &Config{ - Server: ServerConfig{Mode: "edge"}, + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{DSN: "user:pass@tcp(localhost:3306)/db"}, Storage: StorageConfig{AccessKey: "key", SecretKey: "secret"}, Auth: AuthConfig{JWTSecret: "secret", AdminUsername: "", AdminPassword: "pass"}, @@ -261,7 +271,7 @@ func TestConfigValidate(t *testing.T) { { name: "Valid admin credentials", cfg: &Config{ - Server: ServerConfig{Mode: "edge"}, + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{DSN: "user:pass@tcp(localhost:3306)/db"}, Storage: StorageConfig{AccessKey: "key", SecretKey: "secret"}, Auth: AuthConfig{JWTSecret: "secret", AdminUsername: "admin", AdminPassword: "pass"}, @@ -280,9 +290,54 @@ func TestConfigValidate(t *testing.T) { } } +func TestValidateCallbackPublicBaseURL(t *testing.T) { + validBase := Config{ + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, + Database: DatabaseConfig{DSN: "user:pass@tcp(localhost:3306)/db"}, + Storage: StorageConfig{AccessKey: "key", SecretKey: "secret"}, + Auth: AuthConfig{JWTSecret: "jwt-secret"}, + } + + t.Run("required", func(t *testing.T) { + cfg := validBase + cfg.Server.CallbackPublicBaseURL = "" + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "KEYSTONE_CALLBACK_PUBLIC_BASE_URL") { + t.Fatalf("Validate() error = %v, want callback public base URL error", err) + } + }) + + for _, raw := range []string{ + "192.168.1.20:9999", + "ftp://192.168.1.20:9999", + "http:///api", + "http://gateway.local/keystone", + "http://gateway.local?x=1", + "http://gateway.local#abc", + } { + t.Run("rejects "+raw, func(t *testing.T) { + cfg := validBase + cfg.Server.CallbackPublicBaseURL = raw + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "KEYSTONE_CALLBACK_PUBLIC_BASE_URL") { + t.Fatalf("Validate() error = %v, want callback public base URL error", err) + } + }) + } + + t.Run("normalizes trailing slash", func(t *testing.T) { + cfg := validBase + cfg.Server.CallbackPublicBaseURL = "https://keystone.factory.internal/" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() unexpected error = %v", err) + } + if cfg.Server.CallbackPublicBaseURL != "https://keystone.factory.internal" { + t.Fatalf("CallbackPublicBaseURL = %q, want normalized URL", cfg.Server.CallbackPublicBaseURL) + } + }) +} + func TestValidateSyncDPConfig(t *testing.T) { validBase := Config{ - Server: ServerConfig{Mode: "edge"}, + Server: ServerConfig{Mode: "edge", CallbackPublicBaseURL: "http://127.0.0.1:9999"}, Database: DatabaseConfig{DSN: "user:pass@tcp(localhost:3306)/db"}, Storage: StorageConfig{AccessKey: "key", SecretKey: "secret"}, Auth: AuthConfig{JWTSecret: "jwt-secret"}, diff --git a/internal/server/server.go b/internal/server/server.go index a9522a9..c137788 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -101,6 +101,7 @@ func New(cfg *config.Config, db *sqlx.DB, s3Client *s3.Client, syncWorker *servi stateBroker := services.NewDeviceStateBroker() recorderHub := services.NewRecorderHub() recorderHandler := handlers.NewRecorderHandler(recorderHub, &cfg.AxonRecorder, db) + recorderHandler.SetCallbackPublicBaseURL(cfg.Server.CallbackPublicBaseURL) recorderRPCTimeout := time.Duration(cfg.AxonRecorder.ResponseTimeout) * time.Second // Create TransferHub and TransferHandler for Transfer Service @@ -119,6 +120,7 @@ func New(cfg *config.Config, db *sqlx.DB, s3Client *s3.Client, syncWorker *servi // Create TaskHandler for task configuration taskHandler := handlers.NewTaskHandler(db, transferHub, recorderHub, recorderRPCTimeout, transferWriteTimeout) + taskHandler.SetCallbackPublicBaseURL(cfg.Server.CallbackPublicBaseURL) // Create database-dependent handlers only when DB is available var ( @@ -144,7 +146,7 @@ func New(cfg *config.Config, db *sqlx.DB, s3Client *s3.Client, syncWorker *servi batchHandler = handlers.NewBatchHandler(db, recorderHub, recorderRPCTimeout) robotTypeHandler = handlers.NewRobotTypeHandler(db) robotHandler = handlers.NewRobotHandler(db, recorderHub, transferHub, cfg.Sync.DPConfigPath) - deviceRegistrationHandler = handlers.NewDeviceRegistrationHandler(db) + deviceRegistrationHandler = handlers.NewDeviceRegistrationHandler(db, cfg.Server.CallbackPublicBaseURL) factoryHandler = handlers.NewFactoryHandler(db) dataCollectorHandler = handlers.NewDataCollectorHandler(db) stationHandler = handlers.NewStationHandler(db) From 37bca6909dbdaa6db86246b27cc102ee55f634a0 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Mon, 22 Jun 2026 18:21:48 +0800 Subject: [PATCH 2/6] feat(device): add recorder ws client auth --- internal/api/handlers/axon_rpc.go | 27 +-- internal/api/handlers/device_registration.go | 10 + .../api/handlers/device_registration_test.go | 83 ++++++++ .../recorder_axon_interaction_test.go | 193 +++++++++++++++++- internal/api/handlers/ws_client_auth.go | 118 +++++++++++ .../000007_ws_client_auth_tokens.down.sql | 5 + .../000007_ws_client_auth_tokens.up.sql | 16 ++ 7 files changed, 425 insertions(+), 27 deletions(-) create mode 100644 internal/api/handlers/ws_client_auth.go create mode 100644 internal/storage/database/migrations/000007_ws_client_auth_tokens.down.sql create mode 100644 internal/storage/database/migrations/000007_ws_client_auth_tokens.up.sql diff --git a/internal/api/handlers/axon_rpc.go b/internal/api/handlers/axon_rpc.go index 53fa3c3..9d003a9 100644 --- a/internal/api/handlers/axon_rpc.go +++ b/internal/api/handlers/axon_rpc.go @@ -147,31 +147,8 @@ func (h *RecorderHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request w.WriteHeader(http.StatusBadRequest) return } - - // Validate device exists in robots table (if DB is configured) - if h.db != nil { - // Add independent 5s timeout to avoid blocking on slow DB queries - queryTimeout := 5 * time.Second - queryCtx, cancel := context.WithTimeout(r.Context(), queryTimeout) - defer cancel() - - var count int - // #nosec G701 -- Set aside for now - if err := h.db.GetContext(queryCtx, &count, - "SELECT COUNT(1) FROM robots WHERE device_id = ? AND deleted_at IS NULL", deviceID, - ); err != nil { - if errors.Is(err, context.DeadlineExceeded) || errors.Is(queryCtx.Err(), context.DeadlineExceeded) { - logger.Printf("%s DB query timeout after %s (timeout_ms=%d): %v", recorderLogPrefix(deviceID), timeoutLogValue(queryTimeout), timeoutLogMilliseconds(queryTimeout), err) - } else { - logger.Printf("%s DB query error: %v", recorderLogPrefix(deviceID), err) - } - } - // Check count regardless of DB error (count defaults to 0 on error) - if count == 0 { - logger.Printf("%s robot not found in database", recorderLogPrefix(deviceID)) - w.WriteHeader(http.StatusNotFound) - return - } + if !h.authorizeRecorderWebSocket(w, r, deviceID) { + return } // Allow any origin in dev; tighten in production diff --git a/internal/api/handlers/device_registration.go b/internal/api/handlers/device_registration.go index 06b2be4..e23327c 100644 --- a/internal/api/handlers/device_registration.go +++ b/internal/api/handlers/device_registration.go @@ -52,6 +52,7 @@ type DeviceRegistrationResponse struct { RobotType string `json:"robot_type"` RobotTypeID string `json:"robot_type_id"` RobotID string `json:"robot_id"` + WSClientAuthToken string `json:"ws_client_auth_token"` CallbackAllowlist CallbackAllowlist `json:"callback_allowlist"` } @@ -187,6 +188,14 @@ func (h *DeviceRegistrationHandler) registerDevice(factoryName, robotTypeModel s return DeviceRegistrationResponse{}, fmt.Errorf("get inserted robot id: %w", err) } + wsClientAuthToken, err := generateWSClientAuthToken() + if err != nil { + return DeviceRegistrationResponse{}, fmt.Errorf("generate ws client auth token: %w", err) + } + if err := insertWSClientAuthToken(tx, robotID, wsClientAuthToken, now); err != nil { + return DeviceRegistrationResponse{}, err + } + if err := tx.Commit(); err != nil { return DeviceRegistrationResponse{}, fmt.Errorf("commit transaction: %w", err) } @@ -198,6 +207,7 @@ func (h *DeviceRegistrationHandler) registerDevice(factoryName, robotTypeModel s RobotType: robotType.Model, RobotTypeID: strconv.FormatInt(robotType.ID, 10), RobotID: strconv.FormatInt(robotID, 10), + WSClientAuthToken: wsClientAuthToken, CallbackAllowlist: h.callbackURLs.allowlist(), }, nil } diff --git a/internal/api/handlers/device_registration_test.go b/internal/api/handlers/device_registration_test.go index d7c3ae9..3227bd6 100644 --- a/internal/api/handlers/device_registration_test.go +++ b/internal/api/handlers/device_registration_test.go @@ -6,9 +6,12 @@ package handlers import ( "bytes" + "crypto/sha256" + "encoding/hex" "encoding/json" "net/http" "net/http/httptest" + "strconv" "strings" "testing" @@ -99,6 +102,9 @@ func TestDeviceRegistrationHandlerRegisterDevice_Success(t *testing.T) { if resp.Factory != "上海一厂" || resp.RobotType != "搬运机器人" || resp.RobotID == "" { t.Fatalf("unexpected response fields: %#v", resp) } + if !strings.HasPrefix(resp.WSClientAuthToken, "kws_v1_") { + t.Fatalf("ws_client_auth_token=%q want kws_v1_ prefix", resp.WSClientAuthToken) + } if !isASCII(resp.DeviceID) { t.Fatalf("device_id is not ASCII: %q", resp.DeviceID) } @@ -117,6 +123,33 @@ func TestDeviceRegistrationHandlerRegisterDevice_Success(t *testing.T) { t.Fatalf("robot count=%d want=1", robotCount) } + robotID, err := strconv.ParseInt(resp.RobotID, 10, 64) + if err != nil { + t.Fatalf("parse robot_id: %v", err) + } + tokenHash := sha256.Sum256([]byte(resp.WSClientAuthToken)) + var storedToken struct { + RobotID int64 `db:"robot_id"` + TokenHash string `db:"token_hash"` + TokenVersion string `db:"token_version"` + } + if err := db.Get(&storedToken, ` + SELECT robot_id, token_hash, token_version + FROM ws_client_auth_tokens + WHERE robot_id = ? + `, robotID); err != nil { + t.Fatalf("query ws client token: %v", err) + } + if storedToken.RobotID != robotID || storedToken.TokenVersion != "kws_v1" { + t.Fatalf("unexpected stored token metadata: %#v", storedToken) + } + if storedToken.TokenHash != hex.EncodeToString(tokenHash[:]) { + t.Fatalf("stored token_hash=%q does not match response token", storedToken.TokenHash) + } + if strings.Contains(storedToken.TokenHash, resp.WSClientAuthToken) { + t.Fatalf("stored token hash appears to contain plaintext token") + } + var nextSequence int64 if err := db.Get(&nextSequence, "SELECT next_sequence FROM device_id_sequences WHERE factory_id = 3 AND robot_type_id = 12"); err != nil { t.Fatalf("query next sequence: %v", err) @@ -144,6 +177,12 @@ func TestDeviceRegistrationHandlerRegisterDevice_RepeatedRequestAllocatesNewDevi if first.RobotID == second.RobotID { t.Fatalf("expected distinct robot ids, got %q", first.RobotID) } + if first.WSClientAuthToken == "" || second.WSClientAuthToken == "" { + t.Fatalf("expected non-empty ws client tokens: first=%q second=%q", first.WSClientAuthToken, second.WSClientAuthToken) + } + if first.WSClientAuthToken == second.WSClientAuthToken { + t.Fatalf("expected distinct ws client tokens, got %q", first.WSClientAuthToken) + } var robotCount int if err := db.Get(&robotCount, "SELECT COUNT(*) FROM robots"); err != nil { @@ -152,6 +191,40 @@ func TestDeviceRegistrationHandlerRegisterDevice_RepeatedRequestAllocatesNewDevi if robotCount != 2 { t.Fatalf("robot count=%d want=2", robotCount) } + + var tokenCount int + if err := db.Get(&tokenCount, "SELECT COUNT(*) FROM ws_client_auth_tokens"); err != nil { + t.Fatalf("count ws client tokens: %v", err) + } + if tokenCount != 2 { + t.Fatalf("ws client token count=%d want=2", tokenCount) + } +} + +func TestDeviceRegistrationHandlerRegisterDevice_TokenInsertFailureRollsBackRobot(t *testing.T) { + db := newTestDeviceRegistrationDB(t) + defer db.Close() + seedDeviceRegistrationFixtures(t, db) + if _, err := db.Exec(`DROP TABLE ws_client_auth_tokens`); err != nil { + t.Fatalf("drop ws client token table: %v", err) + } + + router := newTestDeviceRegistrationRouter(t, db) + req := httptest.NewRequest(http.MethodPost, "/api/v1/devices/register", bytes.NewBufferString(`{"factory":"上海一厂","robot_type":"搬运机器人"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("status=%d want=%d body=%s", w.Code, http.StatusInternalServerError, w.Body.String()) + } + var robotCount int + if err := db.Get(&robotCount, "SELECT COUNT(*) FROM robots"); err != nil { + t.Fatalf("count robots: %v", err) + } + if robotCount != 0 { + t.Fatalf("robot count=%d want=0", robotCount) + } } func TestFormatRegisteredDeviceID_DoesNotTruncateLargeValues(t *testing.T) { @@ -240,6 +313,16 @@ func newTestDeviceRegistrationDB(t *testing.T) *sqlx.DB { updated_at TIMESTAMP, deleted_at TIMESTAMP NULL )`, + `CREATE TABLE ws_client_auth_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + robot_id INTEGER NOT NULL, + token_hash TEXT NOT NULL UNIQUE, + token_version TEXT NOT NULL DEFAULT 'kws_v1', + created_at TIMESTAMP, + last_rotated_at TIMESTAMP NULL, + last_used_at TIMESTAMP NULL, + revoked_at TIMESTAMP NULL + )`, } for _, stmt := range schema { diff --git a/internal/api/handlers/recorder_axon_interaction_test.go b/internal/api/handlers/recorder_axon_interaction_test.go index 2a78660..4e1b659 100644 --- a/internal/api/handlers/recorder_axon_interaction_test.go +++ b/internal/api/handlers/recorder_axon_interaction_test.go @@ -7,6 +7,7 @@ package handlers import ( "bytes" "context" + "database/sql" "encoding/json" "log" "net/http" @@ -794,6 +795,145 @@ func TestOldRecorderDisconnectAfterReplacementDoesNotRevertTasks(t *testing.T) { assertRecorderInteractionTaskStatus(t, db, "task-current-ready", "ready") } +func TestRecorderWebSocketAuthRejectsMissingBearerToken(t *testing.T) { + db := newRecorderInteractionDB(t) + seedRecorderInteractionDevice(t, db, "robot-001", 1, 101) + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, db) + wsURL := newRecorderWebSocketTestServer(t, handler, "robot-001") + + _, resp, err := websocket.Dial(context.Background(), wsURL, nil) + if err == nil { + t.Fatalf("dial without bearer token succeeded") + } + if resp == nil || resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status=%v want=%d err=%v", recorderWebSocketDialStatus(resp), http.StatusUnauthorized, err) + } + if got := resp.Header.Get("WWW-Authenticate"); got != "Bearer" { + t.Fatalf("WWW-Authenticate=%q want Bearer", got) + } +} + +func TestRecorderWebSocketAuthRejectsTokenForDifferentDevice(t *testing.T) { + db := newRecorderInteractionDB(t) + seedRecorderInteractionDevice(t, db, "robot-001", 1, 101) + seedRecorderInteractionDevice(t, db, "robot-002", 2, 102) + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, db) + wsURL := newRecorderWebSocketTestServer(t, handler, "robot-002") + + _, resp, err := websocket.Dial(context.Background(), wsURL, recorderWebSocketDialOptions(recorderWSAuthToken("robot-001"))) + if err == nil { + t.Fatalf("dial with another device token succeeded") + } + if resp == nil || resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status=%v want=%d err=%v", recorderWebSocketDialStatus(resp), http.StatusUnauthorized, err) + } +} + +func TestRecorderWebSocketAuthRejectsInactiveRobot(t *testing.T) { + db := newRecorderInteractionDB(t) + seedRecorderInteractionDevice(t, db, "robot-001", 1, 101) + if _, err := db.Exec(`UPDATE robots SET status = 'inactive' WHERE device_id = 'robot-001'`); err != nil { + t.Fatalf("mark robot inactive: %v", err) + } + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, db) + wsURL := newRecorderWebSocketTestServer(t, handler, "robot-001") + + _, resp, err := websocket.Dial(context.Background(), wsURL, recorderWebSocketDialOptions(recorderWSAuthToken("robot-001"))) + if err == nil { + t.Fatalf("dial with inactive robot token succeeded") + } + if resp == nil || resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status=%v want=%d err=%v", recorderWebSocketDialStatus(resp), http.StatusUnauthorized, err) + } +} + +func TestRecorderWebSocketAuthRejectsRevokedToken(t *testing.T) { + db := newRecorderInteractionDB(t) + seedRecorderInteractionDevice(t, db, "robot-001", 1, 101) + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, db) + wsURL := newRecorderWebSocketTestServer(t, handler, "robot-001") + if _, err := db.Exec(` + UPDATE ws_client_auth_tokens + SET revoked_at = ? + WHERE token_hash = ? + `, time.Now().UTC(), hashWSClientAuthToken(recorderWSAuthToken("robot-001"))); err != nil { + t.Fatalf("revoke ws client token: %v", err) + } + + _, resp, err := websocket.Dial(context.Background(), wsURL, recorderWebSocketDialOptions(recorderWSAuthToken("robot-001"))) + if err == nil { + t.Fatalf("dial with revoked token succeeded") + } + if resp == nil || resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status=%v want=%d err=%v", recorderWebSocketDialStatus(resp), http.StatusUnauthorized, err) + } +} + +func TestRecorderWebSocketAuthRejectsDeletedRobot(t *testing.T) { + db := newRecorderInteractionDB(t) + seedRecorderInteractionDevice(t, db, "robot-001", 1, 101) + if _, err := db.Exec(`UPDATE robots SET deleted_at = ? WHERE device_id = 'robot-001'`, time.Now().UTC()); err != nil { + t.Fatalf("mark robot deleted: %v", err) + } + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, db) + wsURL := newRecorderWebSocketTestServer(t, handler, "robot-001") + + _, resp, err := websocket.Dial(context.Background(), wsURL, recorderWebSocketDialOptions(recorderWSAuthToken("robot-001"))) + if err == nil { + t.Fatalf("dial with deleted robot token succeeded") + } + if resp == nil || resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status=%v want=%d err=%v", recorderWebSocketDialStatus(resp), http.StatusUnauthorized, err) + } +} + +func TestRecorderWebSocketAuthUpdatesLastUsedAt(t *testing.T) { + db := newRecorderInteractionDB(t) + seedRecorderInteractionDevice(t, db, "robot-001", 1, 101) + + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, db) + wsURL := newRecorderWebSocketTestServer(t, handler, "robot-001") + axon := connectFakeRecorderAxon(t, wsURL) + defer axon.closeNow() + + var lastUsedAt sql.NullTime + if err := db.Get(&lastUsedAt, ` + SELECT last_used_at + FROM ws_client_auth_tokens + WHERE token_hash = ? + `, hashWSClientAuthToken(recorderWSAuthToken("robot-001"))); err != nil { + t.Fatalf("query last_used_at: %v", err) + } + if !lastUsedAt.Valid { + t.Fatalf("last_used_at was not updated after successful websocket auth") + } +} + +func TestRecorderWebSocketAuthDBUnavailableReturnsServiceUnavailable(t *testing.T) { + hub := services.NewRecorderHub() + handler := NewRecorderHandler(hub, &config.RecorderConfig{ResponseTimeout: 1}, nil) + wsURL := newRecorderWebSocketTestServer(t, handler, "robot-001") + + _, resp, err := websocket.Dial(context.Background(), wsURL, recorderWebSocketDialOptions(recorderWSAuthToken("robot-001"))) + if err == nil { + t.Fatalf("dial with nil db succeeded") + } + if resp == nil || resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("status=%v want=%d err=%v", recorderWebSocketDialStatus(resp), http.StatusServiceUnavailable, err) + } +} + func newRecorderInteractionRouter(handler *RecorderHandler) *gin.Engine { gin.SetMode(gin.TestMode) router := gin.New() @@ -869,10 +1009,23 @@ func newRecorderInteractionDB(t *testing.T) *sqlx.DB { if _, err := db.Exec(`CREATE TABLE robots ( id INTEGER PRIMARY KEY, device_id TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'active', deleted_at TIMESTAMP NULL )`); err != nil { t.Fatalf("create robots schema: %v", err) } + if _, err := db.Exec(`CREATE TABLE ws_client_auth_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + robot_id INTEGER NOT NULL, + token_hash TEXT NOT NULL UNIQUE, + token_version TEXT NOT NULL DEFAULT 'kws_v1', + created_at TIMESTAMP, + last_rotated_at TIMESTAMP NULL, + last_used_at TIMESTAMP NULL, + revoked_at TIMESTAMP NULL + )`); err != nil { + t.Fatalf("create ws client token schema: %v", err) + } if _, err := db.Exec(`CREATE TABLE workstations ( id INTEGER PRIMARY KEY, robot_id INTEGER NOT NULL, @@ -901,7 +1054,7 @@ func newRecorderInteractionDB(t *testing.T) *sqlx.DB { func seedRecorderInteractionDevice(t *testing.T, db *sqlx.DB, deviceID string, robotID int64, workstationID int64) { t.Helper() - if _, err := db.Exec(`INSERT INTO robots (id, device_id) VALUES (?, ?)`, robotID, deviceID); err != nil { + if _, err := db.Exec(`INSERT INTO robots (id, device_id, status) VALUES (?, ?, 'active')`, robotID, deviceID); err != nil { t.Fatalf("seed robot: %v", err) } if _, err := db.Exec(`INSERT INTO workstations (id, robot_id) VALUES (?, ?)`, workstationID, robotID); err != nil { @@ -1575,6 +1728,9 @@ type fakeRecorderAxon struct { func newRecorderWebSocketTestServer(t *testing.T, handler *RecorderHandler, deviceID string) string { t.Helper() + if handler.db != nil { + seedRecorderWSClientTokenForDevice(t, handler.db, deviceID) + } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler.HandleWebSocket(w, r, deviceID) })) @@ -1584,7 +1740,7 @@ func newRecorderWebSocketTestServer(t *testing.T, handler *RecorderHandler, devi func connectFakeRecorderAxon(t *testing.T, wsURL string) *fakeRecorderAxon { t.Helper() - conn, _, err := websocket.Dial(context.Background(), wsURL, nil) + conn, _, err := websocket.Dial(context.Background(), wsURL, recorderWebSocketDialOptions(recorderWSAuthToken("robot-001"))) if err != nil { t.Fatalf("dial fake recorder websocket: %v", err) } @@ -1610,6 +1766,39 @@ func connectFakeRecorderAxon(t *testing.T, wsURL string) *fakeRecorderAxon { return axon } +func recorderWebSocketDialOptions(token string) *websocket.DialOptions { + headers := http.Header{} + if token != "" { + headers.Set("Authorization", "Bearer "+token) + } + return &websocket.DialOptions{HTTPHeader: headers} +} + +func recorderWebSocketDialStatus(resp *http.Response) int { + if resp == nil { + return 0 + } + return resp.StatusCode +} + +func recorderWSAuthToken(deviceID string) string { + return "kws_v1_test_token_" + strings.ReplaceAll(deviceID, "-", "_") +} + +func seedRecorderWSClientTokenForDevice(t *testing.T, db *sqlx.DB, deviceID string) { + t.Helper() + var robotID int64 + if err := db.Get(&robotID, `SELECT id FROM robots WHERE device_id = ?`, deviceID); err != nil { + t.Fatalf("query robot id for ws token: %v", err) + } + if _, err := db.Exec(` + INSERT OR IGNORE INTO ws_client_auth_tokens (robot_id, token_hash, token_version, created_at) + VALUES (?, ?, ?, ?) + `, robotID, hashWSClientAuthToken(recorderWSAuthToken(deviceID)), wsClientTokenVersion, time.Now().UTC()); err != nil { + t.Fatalf("seed ws client token: %v", err) + } +} + func (f *fakeRecorderAxon) receiveRPC(t *testing.T, wantAction string) services.RPCRequest { t.Helper() return receiveRecorderRPCRequest(t, f.requests, wantAction) diff --git a/internal/api/handlers/ws_client_auth.go b/internal/api/handlers/ws_client_auth.go new file mode 100644 index 0000000..037e7f3 --- /dev/null +++ b/internal/api/handlers/ws_client_auth.go @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: 2026 ArcheBase +// +// SPDX-License-Identifier: MulanPSL-2.0 + +package handlers + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "archebase.com/keystone-edge/internal/logger" + "github.com/jmoiron/sqlx" +) + +const wsClientTokenVersion = "kws_v1" + +func generateWSClientAuthToken() (string, error) { + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return "", fmt.Errorf("read random bytes: %w", err) + } + return wsClientTokenVersion + "_" + base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +func hashWSClientAuthToken(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} + +func insertWSClientAuthToken(tx *sqlx.Tx, robotID int64, token string, now time.Time) error { + if _, err := tx.Exec(` + INSERT INTO ws_client_auth_tokens ( + robot_id, + token_hash, + token_version, + created_at + ) VALUES (?, ?, ?, ?) + `, robotID, hashWSClientAuthToken(token), wsClientTokenVersion, now); err != nil { + return fmt.Errorf("insert ws client auth token: %w", err) + } + return nil +} + +func (h *RecorderHandler) authorizeRecorderWebSocket(w http.ResponseWriter, r *http.Request, deviceID string) bool { + if h.db == nil { + writeRecorderWebSocketAuthError(w, http.StatusServiceUnavailable, "service unavailable", false) + return false + } + + token, ok := parseBearerToken(r.Header.Get("Authorization")) + if !ok { + writeRecorderWebSocketAuthError(w, http.StatusUnauthorized, "unauthorized", true) + return false + } + + queryTimeout := 5 * time.Second + queryCtx, cancel := context.WithTimeout(r.Context(), queryTimeout) + defer cancel() + + var tokenID int64 + if err := h.db.GetContext(queryCtx, &tokenID, ` + SELECT t.id + FROM ws_client_auth_tokens t + JOIN robots r ON r.id = t.robot_id + WHERE r.device_id = ? + AND t.token_hash = ? + AND t.revoked_at IS NULL + AND r.status = 'active' + AND r.deleted_at IS NULL + LIMIT 1 + `, deviceID, hashWSClientAuthToken(token)); err != nil { + if errors.Is(err, sql.ErrNoRows) { + writeRecorderWebSocketAuthError(w, http.StatusUnauthorized, "unauthorized", true) + return false + } + logger.Printf("%s ws client auth query error: %v", recorderLogPrefix(deviceID), err) + writeRecorderWebSocketAuthError(w, http.StatusServiceUnavailable, "service unavailable", false) + return false + } + + if _, err := h.db.ExecContext(queryCtx, ` + UPDATE ws_client_auth_tokens + SET last_used_at = ? + WHERE id = ? + `, time.Now().UTC(), tokenID); err != nil { + logger.Printf("%s ws client auth last_used_at update failed: %v", recorderLogPrefix(deviceID), err) + } + + return true +} + +func parseBearerToken(header string) (string, bool) { + parts := strings.Fields(header) + if len(parts) != 2 || parts[0] != "Bearer" || strings.TrimSpace(parts[1]) == "" { + return "", false + } + return parts[1], true +} + +func writeRecorderWebSocketAuthError(w http.ResponseWriter, status int, message string, challenge bool) { + w.Header().Set("Content-Type", "application/json") + if challenge { + w.Header().Set("WWW-Authenticate", "Bearer") + } + w.WriteHeader(status) + if _, err := fmt.Fprintf(w, `{"error":%q}`, message); err != nil { + logger.Printf("[RECORDER] write websocket auth error failed: %v", err) + } +} diff --git a/internal/storage/database/migrations/000007_ws_client_auth_tokens.down.sql b/internal/storage/database/migrations/000007_ws_client_auth_tokens.down.sql new file mode 100644 index 0000000..1469269 --- /dev/null +++ b/internal/storage/database/migrations/000007_ws_client_auth_tokens.down.sql @@ -0,0 +1,5 @@ +-- SPDX-FileCopyrightText: 2026 ArcheBase +-- +-- SPDX-License-Identifier: MulanPSL-2.0 + +DROP TABLE IF EXISTS ws_client_auth_tokens; diff --git a/internal/storage/database/migrations/000007_ws_client_auth_tokens.up.sql b/internal/storage/database/migrations/000007_ws_client_auth_tokens.up.sql new file mode 100644 index 0000000..5a91c7e --- /dev/null +++ b/internal/storage/database/migrations/000007_ws_client_auth_tokens.up.sql @@ -0,0 +1,16 @@ +-- SPDX-FileCopyrightText: 2026 ArcheBase +-- +-- SPDX-License-Identifier: MulanPSL-2.0 + +CREATE TABLE IF NOT EXISTS ws_client_auth_tokens ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + robot_id BIGINT NOT NULL, + token_hash CHAR(64) NOT NULL, + token_version VARCHAR(16) NOT NULL DEFAULT 'kws_v1', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_rotated_at TIMESTAMP NULL, + last_used_at TIMESTAMP NULL, + revoked_at TIMESTAMP NULL, + UNIQUE INDEX idx_ws_client_token_hash (token_hash), + INDEX idx_ws_client_robot_active (robot_id, revoked_at) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; From 35565339315719db07c401497e7605ffc017212d Mon Sep 17 00:00:00 2001 From: chaoliu Date: Tue, 23 Jun 2026 17:10:57 +0800 Subject: [PATCH 3/6] feat(device): add ws client token rotation --- docs/designs/device-registration-api.md | 105 ++++++++- internal/api/handlers/device_registration.go | 114 ++++++++++ .../api/handlers/device_registration_test.go | 206 +++++++++++++++++- internal/server/server.go | 2 + 4 files changed, 417 insertions(+), 10 deletions(-) diff --git a/docs/designs/device-registration-api.md b/docs/designs/device-registration-api.md index 2699255..77b9677 100644 --- a/docs/designs/device-registration-api.md +++ b/docs/designs/device-registration-api.md @@ -204,10 +204,85 @@ The table intentionally does not define a database foreign key. Existing Keyston style keeps these relationships application-managed, and this avoids introducing migration ordering and SQLite fixture complexity. -This version does not provide a rotate endpoint. `last_rotated_at` and `revoked_at` are -reserved for a future explicit token rotation or revocation flow. +`last_rotated_at` and `revoked_at` are used by the explicit token rotation flow described +below. -## 9. Recorder WebSocket Authentication +## 9. WebSocket Client Token Rotation + +Keystone provides an admin-only token rotation endpoint for cases where the robot-side +plaintext token was lost, exposed, or needs to be replaced during maintenance. + +| Method | Path | Auth | Caller | +|------|------|------|------| +| POST | `/api/v1/robots/:id/ws-client-auth-token/rotate` | Admin JWT | Synapse admin or operator-run admin tool | + +The endpoint uses `robots.id` from the path, not `device_id`, matching the existing robot +management API style. + +### 9.1 Rotation Response + +Success returns `200 OK`. + +```json +{ + "device_id": "AB-F0001-T0003-000001", + "robot_id": "9", + "ws_client_auth_token": "kws_v1_3Z2iX5lFh7mYxLQd9P0sAqzF2Z3w4R5t6U7v8W9x0Y", + "rotated_at": "2026-06-23T10:15:30Z" +} +``` + +| Field | Meaning | +|------|------| +| `device_id` | Current `robots.device_id` for the rotated robot | +| `robot_id` | `robots.id`, encoded as a string for existing API style | +| `ws_client_auth_token` | New one-time plaintext token for Axon recorder WebSocket client authentication | +| `rotated_at` | UTC timestamp used for revoking old active tokens and creating the new token | + +The response must not include old token hashes or revoked token rows. The plaintext token +is returned only once and must not be logged. + +### 9.2 Rotation Behavior + +Rotation is transactional: + +1. Lock and validate the target `robots` row. +2. Reject soft-deleted robots as not found. +3. Reject robots whose `status` is not `active`. +4. Generate a new `kws_v1_` token. +5. Set `revoked_at = rotated_at` and `last_rotated_at = rotated_at` on every active + `ws_client_auth_tokens` row for the robot. +6. Insert one new active `ws_client_auth_tokens` row with the SHA-256 hash of the new + plaintext token. +7. Commit and return the new plaintext token once. + +If the robot currently has no active token, rotation still succeeds and inserts a new +active token. This supports recovery from failed manual cleanup or missing seed data. + +Old tokens stop authenticating new WebSocket handshakes immediately after the transaction +commits. Keystone does not proactively close an already-established recorder WebSocket +connection; that connection will need the new token after it disconnects and reconnects. + +### 9.3 Rotation Error Responses + +Errors follow Keystone's usual JSON shape: + +```json +{ + "error": "robot not found" +} +``` + +| Status | Condition | Error | +|------|------|------| +| `400` | Path `id` is not a valid positive integer | `invalid robot id` | +| `400` | Robot exists but `status` is not `active` | `robot is not active` | +| `401` | Missing, expired, or invalid admin JWT | Existing auth middleware response | +| `403` | Authenticated user is not an admin | Existing role middleware response | +| `404` | No non-deleted robot with matching `robots.id` | `robot not found` | +| `500` | Token generation, token update, token insert, or transaction failure | `failed to rotate ws client auth token` | + +## 10. Recorder WebSocket Authentication This token is required only for the Axon recorder WebSocket in this implementation. Axon transfer WebSocket remains unchanged because Axon transfer does not currently send a @@ -277,7 +352,7 @@ Logs may include `device_id` and a broad reason such as `missing bearer token` o `invalid token`, but must not include token plaintext or distinguish "not found" from "belongs to another device". -## 10. Concurrency +## 11. Concurrency Concurrent requests with the same `factory` and `robot_type` are supported. Keystone uses `device_id_sequences` to serialize allocation for each `(factory_id, robot_type_id)` pair. @@ -305,7 +380,12 @@ The selected `next_sequence` is used in `device_id`, then Keystone increments Token insertion is part of the same transaction. If inserting the token row fails, the robot insert and device sequence increment are rolled back with the transaction. -## 11. Install Script Usage +Token rotation also runs in a transaction. Concurrent rotations for the same robot must +serialize on the `robots` row or token rows so only the final committed response contains +the active token. A client should treat any earlier concurrent rotation response as stale if +another successful rotation completes later. + +## 12. Install Script Usage Example: @@ -330,14 +410,14 @@ Expected script behavior: `ws_client_auth_token_file`, it writes the token to `/var/lib/axon/secrets/ws_client.token` or the path supplied by `--ws-client-token-file`. -## 12. Implementation Notes +## 13. Implementation Notes Implementation files: | File | Purpose | |------|------| -| `internal/api/handlers/device_registration.go` | Request validation, transaction, sequence allocation, robot insertion | -| `internal/api/handlers/ws_client_auth.go` | Token generation, hashing, storage, and recorder WebSocket validation | +| `internal/api/handlers/device_registration.go` | Request validation, registration transaction, sequence allocation, robot insertion, token rotation response handling | +| `internal/api/handlers/ws_client_auth.go` | Token generation, hashing, storage, rotation helpers, and recorder WebSocket validation | | `internal/server/server.go` | Handler construction and route registration | | `internal/storage/database/migrations/000002_device_id_sequences.up.sql` | Sequence table migration | | `internal/storage/database/migrations/000002_device_id_sequences.down.sql` | Sequence table rollback | @@ -356,18 +436,25 @@ TDD coverage should include: 6. Recorder WebSocket with correct token and device ID connects. 7. Token for robot A cannot connect as robot B. 8. Deleted or non-active robot cannot authenticate. +9. Rotation returns a new `ws_client_auth_token` with `kws_v1_` prefix. +10. Rotation revokes prior active tokens and stores only the new token hash. +11. Rotation succeeds when the robot has no active token. +12. Rotation returns `404` for missing or soft-deleted robots. +13. Rotation returns `400` for non-active robots. +14. Old token fails and new token succeeds on recorder WebSocket authentication after + rotation. Out of scope for this implementation: - Axon transfer WebSocket token authentication. - Token query parameters, `X-API-Key`, or `Sec-WebSocket-Protocol`. -- Token rotation API. - Register endpoint authentication. Validation performed during implementation: ```bash go test ./internal/api/handlers -run 'TestDeviceRegistration|TestFormatRegisteredDeviceID|TestDeviceRegistrationRoutes' -v +go test ./internal/api/handlers -run 'TestDeviceRegistrationHandlerRotateWSClientAuthToken' -v go test ./... ``` diff --git a/internal/api/handlers/device_registration.go b/internal/api/handlers/device_registration.go index e23327c..6fb6276 100644 --- a/internal/api/handlers/device_registration.go +++ b/internal/api/handlers/device_registration.go @@ -22,6 +22,8 @@ import ( var ( errRegistrationFactoryNotFound = errors.New("factory not found") errRegistrationRobotTypeNotFound = errors.New("robot_type not found") + errRegistrationRobotNotFound = errors.New("robot not found") + errRegistrationRobotNotActive = errors.New("robot is not active") ) // DeviceRegistrationHandler handles install-time device registration requests. @@ -56,6 +58,14 @@ type DeviceRegistrationResponse struct { CallbackAllowlist CallbackAllowlist `json:"callback_allowlist"` } +// RotateWSClientAuthTokenResponse represents a successful token rotation. +type RotateWSClientAuthTokenResponse struct { + DeviceID string `json:"device_id"` + RobotID string `json:"robot_id"` + WSClientAuthToken string `json:"ws_client_auth_token"` + RotatedAt string `json:"rotated_at"` +} + type deviceRegistrationFactoryRow struct { ID int64 `db:"id"` Name string `db:"name"` @@ -71,6 +81,11 @@ func (h *DeviceRegistrationHandler) RegisterRoutes(apiV1 *gin.RouterGroup) { apiV1.POST("/devices/register", h.RegisterDevice) } +// RegisterAdminRoutes registers admin-only device credential routes. +func (h *DeviceRegistrationHandler) RegisterAdminRoutes(apiV1 *gin.RouterGroup) { + apiV1.POST("/robots/:id/ws-client-auth-token/rotate", h.RotateWSClientAuthToken) +} + // RegisterDevice handles install-time robot device registration. // // @Summary Register device @@ -120,6 +135,42 @@ func (h *DeviceRegistrationHandler) RegisterDevice(c *gin.Context) { c.JSON(http.StatusCreated, resp) } +// RotateWSClientAuthToken handles admin-triggered recorder WebSocket token rotation. +// +// @Summary Rotate recorder WebSocket client token +// @Description Revokes active recorder WebSocket client tokens for one robot and returns a new plaintext token once +// @Tags robots +// @Produce json +// @Param id path int true "Robot ID" +// @Success 200 {object} RotateWSClientAuthTokenResponse +// @Failure 400 {object} map[string]string +// @Failure 404 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /robots/{id}/ws-client-auth-token/rotate [post] +func (h *DeviceRegistrationHandler) RotateWSClientAuthToken(c *gin.Context) { + robotID, err := strconv.ParseInt(strings.TrimSpace(c.Param("id")), 10, 64) + if err != nil || robotID <= 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid robot id"}) + return + } + + resp, err := h.rotateWSClientAuthToken(robotID) + if err != nil { + switch { + case errors.Is(err, errRegistrationRobotNotFound): + c.JSON(http.StatusNotFound, gin.H{"error": "robot not found"}) + case errors.Is(err, errRegistrationRobotNotActive): + c.JSON(http.StatusBadRequest, gin.H{"error": "robot is not active"}) + default: + logger.Printf("[DEVICE] Failed to rotate ws client auth token: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to rotate ws client auth token"}) + } + return + } + + c.JSON(http.StatusOK, resp) +} + func (h *DeviceRegistrationHandler) registerDevice(factoryName, robotTypeModel string) (DeviceRegistrationResponse, error) { tx, err := h.db.Beginx() if err != nil { @@ -212,6 +263,69 @@ func (h *DeviceRegistrationHandler) registerDevice(factoryName, robotTypeModel s }, nil } +func (h *DeviceRegistrationHandler) rotateWSClientAuthToken(robotID int64) (RotateWSClientAuthTokenResponse, error) { + tx, err := h.db.Beginx() + if err != nil { + return RotateWSClientAuthTokenResponse{}, fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() //nolint:errcheck // Safe after successful Commit. + + type robotTokenRotationRow struct { + ID int64 `db:"id"` + DeviceID string `db:"device_id"` + Status string `db:"status"` + } + + query := ` + SELECT id, device_id, status + FROM robots + WHERE id = ? AND deleted_at IS NULL + ` + if tx.DriverName() != "sqlite" { + query += " FOR UPDATE" + } + + var robot robotTokenRotationRow + if err := tx.Get(&robot, query, robotID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return RotateWSClientAuthTokenResponse{}, errRegistrationRobotNotFound + } + return RotateWSClientAuthTokenResponse{}, fmt.Errorf("query robot: %w", err) + } + if robot.Status != "active" { + return RotateWSClientAuthTokenResponse{}, errRegistrationRobotNotActive + } + + rotatedAt := time.Now().UTC() + token, err := generateWSClientAuthToken() + if err != nil { + return RotateWSClientAuthTokenResponse{}, fmt.Errorf("generate ws client auth token: %w", err) + } + + if _, err := tx.Exec(` + UPDATE ws_client_auth_tokens + SET revoked_at = ?, last_rotated_at = ? + WHERE robot_id = ? AND revoked_at IS NULL + `, rotatedAt, rotatedAt, robot.ID); err != nil { + return RotateWSClientAuthTokenResponse{}, fmt.Errorf("revoke active ws client auth tokens: %w", err) + } + + if err := insertWSClientAuthToken(tx, robot.ID, token, rotatedAt); err != nil { + return RotateWSClientAuthTokenResponse{}, err + } + + if err := tx.Commit(); err != nil { + return RotateWSClientAuthTokenResponse{}, fmt.Errorf("commit transaction: %w", err) + } + + return RotateWSClientAuthTokenResponse{ + DeviceID: robot.DeviceID, + RobotID: strconv.FormatInt(robot.ID, 10), + WSClientAuthToken: token, + RotatedAt: rotatedAt.Format(time.RFC3339), + }, nil +} + func allocateDeviceIDSequence(tx *sqlx.Tx, factoryID, robotTypeID int64) (int64, error) { if tx.DriverName() == "sqlite" { return allocateDeviceIDSequenceSQLite(tx, factoryID, robotTypeID) diff --git a/internal/api/handlers/device_registration_test.go b/internal/api/handlers/device_registration_test.go index 3227bd6..daacbaf 100644 --- a/internal/api/handlers/device_registration_test.go +++ b/internal/api/handlers/device_registration_test.go @@ -227,6 +227,207 @@ func TestDeviceRegistrationHandlerRegisterDevice_TokenInsertFailureRollsBackRobo } } +func TestDeviceRegistrationHandlerRotateWSClientAuthToken_SuccessRevokesOldToken(t *testing.T) { + db := newTestDeviceRegistrationDB(t) + defer db.Close() + seedDeviceRegistrationFixtures(t, db) + + router := newTestDeviceRegistrationRouter(t, db) + registered := registerTestDevice(t, router) + robotID, err := strconv.ParseInt(registered.RobotID, 10, 64) + if err != nil { + t.Fatalf("parse robot_id: %v", err) + } + oldHashBytes := sha256.Sum256([]byte(registered.WSClientAuthToken)) + oldHash := hex.EncodeToString(oldHashBytes[:]) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/robots/"+registered.RobotID+"/ws-client-auth-token/rotate", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d want=%d body=%s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp struct { + DeviceID string `json:"device_id"` + RobotID string `json:"robot_id"` + WSClientAuthToken string `json:"ws_client_auth_token"` + RotatedAt string `json:"rotated_at"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, w.Body.String()) + } + if resp.DeviceID != registered.DeviceID || resp.RobotID != registered.RobotID { + t.Fatalf("unexpected rotate response identity: %#v", resp) + } + if !strings.HasPrefix(resp.WSClientAuthToken, "kws_v1_") { + t.Fatalf("ws_client_auth_token=%q want kws_v1_ prefix", resp.WSClientAuthToken) + } + if resp.WSClientAuthToken == registered.WSClientAuthToken { + t.Fatalf("rotated token should differ from old token") + } + if strings.TrimSpace(resp.RotatedAt) == "" { + t.Fatalf("rotated_at is empty") + } + + var revokedOldCount int + if err := db.Get(&revokedOldCount, ` + SELECT COUNT(*) + FROM ws_client_auth_tokens + WHERE robot_id = ? AND token_hash = ? AND revoked_at IS NOT NULL AND last_rotated_at IS NOT NULL + `, robotID, oldHash); err != nil { + t.Fatalf("count revoked old token: %v", err) + } + if revokedOldCount != 1 { + t.Fatalf("revoked old token count=%d want=1", revokedOldCount) + } + + newHashBytes := sha256.Sum256([]byte(resp.WSClientAuthToken)) + newHash := hex.EncodeToString(newHashBytes[:]) + var activeTokenHash string + if err := db.Get(&activeTokenHash, ` + SELECT token_hash + FROM ws_client_auth_tokens + WHERE robot_id = ? AND revoked_at IS NULL + `, robotID); err != nil { + t.Fatalf("query active token hash: %v", err) + } + if activeTokenHash != newHash { + t.Fatalf("active token hash=%q does not match rotated token", activeTokenHash) + } + if strings.Contains(activeTokenHash, resp.WSClientAuthToken) { + t.Fatalf("stored token hash appears to contain plaintext token") + } +} + +func TestDeviceRegistrationHandlerRotateWSClientAuthToken_SucceedsWithoutActiveToken(t *testing.T) { + db := newTestDeviceRegistrationDB(t) + defer db.Close() + seedDeviceRegistrationFixtures(t, db) + + router := newTestDeviceRegistrationRouter(t, db) + registered := registerTestDevice(t, router) + if _, err := db.Exec(` + UPDATE ws_client_auth_tokens + SET revoked_at = ?, last_rotated_at = ? + WHERE robot_id = ? + `, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z", registered.RobotID); err != nil { + t.Fatalf("revoke seeded token: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/api/v1/robots/"+registered.RobotID+"/ws-client-auth-token/rotate", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d want=%d body=%s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp RotateWSClientAuthTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, w.Body.String()) + } + if !strings.HasPrefix(resp.WSClientAuthToken, "kws_v1_") { + t.Fatalf("ws_client_auth_token=%q want kws_v1_ prefix", resp.WSClientAuthToken) + } + + var activeTokenCount int + if err := db.Get(&activeTokenCount, ` + SELECT COUNT(*) + FROM ws_client_auth_tokens + WHERE robot_id = ? AND revoked_at IS NULL + `, registered.RobotID); err != nil { + t.Fatalf("count active tokens: %v", err) + } + if activeTokenCount != 1 { + t.Fatalf("active token count=%d want=1", activeTokenCount) + } +} + +func TestDeviceRegistrationHandlerRotateWSClientAuthToken_RobotNotFound(t *testing.T) { + db := newTestDeviceRegistrationDB(t) + defer db.Close() + seedDeviceRegistrationFixtures(t, db) + if _, err := db.Exec(` + INSERT INTO robots (id, robot_type_id, device_id, factory_id, status, deleted_at) + VALUES (99, 12, 'deleted-device', 3, 'active', '2026-01-01T00:00:00Z') + `); err != nil { + t.Fatalf("seed deleted robot: %v", err) + } + + router := newTestDeviceRegistrationRouter(t, db) + for _, path := range []string{ + "/api/v1/robots/42/ws-client-auth-token/rotate", + "/api/v1/robots/99/ws-client-auth-token/rotate", + } { + req := httptest.NewRequest(http.MethodPost, path, nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("%s status=%d want=%d body=%s", path, w.Code, http.StatusNotFound, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "robot not found") { + t.Fatalf("%s unexpected error response: %s", path, w.Body.String()) + } + } +} + +func TestDeviceRegistrationHandlerRotateWSClientAuthToken_RobotNotActive(t *testing.T) { + db := newTestDeviceRegistrationDB(t) + defer db.Close() + seedDeviceRegistrationFixtures(t, db) + if _, err := db.Exec(` + INSERT INTO robots (id, robot_type_id, device_id, factory_id, status) + VALUES (88, 12, 'maintenance-device', 3, 'maintenance') + `); err != nil { + t.Fatalf("seed maintenance robot: %v", err) + } + + router := newTestDeviceRegistrationRouter(t, db) + req := httptest.NewRequest(http.MethodPost, "/api/v1/robots/88/ws-client-auth-token/rotate", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d want=%d body=%s", w.Code, http.StatusBadRequest, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "robot is not active") { + t.Fatalf("unexpected error response: %s", w.Body.String()) + } + + var tokenCount int + if err := db.Get(&tokenCount, "SELECT COUNT(*) FROM ws_client_auth_tokens WHERE robot_id = 88"); err != nil { + t.Fatalf("count tokens: %v", err) + } + if tokenCount != 0 { + t.Fatalf("token count=%d want=0", tokenCount) + } +} + +func TestDeviceRegistrationHandlerRotateWSClientAuthToken_InvalidRobotID(t *testing.T) { + db := newTestDeviceRegistrationDB(t) + defer db.Close() + + router := newTestDeviceRegistrationRouter(t, db) + for _, path := range []string{ + "/api/v1/robots/not-a-number/ws-client-auth-token/rotate", + "/api/v1/robots/0/ws-client-auth-token/rotate", + } { + req := httptest.NewRequest(http.MethodPost, path, nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("%s status=%d want=%d body=%s", path, w.Code, http.StatusBadRequest, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "invalid robot id") { + t.Fatalf("%s unexpected error response: %s", path, w.Body.String()) + } + } +} + func TestFormatRegisteredDeviceID_DoesNotTruncateLargeValues(t *testing.T) { got := formatRegisteredDeviceID(12345, 98765, 1234567) want := "AB-F12345-T98765-1234567" @@ -241,7 +442,9 @@ func TestDeviceRegistrationRoutes_DoNotConflictWithRobotDeviceRoutes(t *testing. v1 := router.Group("/api/v1") NewRobotHandler(nil, nil, nil).RegisterRoutes(v1) - NewDeviceRegistrationHandler(nil, "http://192.168.1.20:9999").RegisterRoutes(v1) + handler := NewDeviceRegistrationHandler(nil, "http://192.168.1.20:9999") + handler.RegisterRoutes(v1) + handler.RegisterAdminRoutes(v1) } func registerTestDevice(t *testing.T, router *gin.Engine) DeviceRegistrationResponse { @@ -270,6 +473,7 @@ func newTestDeviceRegistrationRouter(t *testing.T, db *sqlx.DB) *gin.Engine { handler := NewDeviceRegistrationHandler(db, "http://192.168.1.20:9999") v1 := router.Group("/api/v1") handler.RegisterRoutes(v1) + handler.RegisterAdminRoutes(v1) return router } diff --git a/internal/server/server.go b/internal/server/server.go index c137788..2acefec 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -298,6 +298,8 @@ func (s *Server) buildRoutes() http.Handler { } if s.deviceRegistration != nil { s.deviceRegistration.RegisterRoutes(v1Tasks) + adminDeviceCredentials := v1Routes.Group("", middleware.JWTAuth(&s.cfg.Auth), middleware.RequireRole("admin")) + s.deviceRegistration.RegisterAdminRoutes(adminDeviceCredentials) } if s.factory != nil { s.factory.RegisterRoutes(v1Tasks) From 7c638ad9853d68f09f23b7cc7d06b253d8b930bb Mon Sep 17 00:00:00 2001 From: chaoliu Date: Wed, 24 Jun 2026 15:32:57 +0800 Subject: [PATCH 4/6] feat(transfer): persist recorder writer health --- .../designs/episode-recorder-writer-health.md | 285 ++++++++++++++++++ internal/api/handlers/episode.go | 6 +- .../api/handlers/episode_metadata_test.go | 190 ++++++++++++ internal/api/handlers/transfer.go | 78 ++++- .../transfer_asset_id_snapshot_test.go | 111 +++++++ 5 files changed, 664 insertions(+), 6 deletions(-) create mode 100644 docs/designs/episode-recorder-writer-health.md create mode 100644 internal/api/handlers/episode_metadata_test.go diff --git a/docs/designs/episode-recorder-writer-health.md b/docs/designs/episode-recorder-writer-health.md new file mode 100644 index 0000000..a905c61 --- /dev/null +++ b/docs/designs/episode-recorder-writer-health.md @@ -0,0 +1,285 @@ + + +# Episode Recorder Writer Health Design + +**Status:** proposed + +**Scope:** Keystone transfer upload completion, episode metadata, episode detail API, and Synapse episode detail UI. + +## 1. Problem + +Axon Recorder now writes a `writer_health` summary into the recording sidecar JSON. +The summary describes whether the recorder write path had queue pressure, writer +stall, queue overflow, or partial write failures during a recording. + +Keystone already reads the sidecar JSON during transfer `upload_complete` to +populate episode fields such as duration, file size, and checksum. Keystone +should also persist the sidecar `writer_health` summary on the episode so +Synapse can show recording diagnostics on the episode detail page. + +This diagnostic is not a QA result. It is a recorder-side write-path signal that +helps users understand whether the captured data may need extra review. + +## 2. Current Behavior + +- Axon writes top-level `writer_health` into sidecar JSON when a recording is + finished and sidecar generation is enabled. +- Keystone `POST /api/v1/callbacks/finish` marks the task as `uploading` and + sends an upload request to `axon_transfer`. +- Keystone creates the `episodes` row later, when `axon_transfer` reports + `upload_complete` and Keystone verifies that both MCAP and sidecar objects are + present in MinIO. +- Keystone reads sidecar JSON at that point to extract: + - `recording.duration_sec` + - `recording.file_size_bytes` + - `recording.checksum_sha256` +- Episode metadata currently stores an `asset_id` snapshot when it can be + resolved from the workstation's robot. +- Synapse episode detail currently does not render episode metadata and has no + recording diagnostics panel. + +## 3. Decisions + +### 3.1 Source of Truth + +Keystone should read `writer_health` from the uploaded sidecar JSON, not from +the recorder finish callback. + +Rationale: + +- It matches the existing flow for duration, file size, and checksum. +- It avoids adding finish callback persistence before an episode row exists. +- It keeps the episode diagnostic tied to the uploaded recording artifact. +- Finish callback handling stays focused on task lifecycle and upload request + dispatch. + +### 3.2 Episode Metadata Shape + +Persist the recorder diagnostic under a namespaced metadata path: + +```json +{ + "asset_id": "robot-asset-001", + "recorder": { + "writer_health": { + "state": "critical", + "writer_stall_state": "critical", + "writer_stall_suspected": true, + "writer_partial_failures": 2, + "writer_queue_overflows": 1, + "error": "writer_partial_failures=2" + } + } +} +``` + +Rules: + +- Preserve existing metadata fields such as `asset_id`. +- Preserve existing `metadata.recorder` fields. +- Only overwrite `metadata.recorder.writer_health` when a new sidecar + `writer_health` object is present. +- Do not store `writer_health` at metadata top level. +- Keep `error: null` when the sidecar explicitly contains no error. The frontend + should hide an empty or null diagnostic message. + +### 3.3 Sidecar Parsing Rules + +Keystone should extend its sidecar parser with a top-level optional +`writer_health` object. + +The expected Axon sidecar shape is: + +```json +{ + "writer_health": { + "state": "normal", + "writer_stall_state": "normal", + "writer_stall_suspected": false, + "writer_partial_failures": 0, + "writer_queue_overflows": 0, + "error": null + } +} +``` + +Rules: + +- If top-level `writer_health` exists, persist it even when the state is + `normal` and all counters are zero. +- If top-level `writer_health` is absent, do not write or clear + `metadata.recorder.writer_health`. +- If sidecar reading or unmarshalling fails, keep the existing best-effort + behavior and do not block upload completion. +- Backend should preserve the sidecar values rather than rejecting or rewriting + unknown state strings. + +### 3.4 Idempotency + +Keystone should support both creation-time persistence and later idempotent +repair. + +| Case | Behavior | +| --- | --- | +| New episode, sidecar has `writer_health` | Insert `metadata.recorder.writer_health`. | +| New episode, sidecar lacks `writer_health` | Insert existing metadata only, such as `asset_id`. | +| Episode already exists, sidecar has `writer_health` | Merge and overwrite only `metadata.recorder.writer_health`. | +| Episode already exists, sidecar lacks `writer_health` | Leave metadata unchanged. | +| Sidecar read fails | Leave metadata unchanged and continue existing flow. | + +This allows repeated `upload_complete` handling to backfill or correct writer +health without duplicating episodes or destroying unrelated metadata. + +### 3.5 Database Model + +Do not add a dedicated database column in this change. + +Rationale: + +- The current requirement is detail-page display, not filtering, sorting, or + statistics. +- `episodes.metadata` already exists and is appropriate for artifact-derived + optional diagnostics. +- A future list filter can add a dedicated `writer_health_state` column or JSON + index when the product needs it. + +### 3.6 Episode API + +`GET /api/v1/episodes/:id` should return `metadata`. + +`GET /api/v1/episodes` should not return `metadata`. + +Rationale: + +- Only the episode detail page needs this diagnostic. +- List payloads should remain compact. +- Future list filtering should use explicit fields or query parameters rather + than returning complete metadata for every row. + +## 4. Writer Health Field Meaning + +| Field | Meaning | UI Use | +| --- | --- | --- | +| `state` | Overall recorder write-path health: `normal`, `warning`, `critical`, or a future value. | Main panel status. | +| `writer_stall_state` | Health of MCAP writer latency/stall detection. | Detail row. | +| `writer_stall_suspected` | Whether Axon suspects writer stalls occurred. | Detail row as yes/no. | +| `writer_partial_failures` | Count of partial write failures observed by the writer path. | High-risk counter. | +| `writer_queue_overflows` | Count of writer queue overflows. | High-risk counter. | +| `error` | Human-readable diagnostic summary, often `writer_partial_failures=N`. | Show only when non-empty. | + +Recommended interpretation: + +- `writer_partial_failures > 0` is a serious signal that some messages may not + have been written successfully. +- `writer_queue_overflows > 0` is a serious signal that the writer queue could + not keep up. +- `writer_stall_suspected = true` indicates write-path latency or blocking risk. +- `state` is the user-facing summary. +- `error` is supplemental text, not the primary status source. + +## 5. Synapse Episode Detail UI + +Synapse should add a standalone recording diagnostics panel on the admin episode +detail page. + +Placement: + +1. Identity panel +2. Quality check panel +3. Recording diagnostics panel +4. Cloud sync panel +5. File path panel + +Rules: + +- Read `episode.metadata?.recorder?.writer_health`. +- If `writer_health` is missing, do not render the panel. +- If `writer_health` exists with `state = normal`, render the panel with a light + normal treatment. +- If state is `warning`, render a warning treatment. +- If state is `critical`, render an error treatment. +- If state is present but unknown, render a neutral "unknown" treatment. +- Do not render the full metadata JSON. +- Do not change QA status or trigger QA failure based on writer health in this + change. + +Suggested copy: + +| State | Title | Description | +| --- | --- | --- | +| `normal` | `录制写入正常` | `未发现写入队列溢出或部分写入失败。` | +| `warning` | `录制写入告警` | `录制写入链路存在压力信号,建议关注。` | +| `critical` | `录制写入异常` | `这段录制的写入链路存在风险,建议复核 MCAP 完整性。` | +| unknown | `录制写入状态未知` | `录制诊断状态无法识别。` | + +Suggested detail rows: + +| Label | Source | +| --- | --- | +| `总体状态` | `writer_health.state` | +| `写入阻塞` | `writer_health.writer_stall_state` | +| `疑似阻塞` | `writer_health.writer_stall_suspected` | +| `部分失败` | `writer_health.writer_partial_failures` | +| `队列溢出` | `writer_health.writer_queue_overflows` | +| `诊断说明` | `writer_health.error`, only when non-empty | + +## 6. Non-Goals + +- Do not read or persist finish callback `writer_health`. +- Do not add a new database column. +- Do not add list-page display or filters. +- Do not render full episode metadata JSON in Synapse. +- Do not change QA status. +- Do not auto-fail QA based on writer health. +- Do not block upload completion when sidecar diagnostics are missing or + malformed. + +## 7. Backend Implementation Notes + +Recommended helper responsibilities: + +- Parse sidecar `writer_health` as a raw JSON object or a typed struct that + preserves the known fields and `error`. +- Build episode metadata from existing metadata plus optional + `recorder.writer_health`. +- Merge metadata idempotently: + - parse existing metadata into `map[string]any`; + - ensure `recorder` is a `map[string]any`; + - set `recorder["writer_health"]`; + - marshal back to JSON. +- On duplicate upload completion where the episode already exists, update the + existing episode metadata only if sidecar `writer_health` is present. + +Recommended API changes: + +- Add `metadata` to the `Episode` response model. +- Add `metadata` to the detail query only. +- Keep list queries unchanged. + +## 8. Test Plan + +Backend tests: + +- Sidecar with `writer_health` creates an episode whose + `metadata.recorder.writer_health` matches the sidecar. +- Sidecar without `writer_health` creates an episode without + `metadata.recorder.writer_health`. +- Existing episode plus sidecar `writer_health` backfills or overwrites only + `metadata.recorder.writer_health`. +- Existing metadata fields, including `asset_id` and other `recorder` children, + are preserved during backfill. +- `GET /api/v1/episodes/:id` returns metadata. +- `GET /api/v1/episodes` does not return metadata. + +Frontend verification: + +- `npm run build` passes. +- Episode detail does not show the recording diagnostics panel for old episodes + without writer health. +- Episode detail shows normal, warning, critical, and unknown states correctly. +- Empty or null `error` is hidden. + diff --git a/internal/api/handlers/episode.go b/internal/api/handlers/episode.go index 26bce51..9debc77 100644 --- a/internal/api/handlers/episode.go +++ b/internal/api/handlers/episode.go @@ -104,6 +104,7 @@ type episodeRow struct { CloudSyncedAt sql.NullTime `db:"cloud_synced_at"` CreatedAt time.Time `db:"created_at"` LabelsJSON sql.NullString `db:"labels"` + Metadata sql.NullString `db:"metadata"` } // Episode represents an episode in the API response @@ -135,6 +136,7 @@ type Episode struct { CloudSyncedAt *string `json:"cloud_synced_at"` CreatedAt string `json:"created_at"` Labels []string `json:"labels"` + Metadata any `json:"metadata,omitempty"` } // EpisodeListResponse represents the response for listing episodes @@ -632,7 +634,8 @@ func (h *EpisodeHandler) GetEpisode(c *gin.Context) { e.cloud_processed, e.cloud_synced_at, e.created_at, - e.labels + e.labels, + e.metadata FROM episodes e LEFT JOIN tasks t ON t.id = e.task_id AND t.deleted_at IS NULL LEFT JOIN sops s ON s.id = t.sop_id AND s.deleted_at IS NULL @@ -685,5 +688,6 @@ func (h *EpisodeHandler) GetEpisode(c *gin.Context) { CloudSyncedAt: nullableTime(row.CloudSyncedAt), CreatedAt: row.CreatedAt.UTC().Format(time.RFC3339), Labels: episodeLabelsFromDB(row.LabelsJSON), + Metadata: parseJSONRaw(row.Metadata.String), }) } diff --git a/internal/api/handlers/episode_metadata_test.go b/internal/api/handlers/episode_metadata_test.go new file mode 100644 index 0000000..c2b30ca --- /dev/null +++ b/internal/api/handlers/episode_metadata_test.go @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: 2026 ArcheBase +// +// SPDX-License-Identifier: MulanPSL-2.0 + +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/jmoiron/sqlx" + _ "modernc.org/sqlite" +) + +func TestGetEpisodeReturnsMetadata(t *testing.T) { + db := openEpisodeMetadataTestDB(t) + defer db.Close() + seedEpisodeMetadataTestRow(t, db) + + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewEpisodeHandler(db, nil, "", nil) + router.GET("/episodes/:id", handler.GetEpisode) + + req := httptest.NewRequest(http.MethodGet, "/episodes/1", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + metadata, ok := body["metadata"].(map[string]any) + if !ok { + t.Fatalf("metadata type=%T body=%#v", body["metadata"], body) + } + recorder, ok := metadata["recorder"].(map[string]any) + if !ok { + t.Fatalf("recorder type=%T metadata=%#v", metadata["recorder"], metadata) + } + writerHealth, ok := recorder["writer_health"].(map[string]any) + if !ok { + t.Fatalf("writer_health type=%T recorder=%#v", recorder["writer_health"], recorder) + } + if writerHealth["state"] != "warning" { + t.Fatalf("writer_health.state=%v want warning", writerHealth["state"]) + } +} + +func TestListEpisodesOmitsMetadata(t *testing.T) { + db := openEpisodeMetadataTestDB(t) + defer db.Close() + seedEpisodeMetadataTestRow(t, db) + + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewEpisodeHandler(db, nil, "", nil) + router.GET("/episodes", handler.ListEpisodes) + + req := httptest.NewRequest(http.MethodGet, "/episodes", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + + var body struct { + Items []map[string]any `json:"items"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if len(body.Items) != 1 { + t.Fatalf("items=%d want 1", len(body.Items)) + } + if _, ok := body.Items[0]["metadata"]; ok { + t.Fatalf("list item unexpectedly contains metadata: %#v", body.Items[0]["metadata"]) + } +} + +func openEpisodeMetadataTestDB(t *testing.T) *sqlx.DB { + t.Helper() + db, err := sqlx.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite db: %v", err) + } + for _, stmt := range []string{ + `CREATE TABLE episodes ( + id INTEGER PRIMARY KEY, + episode_id TEXT NOT NULL, + task_id INTEGER NOT NULL, + workstation_id INTEGER, + mcap_path TEXT NOT NULL, + sidecar_path TEXT NOT NULL, + checksum TEXT, + file_size_bytes INTEGER, + duration_sec REAL, + qa_status TEXT, + qa_score REAL, + quality_flag TEXT, + auto_approved BOOLEAN DEFAULT FALSE, + cloud_synced BOOLEAN DEFAULT FALSE, + cloud_processed BOOLEAN DEFAULT FALSE, + cloud_synced_at TIMESTAMP NULL, + created_at TIMESTAMP NOT NULL, + labels TEXT, + metadata TEXT, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE tasks ( + id INTEGER PRIMARY KEY, + task_id TEXT, + sop_id INTEGER, + scene_name TEXT, + subscene_name TEXT, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE sops ( + id INTEGER PRIMARY KEY, + slug TEXT, + version TEXT, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE workstations ( + id INTEGER PRIMARY KEY, + robot_id INTEGER, + data_collector_id INTEGER, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE robots ( + id INTEGER PRIMARY KEY, + device_id TEXT, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE data_collectors ( + id INTEGER PRIMARY KEY, + operator_id TEXT, + deleted_at TIMESTAMP NULL + )`, + `CREATE TABLE inspections ( + episode_id INTEGER, + inspector_id INTEGER, + decision TEXT, + inspected_at TIMESTAMP NULL + )`, + `CREATE TABLE inspectors ( + id INTEGER PRIMARY KEY, + inspector_id TEXT + )`, + } { + if _, err := db.Exec(stmt); err != nil { + t.Fatalf("create schema: %v", err) + } + } + return db +} + +func seedEpisodeMetadataTestRow(t *testing.T, db *sqlx.DB) { + t.Helper() + metadata := `{"asset_id":"asset-1","recorder":{"writer_health":{"state":"warning","writer_stall_state":"normal","writer_stall_suspected":false,"writer_partial_failures":0,"writer_queue_overflows":0,"error":null}}}` + if _, err := db.Exec(` + INSERT INTO tasks (id, task_id, sop_id, scene_name, subscene_name, deleted_at) + VALUES (10, 'task-public-1', NULL, 'scene', 'subscene', NULL) + `); err != nil { + t.Fatalf("seed task: %v", err) + } + if _, err := db.Exec(` + INSERT INTO episodes ( + id, episode_id, task_id, workstation_id, mcap_path, sidecar_path, + checksum, file_size_bytes, duration_sec, qa_status, qa_score, + quality_flag, auto_approved, cloud_synced, cloud_processed, + cloud_synced_at, created_at, labels, metadata, deleted_at + ) VALUES ( + 1, 'episode-public-1', 10, NULL, 'bucket/a.mcap', 'bucket/a.json', + 'abc', 1024, 12.5, 'pending_qa', NULL, + NULL, FALSE, FALSE, FALSE, + NULL, '2026-06-24T00:00:00Z', '[]', ?, NULL + ) + `, metadata); err != nil { + t.Fatalf("seed episode: %v", err) + } +} diff --git a/internal/api/handlers/transfer.go b/internal/api/handlers/transfer.go index 76c667b..25bee9d 100644 --- a/internal/api/handlers/transfer.go +++ b/internal/api/handlers/transfer.go @@ -403,6 +403,7 @@ type sidecarTopicSummary struct { type sidecarJSON struct { Recording sidecarRecording `json:"recording"` TopicsSummary []sidecarTopicSummary `json:"topics_summary"` + WriterHealth json.RawMessage `json:"writer_health"` } // readSidecarFromS3 downloads the sidecar JSON object from MinIO and returns the parsed result. @@ -453,6 +454,47 @@ func assetIDSnapshotMetadata(ctx context.Context, tx *sql.Tx, workstationID sql. return sql.NullString{String: string(data), Valid: true} } +func sidecarWriterHealthMetadata(sc *sidecarJSON) (map[string]any, bool) { + if sc == nil { + return nil, false + } + raw := strings.TrimSpace(string(sc.WriterHealth)) + if raw == "" || raw == "null" { + return nil, false + } + var writerHealth map[string]any + if err := json.Unmarshal([]byte(raw), &writerHealth); err != nil || writerHealth == nil { + return nil, false + } + return writerHealth, true +} + +func mergeRecorderWriterHealthMetadata(existing sql.NullString, writerHealth map[string]any) sql.NullString { + if writerHealth == nil { + return existing + } + + metadata := map[string]any{} + if existing.Valid && strings.TrimSpace(existing.String) != "" && strings.TrimSpace(existing.String) != "null" { + if err := json.Unmarshal([]byte(existing.String), &metadata); err != nil || metadata == nil { + return existing + } + } + + recorder, _ := metadata["recorder"].(map[string]any) + if recorder == nil { + recorder = map[string]any{} + } + recorder["writer_health"] = writerHealth + metadata["recorder"] = recorder + + data, err := json.Marshal(metadata) + if err != nil { + return existing + } + return sql.NullString{String: string(data), Valid: true} +} + func uploadCompleteS3Key(data map[string]interface{}) string { return strings.TrimSpace(stringVal(data, "s3_key")) } @@ -667,15 +709,23 @@ func (h *TransferHandler) onUploadComplete(ctx context.Context, dc *services.Tra // Idempotency: avoid creating duplicate episodes for the same task. // This keeps batches.episode_count correct even if the device retries uploads. - var existingEpisodeID string + var existingEpisode struct { + ID int64 `db:"id"` + EpisodeID string `db:"episode_id"` + Metadata sql.NullString `db:"metadata"` + } err := tx.QueryRowContext(ctx, ` - SELECT episode_id + SELECT id, episode_id, metadata FROM episodes WHERE task_id = ? AND deleted_at IS NULL LIMIT 1 - `, taskRow.ID).Scan(&existingEpisodeID) + `, taskRow.ID).Scan( + &existingEpisode.ID, + &existingEpisode.EpisodeID, + &existingEpisode.Metadata, + ) - if err == nil && existingEpisodeID == "" { + if err == nil && existingEpisode.EpisodeID == "" { // #nosec G706 -- Set aside for now logger.Printf("%s data corruption: empty episode_id found for task_pk=%d", transferTaskLogPrefix(dc.DeviceID, taskID), taskRow.ID) return @@ -686,7 +736,22 @@ func (h *TransferHandler) onUploadComplete(ctx context.Context, dc *services.Tra return } - if errors.Is(err, sql.ErrNoRows) { + if err == nil { + if writerHealth, ok := sidecarWriterHealthMetadata(sc); ok { + mergedMetadata := mergeRecorderWriterHealthMetadata(existingEpisode.Metadata, writerHealth) + if mergedMetadata.Valid && mergedMetadata.String != existingEpisode.Metadata.String { + if _, dbErr := tx.ExecContext(ctx, ` + UPDATE episodes + SET metadata = ?, updated_at = ? + WHERE id = ? AND deleted_at IS NULL + `, mergedMetadata, time.Now().UTC(), existingEpisode.ID); dbErr != nil { + // #nosec G706 -- Set aside for now + logger.Printf("%s DB metadata backfill failed for episode=%s: %v", transferTaskLogPrefix(dc.DeviceID, taskID), existingEpisode.EpisodeID, dbErr) + return + } + } + } + } else if errors.Is(err, sql.ErrNoRows) { episodeID := uuid.New().String() // Extract recording metadata from sidecar JSON (nullable). @@ -705,6 +770,9 @@ func (h *TransferHandler) onUploadComplete(ctx context.Context, dc *services.Tra } } episodeMetadata := assetIDSnapshotMetadata(ctx, tx, taskRow.WorkstationID) + if writerHealth, ok := sidecarWriterHealthMetadata(sc); ok { + episodeMetadata = mergeRecorderWriterHealthMetadata(episodeMetadata, writerHealth) + } insertRes, dbErr := tx.ExecContext(ctx, `INSERT INTO episodes ( diff --git a/internal/api/handlers/transfer_asset_id_snapshot_test.go b/internal/api/handlers/transfer_asset_id_snapshot_test.go index 9d71be7..787a8a5 100644 --- a/internal/api/handlers/transfer_asset_id_snapshot_test.go +++ b/internal/api/handlers/transfer_asset_id_snapshot_test.go @@ -74,6 +74,117 @@ func TestAssetIDSnapshotMetadata_MissingDoesNotFailEpisodeCreationPath(t *testin } } +func TestSidecarWriterHealthMetadata_ReadsTopLevelObject(t *testing.T) { + var sc sidecarJSON + if err := json.Unmarshal([]byte(`{ + "writer_health": { + "state": "critical", + "writer_stall_state": "critical", + "writer_stall_suspected": true, + "writer_partial_failures": 2, + "writer_queue_overflows": 1, + "error": "writer_partial_failures=2" + } + }`), &sc); err != nil { + t.Fatalf("unmarshal sidecar: %v", err) + } + + got, ok := sidecarWriterHealthMetadata(&sc) + if !ok { + t.Fatal("writer_health was not detected") + } + if got["state"] != "critical" { + t.Fatalf("state=%v want critical", got["state"]) + } + if got["writer_stall_suspected"] != true { + t.Fatalf("writer_stall_suspected=%v want true", got["writer_stall_suspected"]) + } +} + +func TestSidecarWriterHealthMetadata_MissingDoesNotWrite(t *testing.T) { + got, ok := sidecarWriterHealthMetadata(&sidecarJSON{}) + if ok || got != nil { + t.Fatalf("writer_health=%#v ok=%t, want nil false", got, ok) + } +} + +func TestMergeRecorderWriterHealthMetadata_PreservesExistingFields(t *testing.T) { + existing := sql.NullString{ + String: `{"asset_id":"asset-1","recorder":{"profile":"high_rate"},"owner":"ops"}`, + Valid: true, + } + writerHealth := map[string]any{ + "state": "critical", + "writer_stall_state": "critical", + "writer_stall_suspected": true, + "writer_partial_failures": float64(2), + "writer_queue_overflows": float64(1), + "error": "writer_partial_failures=2", + } + + got := mergeRecorderWriterHealthMetadata(existing, writerHealth) + if !got.Valid { + t.Fatal("metadata was not written") + } + var decoded map[string]any + if err := json.Unmarshal([]byte(got.String), &decoded); err != nil { + t.Fatalf("unmarshal metadata: %v", err) + } + if decoded["asset_id"] != "asset-1" || decoded["owner"] != "ops" { + t.Fatalf("existing metadata not preserved: %#v", decoded) + } + recorder, ok := decoded["recorder"].(map[string]any) + if !ok { + t.Fatalf("recorder metadata type=%T", decoded["recorder"]) + } + if recorder["profile"] != "high_rate" { + t.Fatalf("recorder.profile=%v want high_rate", recorder["profile"]) + } + health, ok := recorder["writer_health"].(map[string]any) + if !ok { + t.Fatalf("writer_health type=%T", recorder["writer_health"]) + } + if health["state"] != "critical" { + t.Fatalf("writer_health.state=%v want critical", health["state"]) + } +} + +func TestMergeRecorderWriterHealthMetadata_OverwritesOnlyWriterHealth(t *testing.T) { + existing := sql.NullString{ + String: `{"recorder":{"profile":"high_rate","writer_health":{"state":"warning"}}}`, + Valid: true, + } + writerHealth := map[string]any{ + "state": "critical", + "writer_partial_failures": float64(2), + } + + got := mergeRecorderWriterHealthMetadata(existing, writerHealth) + if !got.Valid { + t.Fatal("metadata was not written") + } + var decoded map[string]any + if err := json.Unmarshal([]byte(got.String), &decoded); err != nil { + t.Fatalf("unmarshal metadata: %v", err) + } + recorder := decoded["recorder"].(map[string]any) + if recorder["profile"] != "high_rate" { + t.Fatalf("recorder.profile=%v want high_rate", recorder["profile"]) + } + health := recorder["writer_health"].(map[string]any) + if health["state"] != "critical" { + t.Fatalf("writer_health.state=%v want critical", health["state"]) + } +} + +func TestMergeRecorderWriterHealthMetadata_InvalidExistingPreserved(t *testing.T) { + existing := sql.NullString{String: `{invalid`, Valid: true} + got := mergeRecorderWriterHealthMetadata(existing, map[string]any{"state": "normal"}) + if !got.Valid || got.String != existing.String { + t.Fatalf("metadata=%#v want original invalid metadata", got) + } +} + func createAssetIDSnapshotSchema(t *testing.T, db *sql.DB) { t.Helper() for _, stmt := range []string{ From 1513aecfb50bc785a1b50f11ba8ff637230f04d4 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Wed, 24 Jun 2026 16:51:37 +0800 Subject: [PATCH 5/6] feat(transfer): persist recorder version metadata --- .../api/handlers/episode_metadata_test.go | 9 ++- internal/api/handlers/transfer.go | 66 ++++++++++++------- .../transfer_asset_id_snapshot_test.go | 54 +++++++++++++-- 3 files changed, 98 insertions(+), 31 deletions(-) diff --git a/internal/api/handlers/episode_metadata_test.go b/internal/api/handlers/episode_metadata_test.go index c2b30ca..a2983d8 100644 --- a/internal/api/handlers/episode_metadata_test.go +++ b/internal/api/handlers/episode_metadata_test.go @@ -52,6 +52,13 @@ func TestGetEpisodeReturnsMetadata(t *testing.T) { if writerHealth["state"] != "warning" { t.Fatalf("writer_health.state=%v want warning", writerHealth["state"]) } + recording, ok := recorder["recording"].(map[string]any) + if !ok { + t.Fatalf("recording type=%T recorder=%#v", recorder["recording"], recorder) + } + if recording["recorder_version"] != "axon_recorder 0.5.0" { + t.Fatalf("recorder.recording.recorder_version=%v want axon_recorder 0.5.0", recording["recorder_version"]) + } } func TestListEpisodesOmitsMetadata(t *testing.T) { @@ -165,7 +172,7 @@ func openEpisodeMetadataTestDB(t *testing.T) *sqlx.DB { func seedEpisodeMetadataTestRow(t *testing.T, db *sqlx.DB) { t.Helper() - metadata := `{"asset_id":"asset-1","recorder":{"writer_health":{"state":"warning","writer_stall_state":"normal","writer_stall_suspected":false,"writer_partial_failures":0,"writer_queue_overflows":0,"error":null}}}` + metadata := `{"asset_id":"asset-1","recorder":{"recording":{"recorder_version":"axon_recorder 0.5.0"},"writer_health":{"state":"warning","writer_stall_state":"normal","writer_stall_suspected":false,"writer_partial_failures":0,"writer_queue_overflows":0,"error":null}}}` if _, err := db.Exec(` INSERT INTO tasks (id, task_id, sop_id, scene_name, subscene_name, deleted_at) VALUES (10, 'task-public-1', NULL, 'scene', 'subscene', NULL) diff --git a/internal/api/handlers/transfer.go b/internal/api/handlers/transfer.go index 25bee9d..e63339e 100644 --- a/internal/api/handlers/transfer.go +++ b/internal/api/handlers/transfer.go @@ -388,11 +388,12 @@ func (h *TransferHandler) onUploadProgress(dc *services.TransferConn, msg map[st // sidecarRecording is the subset of the sidecar JSON "recording" block we care about. type sidecarRecording struct { - DurationSec float64 `json:"duration_sec"` - FileSizeBytes int64 `json:"file_size_bytes"` - ChecksumSHA256 string `json:"checksum_sha256"` - MessageCount int64 `json:"message_count"` - TopicsRecorded []string `json:"topics_recorded"` + DurationSec float64 `json:"duration_sec"` + FileSizeBytes int64 `json:"file_size_bytes"` + ChecksumSHA256 string `json:"checksum_sha256"` + MessageCount int64 `json:"message_count"` + TopicsRecorded []string `json:"topics_recorded"` + RecorderVersion string `json:"recorder_version"` } type sidecarTopicSummary struct { @@ -469,8 +470,17 @@ func sidecarWriterHealthMetadata(sc *sidecarJSON) (map[string]any, bool) { return writerHealth, true } -func mergeRecorderWriterHealthMetadata(existing sql.NullString, writerHealth map[string]any) sql.NullString { - if writerHealth == nil { +func sidecarRecorderVersionMetadata(sc *sidecarJSON) (string, bool) { + if sc == nil { + return "", false + } + version := strings.TrimSpace(sc.Recording.RecorderVersion) + return version, version != "" +} + +func mergeRecorderMetadata(existing sql.NullString, writerHealth map[string]any, recorderVersion string) sql.NullString { + recorderVersion = strings.TrimSpace(recorderVersion) + if writerHealth == nil && recorderVersion == "" { return existing } @@ -485,7 +495,17 @@ func mergeRecorderWriterHealthMetadata(existing sql.NullString, writerHealth map if recorder == nil { recorder = map[string]any{} } - recorder["writer_health"] = writerHealth + if writerHealth != nil { + recorder["writer_health"] = writerHealth + } + if recorderVersion != "" { + recording, _ := recorder["recording"].(map[string]any) + if recording == nil { + recording = map[string]any{} + } + recording["recorder_version"] = recorderVersion + recorder["recording"] = recording + } metadata["recorder"] = recorder data, err := json.Marshal(metadata) @@ -737,18 +757,18 @@ func (h *TransferHandler) onUploadComplete(ctx context.Context, dc *services.Tra } if err == nil { - if writerHealth, ok := sidecarWriterHealthMetadata(sc); ok { - mergedMetadata := mergeRecorderWriterHealthMetadata(existingEpisode.Metadata, writerHealth) - if mergedMetadata.Valid && mergedMetadata.String != existingEpisode.Metadata.String { - if _, dbErr := tx.ExecContext(ctx, ` - UPDATE episodes - SET metadata = ?, updated_at = ? - WHERE id = ? AND deleted_at IS NULL - `, mergedMetadata, time.Now().UTC(), existingEpisode.ID); dbErr != nil { - // #nosec G706 -- Set aside for now - logger.Printf("%s DB metadata backfill failed for episode=%s: %v", transferTaskLogPrefix(dc.DeviceID, taskID), existingEpisode.EpisodeID, dbErr) - return - } + writerHealth, _ := sidecarWriterHealthMetadata(sc) + recorderVersion, _ := sidecarRecorderVersionMetadata(sc) + mergedMetadata := mergeRecorderMetadata(existingEpisode.Metadata, writerHealth, recorderVersion) + if mergedMetadata.Valid && mergedMetadata.String != existingEpisode.Metadata.String { + if _, dbErr := tx.ExecContext(ctx, ` + UPDATE episodes + SET metadata = ?, updated_at = ? + WHERE id = ? AND deleted_at IS NULL + `, mergedMetadata, time.Now().UTC(), existingEpisode.ID); dbErr != nil { + // #nosec G706 -- Set aside for now + logger.Printf("%s DB metadata backfill failed for episode=%s: %v", transferTaskLogPrefix(dc.DeviceID, taskID), existingEpisode.EpisodeID, dbErr) + return } } } else if errors.Is(err, sql.ErrNoRows) { @@ -770,9 +790,9 @@ func (h *TransferHandler) onUploadComplete(ctx context.Context, dc *services.Tra } } episodeMetadata := assetIDSnapshotMetadata(ctx, tx, taskRow.WorkstationID) - if writerHealth, ok := sidecarWriterHealthMetadata(sc); ok { - episodeMetadata = mergeRecorderWriterHealthMetadata(episodeMetadata, writerHealth) - } + writerHealth, _ := sidecarWriterHealthMetadata(sc) + recorderVersion, _ := sidecarRecorderVersionMetadata(sc) + episodeMetadata = mergeRecorderMetadata(episodeMetadata, writerHealth, recorderVersion) insertRes, dbErr := tx.ExecContext(ctx, `INSERT INTO episodes ( diff --git a/internal/api/handlers/transfer_asset_id_snapshot_test.go b/internal/api/handlers/transfer_asset_id_snapshot_test.go index 787a8a5..8a07460 100644 --- a/internal/api/handlers/transfer_asset_id_snapshot_test.go +++ b/internal/api/handlers/transfer_asset_id_snapshot_test.go @@ -108,7 +108,33 @@ func TestSidecarWriterHealthMetadata_MissingDoesNotWrite(t *testing.T) { } } -func TestMergeRecorderWriterHealthMetadata_PreservesExistingFields(t *testing.T) { +func TestSidecarRecorderVersionMetadata_ReadsRecordingBlock(t *testing.T) { + var sc sidecarJSON + if err := json.Unmarshal([]byte(`{ + "recording": { + "recorder_version": "axon_recorder 0.5.0" + } + }`), &sc); err != nil { + t.Fatalf("unmarshal sidecar: %v", err) + } + + got, ok := sidecarRecorderVersionMetadata(&sc) + if !ok { + t.Fatal("recorder_version was not detected") + } + if got != "axon_recorder 0.5.0" { + t.Fatalf("recorder_version=%q want axon_recorder 0.5.0", got) + } +} + +func TestSidecarRecorderVersionMetadata_MissingDoesNotWrite(t *testing.T) { + got, ok := sidecarRecorderVersionMetadata(&sidecarJSON{}) + if ok || got != "" { + t.Fatalf("recorder_version=%q ok=%t, want empty false", got, ok) + } +} + +func TestMergeRecorderMetadata_PreservesExistingFields(t *testing.T) { existing := sql.NullString{ String: `{"asset_id":"asset-1","recorder":{"profile":"high_rate"},"owner":"ops"}`, Valid: true, @@ -122,7 +148,7 @@ func TestMergeRecorderWriterHealthMetadata_PreservesExistingFields(t *testing.T) "error": "writer_partial_failures=2", } - got := mergeRecorderWriterHealthMetadata(existing, writerHealth) + got := mergeRecorderMetadata(existing, writerHealth, "axon_recorder 0.5.0") if !got.Valid { t.Fatal("metadata was not written") } @@ -147,11 +173,18 @@ func TestMergeRecorderWriterHealthMetadata_PreservesExistingFields(t *testing.T) if health["state"] != "critical" { t.Fatalf("writer_health.state=%v want critical", health["state"]) } + recording, ok := recorder["recording"].(map[string]any) + if !ok { + t.Fatalf("recording type=%T", recorder["recording"]) + } + if recording["recorder_version"] != "axon_recorder 0.5.0" { + t.Fatalf("recorder.recording.recorder_version=%v want axon_recorder 0.5.0", recording["recorder_version"]) + } } -func TestMergeRecorderWriterHealthMetadata_OverwritesOnlyWriterHealth(t *testing.T) { +func TestMergeRecorderMetadata_OverwritesOnlyRecorderFields(t *testing.T) { existing := sql.NullString{ - String: `{"recorder":{"profile":"high_rate","writer_health":{"state":"warning"}}}`, + String: `{"recorder":{"profile":"high_rate","writer_health":{"state":"warning"},"recording":{"recorder_version":"old","duration_sec":12.5}}}`, Valid: true, } writerHealth := map[string]any{ @@ -159,7 +192,7 @@ func TestMergeRecorderWriterHealthMetadata_OverwritesOnlyWriterHealth(t *testing "writer_partial_failures": float64(2), } - got := mergeRecorderWriterHealthMetadata(existing, writerHealth) + got := mergeRecorderMetadata(existing, writerHealth, "new") if !got.Valid { t.Fatal("metadata was not written") } @@ -175,11 +208,18 @@ func TestMergeRecorderWriterHealthMetadata_OverwritesOnlyWriterHealth(t *testing if health["state"] != "critical" { t.Fatalf("writer_health.state=%v want critical", health["state"]) } + recording := recorder["recording"].(map[string]any) + if recording["recorder_version"] != "new" { + t.Fatalf("recorder.recording.recorder_version=%v want new", recording["recorder_version"]) + } + if recording["duration_sec"] != 12.5 { + t.Fatalf("recorder.recording.duration_sec=%v want 12.5", recording["duration_sec"]) + } } -func TestMergeRecorderWriterHealthMetadata_InvalidExistingPreserved(t *testing.T) { +func TestMergeRecorderMetadata_InvalidExistingPreserved(t *testing.T) { existing := sql.NullString{String: `{invalid`, Valid: true} - got := mergeRecorderWriterHealthMetadata(existing, map[string]any{"state": "normal"}) + got := mergeRecorderMetadata(existing, map[string]any{"state": "normal"}, "axon_recorder 0.5.0") if !got.Valid || got.String != existing.String { t.Fatalf("metadata=%#v want original invalid metadata", got) } From 5578e4862ea4325675a0473b517f8661056ea047 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Thu, 25 Jun 2026 18:18:37 +0800 Subject: [PATCH 6/6] fix: fix lint issues --- internal/api/handlers/ws_client_auth.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/api/handlers/ws_client_auth.go b/internal/api/handlers/ws_client_auth.go index 037e7f3..610c940 100644 --- a/internal/api/handlers/ws_client_auth.go +++ b/internal/api/handlers/ws_client_auth.go @@ -87,6 +87,7 @@ func (h *RecorderHandler) authorizeRecorderWebSocket(w http.ResponseWriter, r *h return false } + // #nosec G701 -- static SQL with placeholder-bound token usage values. if _, err := h.db.ExecContext(queryCtx, ` UPDATE ws_client_auth_tokens SET last_used_at = ?