diff --git a/docs/zh/design/hooks-design.md b/docs/zh/design/hooks-design.md new file mode 100644 index 000000000..fad01a93a --- /dev/null +++ b/docs/zh/design/hooks-design.md @@ -0,0 +1,2419 @@ +# Hooks 系统设计文档 + +> 参考 Claude Code / Cursor / Codex / Hermes Agent 的 shell hook 协议与社区生态。 +> +> 原始产品需求见仓库根目录 `playground_prototype_design.md`(F6 Hooks 系统、F9 Plugins 兼容、智能体设置中的 Hooks 配置)。 +> +> 本文档是 Hooks 模块的**完整可执行方案**,涵盖子进程协议、Canonical 事件体系、多平台配置加载、匹配器、执行引擎、与旧 Callback 的桥接、与权限系统的协作、以及 Claude / Cursor / Hermes 三方 shell hook 生态兼容边界。 + +--- + +## 目录 + +- [1. 现状分析](#1-现状分析) +- [2. 总体架构](#2-总体架构) +- [3. 子进程协议](#3-子进程协议) +- [4. 事件体系](#4-事件体系) +- [5. 配置格式](#5-配置格式) +- [6. 匹配器](#6-匹配器) +- [7. HookRegistry — 配置加载与合并](#7-hookregistry--配置加载与合并) +- [8. HookExecutor — 执行引擎(Dispatcher)](#8-hookexecutor--执行引擎dispatcher--command-后端) +- [9. CallbackToHookBridge — 向后兼容桥](#9-callbacktohookbridge--向后兼容桥) +- [10. 与权限系统的协作](#10-与权限系统的协作) +- [11. 集成点与代码变更](#11-集成点与代码变更) +- [12. 文件结构](#12-文件结构) +- [13. 与外部生态的对比](#13-与外部生态的对比) +- [14. 验证方式](#14-验证方式) +- [15. 多平台生态兼容设计](#15-多平台生态兼容设计) +- [16. 分阶段交付与验收](#16-分阶段交付与验收) +- [17. 扩展 Executor:HTTP / Prompt / Agent](#17-扩展-executorhttppromptagent) +- [附录 A:Hook Handler 类型与应用场景](#附录-ahook-handler-类型与应用场景) +- [附录 B:Hermes 三套 Hook 体系与功能关系](#附录-bhermes-三套-hook-体系与功能关系) +- [附录 C:实现待办与跨文档约定](#附录-c实现待办与跨文档约定) + +--- + +## 1. 现状分析 + +### 1.1 当前 Callback 机制的问题 + +| 问题 | 说明 | +|------|------| +| **需继承 Python 类** | 用户必须写 `class MyCallback(Callback)` 子类,无法用 Shell/Node 等 | +| **需 `trust_remote_code`** | 加载用户脚本时必须开启 trust,存在安全隐患 | +| **仅 5 个固定方法** | `on_task_begin`、`on_generate_response`、`on_tool_call`、`after_tool_call`、`on_task_end` | +| **无阻断能力** | 所有 Callback 方法返回 `None`,无法拒绝或修改工具调用 | +| **无外部脚本扩展** | 无法从外部注入策略脚本,社区生态无法复用 | + +### 1.2 现有 Callback 类盘点 + +```python +# ms_agent/callbacks/base.py +class Callback: + async def on_task_begin(self, runtime, messages) -> None + async def on_generate_response(self, runtime, messages) -> None + async def on_tool_call(self, runtime, messages) -> None + async def after_tool_call(self, runtime, messages) -> None + async def on_task_end(self, runtime, messages) -> None +``` + +唯一内置实现:`InputCallback` — 在 `after_tool_call` 中等待用户输入,控制多轮对话。 + +### 1.3 工具管线 + +`ms_agent/tools/tool_manager.py` 已重写,工具调用统一经 `single_call_tool()`: + +- `LLMAgent.parallel_tool_call()` → `ToolManager.parallel_call_tool()` → **N × `single_call_tool()`** +- 权限双层检查已就位:`SafetyGuard`(L296–308)→ `PermissionEnforcer`(L309–315)→ `call_tool()`(L337–343) +- 工具名格式:`{server_name}---{tool_name}`(`TOOL_SPLITER = '---'`),与 Hooks matcher 及 `permission-design.md` 一致 + +Hooks 模块的 **PreToolUse / PostToolUse 应插入此函数**,而非依赖 `Callback.on_tool_call`(触发时机在 `parallel_tool_call` 之前,无法拦截 `single_call_tool` 内部逻辑)。 + +### 1.4 设计目标 + +对齐 `playground_prototype_design.md` F6 / 智能体设置「Hooks:支持 python 和 sh,不需要继承父类」: + +1. **语言中立**:支持 Python、Shell、任意可执行文件,无需继承 `Callback` +2. **可阻断**:关键事件(`PreToolUse`、`UserPromptSubmit`、`Stop`)支持策略性阻断与参数改写 +3. **社区兼容**:协议对标 Claude Code / Codex / Cursor 的 `stdin/stdout/exit code` 约定;**优先兼容三家的 shell-based third-party hooks** +4. **多源配置**:除 `agent.yaml` 外,可选加载 `.claude/settings.json`、`.cursor/hooks.json`、Hermes shell hooks(`config.yaml`) +5. **向后兼容**:旧 Callback 与 `HookRuntime` 共存,不废弃 +6. **工具管线优先**:`PreToolUse` / `PostToolUse` / `PermissionRequest` 嵌入 `ToolManager.single_call_tool()`,与 `SafetyGuard` / `PermissionEnforcer` 同层协作 +7. **生命周期精确挂点**:`SessionStart` / `UserPromptSubmit` / `Stop` 在 `LLMAgent` 主循环的**语义等价位置**触发(见 §4.5、§9),**不复用** `on_generate_response` / `on_task_end` +8. **Plugin 联动**:`hooks/hooks.json` 经 `PluginLoader` 转换后并入 `HookRegistry`(F9) + +--- + +## 2. 总体架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ HookLoaders(多平台配置 → Canonical IR) │ +│ - NativeYamlLoader agent.yaml 或 ~/.ms_agent/hooks.yaml│ +│ - ClaudeSettingsLoader .claude/settings.json │ +│ - CursorHooksLoader .cursor/hooks.json │ +│ - HermesShellLoader ~/.hermes/config.yaml (hooks: 段) │ +│ - PluginHooksLoader plugin hooks/hooks.json (F9) │ +├─────────────────────────────────────────────────────────────┤ +│ HookRegistry(Canonical 事件 → MatcherGroup 索引) │ +│ - merge(sources...) 按优先级合并 │ +│ - get_handlers(canonical_event, normalized_tool_name) │ +├─────────────────────────────────────────────────────────────┤ +│ HookExecutorDispatcher(按 handler.type 路由) │ +│ - command → CommandHookExecutor(子进程,§8) │ +│ - http → HttpHookExecutor(§17.2) │ +│ - prompt → PromptHookExecutor(§17.3) │ +│ - agent → AgentHookExecutor(§17.4) │ +│ 统一出口 → ResponseAdapter → HookResult │ +├─────────────────────────────────────────────────────────────┤ +│ ToolNameMapper + PatternMatcher │ +│ - Bash/Shell/terminal ↔ ms-agent tool 名 │ +│ - fnmatch + | 分隔(与 permission 共用) │ +├─────────────────────────────────────────────────────────────┤ +│ HookRuntime(facade:registry + executor + mapper + adapter) │ +│ - run_pre_tool_use / run_post_tool_use → ToolManager 调用 │ +│ - run_session_start → CallbackToHookBridge │ +│ - run_user_prompt_submit / run_stop → LLMAgent 直接调用 │ +├─────────────────────────────────────────────────────────────┤ +│ CallbackToHookBridge(仅 SessionStart) │ +│ UserPromptSubmit / Stop → LLMAgent 直接调 HookRuntime(§9) │ +└─────────────────────────────────────────────────────────────┘ +``` + +**数据流** + +``` +【工具事件 — 主路径】 +LLMAgent.parallel_tool_call() + └─ ToolManager.parallel_call_tool() + └─ ToolManager.single_call_tool() + ├─ 1. SafetyGuard.check() + ├─ 2. HookRuntime.run_pre_tool_use() ← PreToolUse + ├─ 3. PermissionEnforcer.check() ← hooks pass 时 + ├─ 4. tool_ins.call_tool() + └─ 5. HookRuntime.run_post_tool_use() ← PostToolUse + +【非工具事件 — LLMAgent 主循环挂点】 +run_loop() / step() + ├─ create_messages() 或 InputCallback 追加 user 后 + │ └─ HookRuntime.run_user_prompt_submit() ← UserPromptSubmit + ├─ on_task_begin (round==0) + │ └─ HookRuntime.run_session_start() ← SessionStart + └─ after_tool_call() 判定 should_stop 之前 + └─ HookRuntime.run_stop() ← Stop +``` + +`on_generate_response` **不**映射 `UserPromptSubmit`——它在每轮 LLM 调用前触发,语义是「turn 内 pre-LLM」而非「用户提交 prompt」。`on_task_end` **不**映射 `Stop`——此时主循环已结束,无法再 `block` 停止决策。 + +`on_tool_call` **不再**作为 PreToolUse 的主触发点——它在 LLM 产出 tool_calls 之后、`parallel_tool_call` 之前触发,时机偏晚且无法改写 `single_call_tool` 内的参数与返回值。PR#906 之后工具类 hook 以 `ToolManager` 为准。 + +--- + +## 3. 子进程协议 + +### 3.1 核心协议 + +对标 Claude Code 和 Codex 的共同约定,语言中立: + +``` +stdin ──→ JSON 事件数据(紧凑格式,一次性写入后关闭 stdin) +stdout ←── JSON 决策数据(可选,仅需返回有意义的字段) +stderr ←── 错误信息或阻断原因文本 +exit code: + 0 = 通过(解析 stdout JSON 获取详细决策) + 2 = 阻断(策略性拒绝,stderr 为原因) + 1 / 其他 = 非阻断错误(脚本 bug 不应误拦,记 warning 后继续) +``` + +### 3.2 为什么 exit 2 专用于阻断 + +- `exit 1` 是最常见的脚本错误退出码,如果 `exit 1 = 阻断`,那么脚本中的任何 uncaught exception 都会导致工具被拒绝 +- `exit 2` 需要用户显式选择,必须有意为之才会触发 +- 这是 Claude Code 和 Codex 共同验证过的约定,避免脚本 bug 误拦 + +### 3.3 stdin 事件数据格式(CanonicalPayload) + +最小字段: + +```json +{ + "event": "PreToolUse", + "session_id": "abc123", + "tool_name": "code_executor---shell_executor", + "tool_args": { + "command": "pip install requests" + } +} +``` + +启用多平台兼容时,Executor 应附加 `tool_input`(与 `tool_args` 同值)及 `tool_name_claude` / `tool_name_cursor` / `tool_name_hermes` 等别名,详见 [§15.6](#156-stdin-canonicalpayload-格式)。 + +### 3.4 stdout 决策数据格式 + +```json +{ + "decision": "deny", + "reason": "Package installation not allowed in production", + "additionalContext": "Consider using a requirements.txt instead" +} +``` + +可选字段(ms-agent **原生 / Canonical** 格式): + +| 字段 | 类型 | 说明 | +|------|------|------| +| `decision` | `"allow"` / `"deny"` / `"block"` | 决策(不提供则默认 `"pass"` 即通过) | +| `reason` | `str` | 阻断/放行的原因 | +| `additionalContext` | `str` | 注入到后续 LLM 上下文的附加信息 | +| `updatedArgs` | `dict` | 修改后的工具参数(仅 PreToolUse) | + +### 3.6 外部生态 stdout 格式适配(ResponseAdapter) + +执行引擎在解析 stdout 时,除 Canonical 字段外,还应识别以下**社区常见格式**并归一化为 `HookResult`: + +| 来源 | 阻断/放行字段示例 | 归一化 `HookResult.action` | +|------|------------------|---------------------------| +| **ms-agent / Codex** | `{"decision": "deny\|allow", ...}` | 直接映射 | +| **Claude Code** | `hookSpecificOutput.permissionDecision: "deny\|allow\|ask"` | → `deny` / `allow` / `ask` | +| **Claude Code** | `{"decision": "approve"}` / `{"decision": "block"}` | → `allow` / `deny` | +| **Cursor** | `{"permission": "deny", "user_message": "..."}` | → `deny` | +| **Hermes shell** | `{"decision": "block", ...}` / `{"action": "block", ...}` | → `deny` | +| **通用** | exit code `2` + stderr 文本 | → `deny` | + +参数改写字段映射: + +| 来源 | 字段 | Canonical | +|------|------|-----------| +| Claude `hookSpecificOutput.updatedInput` | 工具参数对象 | `updated_args` | +| Cursor `updated_input` | 同上 | `updated_args` | +| ms-agent | `updatedArgs` | `updated_args` | + +> **社区兼容要点**:仅 `updated_args`、无 `permissionDecision` / `decision` 时 → `action=pass`(passthrough),只改参数,**不**改变 permission 决策。对齐 Claude `toolHooks.ts` L556–563。 + +上下文注入字段映射: + +| 来源 | 注入字段 | Canonical | +|------|---------|-----------| +| Claude / Cursor | `additional_context` / `agent_message` | `additionalContext` | +| Hermes | `{"context": "..."}` on `pre_llm_call` | `additionalContext`(映射到 `UserPromptSubmit` 或 turn 前注入) | +| Cursor `preToolUse` | `updated_input` | `updatedArgs` | + +#### 3.6.1 生态兼容 ≠ 原生执行所有 hook 类型 + +我们要兼容的是 **Claude Code / Cursor / Hermes /(部分)OpenClaw** 这些框架,但每个框架内部 hook 有**多种实现形态**,按优先级支持能在 ms-agent 里「原样加载、原样跑」。 + +| 层次 | v1 目标 | 说明 | +|------|---------|------| +| **生态层** | ✅ 兼容 | 识别各平台配置路径、Plugin 目录、Skills、`hooks.json` / `settings.json` | +| **可移植层(shell command)** | ✅ v1 主路径 | 子进程 + stdin/stdout JSON;三方社区 hook **绝大多数**属于此类 | +| **平台原生运行时** | ⏳ v2+ 或适配器 | 需 ms-agent 自己实现对应 **Executor 后端**,不能靠 spawn 脚本完成 | + +因此 §3.6 下列类型标为 **v1 非目标**,指的是 **不在 v1 实现其原生执行后端**,而不是「不做这些框架的兼容」: + +| 类型 | 所属生态 | 为何不纳入 v1 | v1 仍如何兼容该生态 | +|------|---------|--------------|-------------------| +| Claude **HTTP** hook | Claude Code | v1 仅 `command`;**P2** 原生 `HttpHookExecutor`(§17.2) | v1:loader warning + 可用 command 包装 `curl` | +| Claude **prompt** hook | Claude Code | v1 仅 `command`;**P2** `PromptHookExecutor`(§17.3) | 同上 | +| Claude **agent** hook | Claude Code | v1 仅 `command`;**P3** `AgentHookExecutor`(§17.4) | Stop 验证类场景 P3 补齐 | +| Hermes **Python plugin** hook | Hermes | `ctx.register_hook()` 是 **Hermes 进程内 API**,不是独立脚本 | v1 兼容 **Hermes shell hooks** + **同仓库内的 command 脚本**;Python plugin 需在 Hermes 中运行,或作者提供等价 `.sh`(见 [附录 B](#附录-bhermes-三套-hook-体系与功能关系)) | +| OpenClaw **typed `api.on()`** | OpenClaw | **TypeScript 进程内**中间件 | OpenClaw 对 Claude `hooks.json` 本身也是 detect-only;v1 不 ingest TS 模块 | +| Cursor **`type: prompt`** | Cursor | 同 Claude prompt,需 LLM 后端 | v2;v1 兼容 command hook | + +**总结**:兼容框架 = 兼容其 **配置发现、Plugin 打包、shell hook 脚本与阻断语义**;不等于在 ms-agent 内嵌 Claude/Hermes/OpenClaw 的完整 hook **虚拟机**。 + +#### 3.6.2 扩展 Executor 路线 + +| 能力 | Executor | 阶段 | 详见 | +|------|----------|------|------| +| `type: command` | `CommandHookExecutor` | **P0** | §8 | +| `type: http` | `HttpHookExecutor` | **P2** | §17.2 | +| `type: prompt` | `PromptHookExecutor` | **P2** | §17.3 | +| `type: agent` | `AgentHookExecutor` | **P3** | §17.4 | +| Hermes plugin 迁移 | 文档 + `hermes-to-shell` | P1 文档 | 附录 B | +| OpenClaw HOOK pack | command 转换或 TS 沙箱 | P3 | §17.6 | + +社区 hook 的 **形态分布**(经验值):command/shell **>80%**;HTTP/prompt/agent 多见于企业集成与官方 partner,v1 用 shell 覆盖主体场景后,再按需求加 Executor 类型。详见 [附录 A](#附录-ahook-handler-类型与应用场景)。 + +### 3.5 exit code 解析逻辑 + +```python +if exit_code == 0: + # 解析 stdout JSON + if stdout_json.get("decision") == "deny": + return HookResult(action="deny", reason=stdout_json.get("reason", "")) + elif stdout_json.get("decision") == "allow": + return HookResult(action="allow", ...) + else: + return HookResult(action="pass") # 无明确决策 = 通过 +elif exit_code == 2: + return HookResult(action="deny", reason=stderr_text) +else: + # 非阻断错误:记录 warning,继续(除非 fail_closed / handler.fail_closed 为 true,则视为 deny) + logger.warning(f"Hook script error (exit {exit_code}): {stderr_text}") + return HookResult(action="error", reason=stderr_text) +``` + +> **fail_closed**:全局 `hooks.fail_closed` 或 per-handler `failClosed` 为 `true` 时,超时、命令不存在、exit 1 等非 exit-2 错误在可阻断事件上视为 `deny`(§8.6)。 + +### 3.7 `deny` / `block` 与事件类型的归一化链 + +子进程协议层(exit 2、stdout `decision:block`)统一先产出 `HookResult(action="deny")`;**按事件类型**在消费端再映射: + +| 阶段 | PreToolUse / UserPromptSubmit / PermissionRequest | Stop | +|------|---------------------------------------------------|------| +| `CommandHookExecutor` / `ResponseAdapter` | `exit 2` → `deny`;`decision:block` → `deny` | 同上 | +| `HookExecutor.execute_all`(可阻断) | 短路时 `action="deny"` | 短路时保留 `action="block"`(若脚本直接返回 `block`) | +| `HookRuntime._run_event` | 原样传递 | `deny` → **`block`**(对齐 Claude「阻止停止并继续」) | +| 消费端 | `deny` 拒绝工具 / prompt | `block` → `append_stop_blocking_feedback()` | + +因此社区脚本在 **Stop** 上使用 `exit 2` 或 `{"decision":"block"}` 均可生效;无需脚本感知 ms-agent 的 `block` 与 `deny` 差异。 + +--- + +## 4. 事件体系 + +### 4.1 事件总览 + +| 事件 | 触发时机 | 可阻断 | 关键字段 | +|------|---------|--------|---------| +| `SessionStart` | `run_loop()` 开始 | 否 | `session_id`, `project_path` | +| `PreToolUse` | 工具调用执行前 | 是(`deny` / `allow`) | `tool_name`, `tool_args` | +| `PostToolUse` | 工具调用完成后 | 否(但可注入 `additionalContext`) | `tool_name`, `tool_args`, `tool_result` | +| `UserPromptSubmit` | **用户消息进入 agent 循环前**(见 §4.5) | 是 | `prompt` | +| `Stop` | **Agent 本轮将结束、尚未退出循环前**(见 §4.5) | 是(`block` = 阻止停止并继续) | `reason`, `last_assistant_message` | +| `PermissionRequest` | 权限请求时(interactive 模式,`resolve_hook_permission_decision` 内) | 是 | `tool_name`, `tool_args` | + +**配置可加载、运行时触发待 P2**(loader 已映射入 `VALID_EVENTS`,尚无 `HookRuntime.run_*` 挂点): + +| 事件 | 触发时机 | 可阻断 | +|------|---------|--------| +| `SubagentStop` | 子 agent 任务结束(**P2**) | 否(可注入 context) | + +**P2 可选扩展**(主要为 Cursor / Hermes 生态独立事件,见 §15.3): + +| 事件 | 触发时机 | 可阻断 | +|------|---------|--------| +| `ShellBefore` | 仅 shell 类工具执行前 | 是 | +| `FileAfterEdit` | 文件写入/编辑后 | 否(可触发 format) | + +### 4.2 事件数据结构 + +```python +from dataclasses import dataclass, field, asdict +from typing import Any + +@dataclass(frozen=True) +class SessionStartEvent: + session_id: str + project_path: str = "" + event: str = field(default="SessionStart", init=False) + +@dataclass(frozen=True) +class PreToolUseEvent: + session_id: str + tool_name: str + tool_args: dict[str, Any] = field(default_factory=dict) + event: str = field(default="PreToolUse", init=False) + +@dataclass(frozen=True) +class PostToolUseEvent: + session_id: str + tool_name: str + tool_args: dict[str, Any] = field(default_factory=dict) + tool_result: str = "" + event: str = field(default="PostToolUse", init=False) + +@dataclass(frozen=True) +class UserPromptSubmitEvent: + session_id: str + prompt: str + event: str = field(default="UserPromptSubmit", init=False) + +@dataclass(frozen=True) +class StopEvent: + session_id: str + reason: str = "" + last_assistant_message: str = "" # 与 Claude Stop hook 输入对齐 + stop_hook_active: bool = False # 是否处于 Stop hook 反馈后的重入 + event: str = field(default="Stop", init=False) + +@dataclass(frozen=True) +class PermissionRequestEvent: + session_id: str + tool_name: str + tool_args: dict[str, Any] = field(default_factory=dict) + event: str = field(default="PermissionRequest", init=False) +``` + +### 4.3 HookResult(统一返回信封) + +```python +@dataclass(frozen=True) +class HookResult: + action: str # "allow" | "deny" | "ask" | "block" | "pass" | "error" + reason: str = "" + additional_context: str = "" + updated_args: dict[str, Any] | None = None + exit_code: int = 0 + stderr: str = "" +``` + +- `"pass"` 表示 hook 无明确决策,继续后续 permission 流程 +- `"allow"`(PreToolUse):**建议免交互弹窗**,但仍须过 **规则层校验**(blacklist / ask rule,见 §10.6);**不**等同跳过整个 PermissionEnforcer +- `"ask"`(PreToolUse):强制进入 PermissionEnforcer / handler,可携带 hook 的 `reason` 作为 `force_decision` 文案 +- `"deny"`(PreToolUse / UserPromptSubmit / PermissionRequest):直接拒绝,不再询问用户 +- `"block"`(**仅 Stop 事件**):阻止 Agent 停止,继续执行;其他事件上 `block` 在协议层归一为 `deny` +- `"error"`:脚本出错,不阻断流程(除非 `fail_closed`) + +### 4.4 事件的可阻断语义 + +**PreToolUse**(对齐 Claude `resolveHookPermissionDecision`,见 §10.6): + +| Hook 返回 | 行为 | +|-----------|------| +| `deny` | 直接拒绝,不进入 PermissionEnforcer | +| `ask` | 进入 PermissionEnforcer,弹窗文案优先用 hook `reason` | +| `allow` | **跳过「无规则命中」时的 ask 弹窗**;blacklist **仍 deny**;显式 ask rule **仍弹窗** | +| `pass` / `{}` / 无返回 | 完整 PermissionEnforcer 流程(含 ask) | +| `updated_args` 且无 permission 字段 | 仅改参(passthrough),permission 用新参数再匹配 | +| `allow` + `updated_args` | 规则校验与放行均基于改写后参数 | + +**UserPromptSubmit:** + +| Hook 返回 | 行为 | +|-----------|------| +| `deny` | 拒绝该 prompt,不送入 LLM | +| `pass` / 无返回 | 正常提交 | + +**Stop:** + +| Hook 返回 | 行为(对齐 Claude `query/stopHooks.ts`) | +|-----------|----------------------------------------| +| `block` / `deny`(exit 2 或 `decision:block`) | 注入 **Stop hook feedback** 元消息,**不**设置 `should_stop`,主循环继续 | +| `pass` / 无返回 | 允许停止(`should_stop` 保持 `True`) | +| `additionalContext` | 写入 `hook_additional_context`(见 §8.4、§9.5) | + +> Claude 另有 stdout `continue: false` → `preventContinuation`,语义为**确认结束本轮**(v1 经 `ResponseAdapter` 映射为 `pass`)。Cursor `stop` 的 `followup_message` 对齐 ms-agent 的 `block` + 注入 user 消息。 + +### 4.5 来源框架中的执行位置(UserPromptSubmit / Stop) + +两事件在 Claude Code、Cursor、Hermes 中**均存在**(名称不同),但触发粒度不同: + +| Canonical | Claude Code | 触发位置(源码) | Cursor | 触发位置 | Hermes Shell | 触发位置 | +|-----------|-------------|-----------------|--------|---------|--------------|---------| +| `UserPromptSubmit` | `UserPromptSubmit` | `processUserInput.ts`:用户输入进入 query **之前**(`executeUserPromptSubmitHooks`) | `beforeSubmitPrompt` | 用户提交 prompt、送模型**之前** | `pre_llm_call` | 每次 LLM 调用前(**粒度更宽**) | +| `Stop` | `Stop` / `SubagentStop` | `query.ts`:`!needsFollowUp` 时、`handleStopHooks`(一轮结束、无后续 tool) | `stop` | Agent 任务完成时 | `on_session_end` 等(**无完全等价 Stop**) | 会话级 | + +**Claude Code — UserPromptSubmit** + +``` +用户输入 → processUserInput() + → executeUserPromptSubmitHooks(prompt) # 在 query 循环之前 + → deny/block → shouldQuery=false,不进入 LLM + → additionalContext → hook_additional_context 附件 + → 通过后才进入 query 主循环 +``` + +**Claude Code — Stop** + +``` +query 主循环一轮结束(assistant 无待执行 tool) + → handleStopHooks() / executeStopHooks() + → block(exit 2)→ 注入 Stop hook feedback 元 user 消息 → 继续 query(stopHookActive=true) + → pass → 正常结束 +``` + +**ms-agent 对齐挂点** + +| 事件 | ms-agent 挂点 | 不复用 | +|------|--------------|--------| +| `UserPromptSubmit` | ① `run_loop()` 中 `create_messages()` 之后、首步 `step()` 之前;② `InputCallback.after_tool_call` 追加 user 之后、下一轮 `step()` 之前 | ~~`on_generate_response`~~(每轮 LLM 前,非用户提交) | +| `Stop` | `after_tool_call()` 内:判定 `should_stop` **之前**(assistant 无 `tool_calls` 时) | ~~`on_task_end`~~(循环已退出) | +| `SessionStart` | `on_task_begin`(`round==0`) | — | + +`CallbackToHookBridge` 仅负责 `SessionStart`;`UserPromptSubmit` / `Stop` 由 `LLMAgent` 直接调用 `HookRuntime`(见 §9)。 + +--- + +## 5. 配置格式 + +### 5.1 YAML 配置结构 + +配置位于 `agent.yaml` 或独立的 hooks 配置文件中,支持全局和项目两级: + +```yaml +hooks: + PreToolUse: + - matcher: "file_system---*" + hooks: + - type: command + command: "./hooks/validate-path.py" + timeout: 30 + - matcher: "code_executor---shell_executor" + hooks: + - type: command + command: "./hooks/check-shell.sh" + timeout: 10 + # --- P2 扩展 handler(§17)--- + - type: http + url: "https://policy.corp.example/hooks/pre-tool" + timeout: 10 + headers: + Authorization: "Bearer ${POLICY_TOKEN}" + allowed_env_vars: ["POLICY_TOKEN"] + - type: prompt + prompt: | + Evaluate whether this shell command is safe for production. + Input: $ARGUMENTS + Reply JSON only: {"ok": true} or {"ok": false, "reason": "..."} + model: "qwen-plus" + timeout: 30 + + PostToolUse: + - matcher: "*" + hooks: + - type: command + command: "./hooks/log-tool-use.py" + timeout: 5 + + SessionStart: + - hooks: + - type: command + command: "./hooks/session-init.sh" + + UserPromptSubmit: + - hooks: + - type: command + command: "./hooks/validate-prompt.py" + + Stop: + - hooks: + - type: command + command: "./hooks/cleanup.sh" + # --- P3 agent hook(§17.4)--- + - type: agent + prompt: | + Verify the agent completed the plan in $ARGUMENTS. + Read transcript if needed. Return structured ok/reason. + model: "qwen-plus" + max_turns: 20 + timeout: 120 +``` + +### 5.2 配置层级 + +三层嵌套:**事件类型 → MatcherGroup 列表 → Hook 处理器列表** + +```python +@dataclass(frozen=True) +class HookHandlerConfig: + type: str = "command" # command | http | prompt | agent + timeout: float = 30.0 + fail_closed: bool = False + # command + command: str | None = None + # http(对齐 Claude HttpHook) + url: str | None = None + headers: dict[str, str] = field(default_factory=dict) + allowed_env_vars: tuple[str, ...] = () # headers 内 ${VAR} 可解析的白名单 + # prompt / agent(对齐 Claude PromptHook / AgentHook) + prompt: str | None = None # 支持 $ARGUMENTS / ${ARGUMENTS} 占位符 + model: str | None = None # 默认 hooks.default_model 或 fast 模型 + max_turns: int = 20 # 仅 agent;Claude MAX_AGENT_TURNS=50,ms-agent 默认更保守 + +@dataclass(frozen=True) +class MatcherGroup: + matcher: str | None # 工具名匹配模式(非工具事件为 None) + hooks: tuple[HookHandlerConfig, ...] +``` + +### 5.3 配置来源与合并规则 + +对齐 Playground「全局 + 项目继承」与 F9 Plugin 加载: + +| 优先级(低→高,后者追加) | 路径 | Loader | +|--------------------------|------|--------| +| 1 全局原生 | `~/.ms_agent/hooks.yaml` | `NativeYamlLoader`(**需** `hooks.enabled_sources` 含 `native`) | +| 2 全局 Claude | `~/.claude/settings.json` → `hooks` | `ClaudeSettingsLoader`(`enabled_sources` 含 `claude` 时) | +| 3 全局 Cursor | `~/.cursor/hooks.json` | `CursorHooksLoader`(`enabled_sources` 含 `cursor` 时) | +| 4 项目 Claude | `.claude/settings.json` | 同上 | +| 5 项目 Cursor | `.cursor/hooks.json` | 同上 | +| 6 项目原生 | `agent.yaml` → `hooks:` | `HookRegistry.from_dict`(**需** `enabled_sources` 含 `native`) | +| 7 项目目录 | `.ms-agent/hooks.json` | `NativeJsonLoader`(**需** `native`) | +| 8 Plugin | `hooks/hooks.json` | `PluginHooksLoader`(`enabled_sources` 含 `plugin` 时) | +| 9 Hermes(可选) | `~/.hermes/config.yaml` → `hooks:` | `HermesShellLoader`(`enabled_sources` 含 `hermes` 时) | + +**合并策略**: + +- 按事件类型 **追加**(append),同优先级内保持文件声明顺序 +- **执行顺序**:低优先级先执行,高优先级后执行;可阻断事件上首个 `deny` 短路 +- **默认**:仅加载 ms-agent 原生配置(`enabled_sources: [native]`);外部生态需显式开启 +- **`agent.yaml` 的 `hooks:` 事件段**与 `~/.ms_agent/hooks.yaml`、`.ms-agent/hooks.json` 同属 **native** 源,需 `enabled_sources` 含 `native` 才会加载 + +```yaml +hooks: + enabled_sources: [native] # 唯一配置键;可选: claude, cursor, hermes, plugin + enabled_executors: [command] # P2: http, prompt;P3: agent + default_model: "qwen-plus" # prompt/agent 默认模型 + fail_closed: false # 脚本崩溃/超时是否阻断(对标 Cursor failClosed) + # allowed_http_hook_urls: [...] # P2 HTTP 白名单,见 §17.1 +``` + +> **配置键命名**:统一使用 `enabled_sources`(非 `hooks.enabled`)。所有外部生态加载均受此开关控制;`agent.yaml` 内事件段同属 `native` 源。 + +**Playground 工作区约定**:项目级 hooks 脚本推荐放在 `.ms-agent/hooks/`,配置放在 `.ms-agent/hooks.json` 或 `agent.yaml`,与 session log、memory 等同属 `.ms-agent/` 命名空间。 + +### 5.4 匹配器规则 + +- 仅 **工具事件**(`PreToolUse`、`PostToolUse`、`PermissionRequest`)使用 matcher +- 非工具事件(`SessionStart`、`UserPromptSubmit`、`Stop`)无 matcher,所有 hooks 都触发 +- matcher 格式与权限系统一致:`server_name---tool_name`,支持 `*`/`?` 通配符和 `|` 分隔 + +--- + +## 6. 匹配器 + +### 6.1 共享 PatternMatcher + +Hooks 和 Permission 模块共用同一个通配符匹配函数,提取到 `ms_agent/utils/pattern_matcher.py`。 + +> **实现注记(permission 已落地)**:`ms_agent/permission/matcher.py` 中 `PermissionMatcher.match()` 已内联 fnmatch + `|` 逻辑(与下文等价)。P0 实施时**提取**为 `match_pattern()` 并让 `PermissionMatcher` 委托,避免两套实现漂移。Hooks matcher **v1 仅匹配工具名**(`server---tool`),**不**支持 Permission 的 `:content_pattern` 后缀;内容级策略用 PreToolUse 脚本内判断。 + +```python +import fnmatch + +def match_pattern(pattern: str, target: str) -> bool: + """fnmatch 通配符匹配,支持 | 分隔的多模式。 + + Examples: + match_pattern("file_system---*", "file_system---read_file") → True + match_pattern("read_file|write_file", "read_file") → True + match_pattern("code_executor---shell_*", "web_search---*") → False + """ + for alt in pattern.split('|'): + alt = alt.strip() + if alt and fnmatch.fnmatch(target, alt): + return True + return False +``` + +### 6.2 Permission 模块适配 + +`ms_agent/permission/matcher.py` 中的 `PermissionMatcher.match()` 改为调用 `match_pattern()`,保持接口不变: + +```python +from ms_agent.utils.pattern_matcher import match_pattern + +class PermissionMatcher: + def match(self, pattern: str, tool_call: str) -> bool: + return match_pattern(pattern, tool_call) + + def match_with_content(self, pattern, tool_name, tool_args) -> bool: + # ... 保持不变,内部 self.match() 已委托到 match_pattern +``` + +### 6.3 HookRegistry 中的使用 + +```python +class HookRegistry: + def get_handlers(self, event_type: str, tool_name: str | None = None) -> list[HookHandlerConfig]: + groups = self._index.get(event_type, []) + result = [] + for group in groups: + if event_type not in cls.TOOL_EVENTS: + result.extend(group.hooks) + elif group.matcher is None: + result.extend(group.hooks) + elif tool_name is not None and match_pattern(group.matcher, tool_name): + result.extend(group.hooks) + return result +``` + +> **实现注记**:工具事件在 `tool_name is None` 时不匹配任何带 matcher 的组,避免误触发全部 handler。 + +--- + +## 7. HookRegistry — 配置加载与合并 + +### 7.1 类设计 + +```python +@dataclass(frozen=True) +class HookRegistry: + _index: dict[str, tuple[MatcherGroup, ...]] + + VALID_EVENTS: ClassVar[frozenset[str]] = frozenset({ + "SessionStart", "PreToolUse", "PostToolUse", + "UserPromptSubmit", "Stop", "PermissionRequest", + "SubagentStop", # 配置可加载;运行时触发见 §4.1(P2) + }) + + TOOL_EVENTS: ClassVar[frozenset[str]] = frozenset({ + "PreToolUse", "PostToolUse", "PermissionRequest", + }) + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> HookRegistry: ... + + def merge(self, other: HookRegistry) -> HookRegistry: ... + + def get_handlers(self, event_type: str, + tool_name: str | None = None) -> list[HookHandlerConfig]: ... + + @property + def is_empty(self) -> bool: ... +``` + +### 7.2 from_dict 解析逻辑 + +```python +@classmethod +def from_dict( + cls, + d: dict[str, Any], + *, + enabled_executors: frozenset[str] = frozenset({"command"}), + source: str = "config", +) -> HookRegistry: + if not d: + return cls(_index={}) + + index: dict[str, tuple[MatcherGroup, ...]] = {} + for event_type, groups_raw in d.items(): + if event_type in ("enabled_sources", "enabled_executors", "default_model", + "fail_closed", "allowed_http_hook_urls", + "http_hook_allowed_env_vars"): + continue + if event_type not in cls.VALID_EVENTS: + logger.warning(f"Unknown hook event type: {event_type}") + continue + groups = [] + for g in (groups_raw or []): + matcher = g.get("matcher") if event_type in cls.TOOL_EVENTS else None + hooks_raw = g.get("hooks", []) + handlers = _filter_handlers_by_executor(hooks_raw, enabled_executors, source=source) + if handlers: + groups.append(MatcherGroup(matcher=matcher, hooks=handlers)) + if groups: + index[event_type] = tuple(groups) + return cls(_index=index) + + +def _parse_hook_handler(h: dict[str, Any]) -> HookHandlerConfig | None: + t = h.get("type", "command") + timeout = float(h.get("timeout", 30.0)) + fail_closed = bool(h.get("failClosed", h.get("fail_closed", False))) + if t == "command": + if not h.get("command"): + return None + return HookHandlerConfig(type="command", command=h["command"], + timeout=timeout, fail_closed=fail_closed) + if t == "http": + if not h.get("url"): + return None + return HookHandlerConfig( + type="http", url=h["url"], headers=dict(h.get("headers") or {}), + allowed_env_vars=tuple(h.get("allowedEnvVars", h.get("allowed_env_vars", []))), + timeout=timeout, fail_closed=fail_closed, + ) + if t in ("prompt", "agent"): + if not h.get("prompt"): + return None + return HookHandlerConfig( + type=t, prompt=h["prompt"], model=h.get("model"), + max_turns=int(h.get("maxTurns", h.get("max_turns", 20))), + timeout=timeout, fail_closed=fail_closed, + ) + logger.warning(f"Unknown hook handler type: {t}") + return None +``` + +### 7.3 merge 合并逻辑 + +```python +def merge(self, other: HookRegistry) -> HookRegistry: + """合并两个 registry(全局 + 项目),同事件类型下 append。""" + merged: dict[str, tuple[MatcherGroup, ...]] = {} + all_events = set(self._index) | set(other._index) + for event in all_events: + self_groups = self._index.get(event, ()) + other_groups = other._index.get(event, ()) + merged[event] = self_groups + other_groups + return HookRegistry(_index=merged) +``` + +--- + +## 8. HookExecutor — 执行引擎(Dispatcher + Command 后端) + +### 8.1 类设计(Dispatcher) + +`HookExecutor` 为**调度门面**,按 `HookHandlerConfig.type` 路由到各后端;所有后端统一返回 `HookResult`,stdout/HTTP body 均经 `ResponseAdapter`(§3.6)。 + +```python +class HookExecutor: + def __init__( + self, + working_dir: str | None = None, + *, + command: CommandHookExecutor, + http: HttpHookExecutor | None = None, # P2 + prompt: PromptHookExecutor | None = None, # P2 + agent: AgentHookExecutor | None = None, # P3 + response_adapter: ResponseAdapter, + ) -> None: ... + + async def execute(self, handler: HookHandlerConfig, + event_data: dict[str, Any], + ctx: HookExecutionContext) -> HookResult: + backend = self._backends.get(handler.type) + if backend is None: + return HookResult(action="error", + reason=f"Hook type '{handler.type}' not enabled") + return await backend.execute(handler, event_data, ctx) + +@dataclass +class HookExecutionContext: + """prompt/agent 后端需要 session 上下文;command/http 可选。""" + session_id: str + project_path: str + llm: LLM | None = None # 当前 session LLM(prompt/agent) + messages: list[Message] | None = None # agent hook 可读历史 + abort_signal: asyncio.Event | None = None + tool_manager: ToolManager | None = None # agent hook 受限工具集 +``` + +P0 仅注册 `command`;P2 起按 `hooks.enabled_executors: [command, http, prompt]` 启用扩展后端。 + +### 8.2 CommandHookExecutor — 子进程执行 + +```python +async def execute(self, handler: HookHandlerConfig, + event_data: dict[str, Any]) -> HookResult: + stdin_data = json.dumps(event_data, ensure_ascii=False).encode("utf-8") + + try: + proc = await asyncio.create_subprocess_exec( + *shlex.split(handler.command), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self._working_dir, + ) + stdout, stderr = await asyncio.wait_for( + proc.communicate(input=stdin_data), + timeout=handler.timeout, + ) + except asyncio.TimeoutError: + # 超时:非阻断错误 + proc.kill() + return HookResult(action="error", reason=f"Hook timed out after {handler.timeout}s", + exit_code=-1) + except FileNotFoundError: + return HookResult(action="error", reason=f"Hook command not found: {handler.command}", + exit_code=-1) + except Exception as e: + return HookResult(action="error", reason=str(e), exit_code=-1) + + exit_code = proc.returncode or 0 + stderr_text = stderr.decode("utf-8", errors="replace").strip() + stdout_text = stdout.decode("utf-8", errors="replace").strip() + + # exit 2 → 策略性阻断 + if exit_code == 2: + return HookResult(action="deny", reason=stderr_text or "Blocked by hook", + exit_code=exit_code, stderr=stderr_text) + + # exit 非 0 且非 2 → 非阻断错误 + if exit_code != 0: + logger.warning(f"Hook '{handler.command}' exited {exit_code}: {stderr_text}") + return HookResult(action="error", reason=stderr_text, exit_code=exit_code, + stderr=stderr_text) + + # exit 0 → 经 ResponseAdapter 解析 stdout(§3.6),勿在此重复解析 + return self._response_adapter.parse(stdout_text, exit_code, stderr_text, event_data.get("event")) +``` + +### 8.3 execute_all — 批量执行与合并 + +对齐 Claude `executeHooks`(`hooks.ts`):多个 handler **顺序执行**(v1;Claude 为并行,语义等价),合并 permission 决策时优先级 **`deny` > `ask` > `allow`**。仅 `deny`/`block` 立即短路;`allow` **不**短路后续 handler(审计 hook 可继续跑)。 + +```python +async def execute_all( + self, + handlers: list[HookHandlerConfig], + event_data: dict[str, Any], + *, + blockable: bool = False, + on_handler_complete: Callable | None = None, # 上报 hook_name / duration_ms +) -> HookResult: + ... + for handler in handlers: + result = await self.execute(handler, event_data, ctx) + if on_handler_complete: + await on_handler_complete(handler, result, duration_ms) + + if blockable and result.action in ("deny", "block"): + # Stop 保留 block;其余事件为 deny + return HookResult( + action=result.action if result.action == "block" else "deny", + ... + ) + ... +``` + +> **注意**:`allow` 不再在 `execute_all` 层跳过 PermissionEnforcer;跳过弹窗逻辑统一在 `resolve_hook_permission_decision()`(§10.6)。Stop 事件的 `deny` → `block` 映射在 `HookRuntime._run_event`(§3.7)。 + +### 8.4 子进程环境变量与 command 解析 + +`HookExecutor.execute()` 启动子进程时注入(兼容 Claude 社区脚本): + +| 变量 | 含义 | +|------|------| +| `MS_AGENT_PROJECT_DIR` | 项目根目录 | +| `MS_AGENT_PLUGIN_ROOT` | 当前 plugin 根(如有) | +| `CLAUDE_PROJECT_DIR` | 上者别名 | +| `CLAUDE_PLUGIN_ROOT` | 上者别名 | + +**command 解析(v1)**: + +- 默认 `shlex.split(handler.command)`——**按空格分词**,不支持 shell 引号语义;`bash -c "foo bar"` 会被错误拆分 +- 推荐写法:可执行脚本路径无空格(`./hooks/check.sh`、`/abs/path/hook.py`),复杂逻辑写在脚本内部 +- v2 可选支持 `command: ["bash", "-c", "..."]` 列表形式 +- 相对路径以 `agent.yaml` / hooks 配置所在目录为 cwd(`HookExecutor(working_dir=...)`) + +### 8.5 PostToolUse — additionalContext 回流路径 + +对齐 Claude Code `services/tools/toolHooks.ts` + `utils/messages.ts`:**不**把 context 拼进 tool result 字符串(避免破坏 JSON 工具输出)。 + +**Claude 数据流** + +``` +tool 执行完成 + → executePostToolHooks() + → hook 返回 additionalContext + → createAttachmentMessage({ type: 'hook_additional_context', hookEvent: 'PostToolUse', toolUseID }) + → 插入 transcript:assistant(tool_use) → [hook 附件] → user(tool_result) + → 渲染为 isMeta user 消息,包在 内送给模型 + → smooshSystemReminderSiblings:合并进同轮 tool_result 旁侧(Gap F) +``` + +**ms-agent 对齐方案** + +```python +# ms_agent/hooks/context.py +@dataclass(frozen=True) +class HookAttachment: + type: Literal["hook_additional_context", "hook_blocking_feedback", "hook_stopped_continuation"] + hook_event: str + tool_call_id: str | None + content: str | list[str] + +# ToolManager.single_call_tool — Pre/PostToolUse +pre_result, pre_attachments = await hook_runtime.run_pre_tool_use(...) +... +post_result, post_attachments = await hook_runtime.run_post_tool_use(...) +hook_attachments = list(pre_attachments) + list(post_attachments) +# 挂到 tool Message.hook_attachments(见 parallel_tool_call) +``` + +`LLMAgent.parallel_tool_call()` 组装 `Message(role='tool', ...)` 后: + +1. 将 `hook_attachments` 挂到对应 `Message` 的 **`hook_attachments` 字段**(与 `tool_detail` 同级,**不进** `to_dict_clean()`) +2. `step()` 在调用 LLM 前,经 `condense_hook_attachments_for_llm(messages)` 把附件转为 **user 元消息**(`[hook:PostToolUse]` 前缀或 Stop 的 `Stop hook feedback:`),插入 **原消息之后** +3. **禁止**修改 `tool` 消息的 `content` 本体 + +**PreToolUse additionalContext**:与 PostToolUse 相同,经 `hook_attachments` 挂在 tool 消息上,下轮 LLM 前 condense(v1 不 smoosh 进 tool_result 字符串)。 + +**Stop blocking feedback**:挂在 assistant 消息上的 `HookAttachment(type=hook_blocking_feedback)`,由 `append_stop_blocking_feedback()` 写入(§9.4)。 + +**WebUI(Hold)**:`hook_attachments` 经 SSE/API 透出供 UI 展示;具体渲染见 §9.6 预留接口。 + +### 8.6 fail_closed 粒度 + +| 粒度 | v1 | 说明 | +|------|-----|------| +| `hooks.fail_closed`(全局) | ✅ | 超时/exit≠2/命令不存在 → 可阻断事件视为 `deny` | +| per-handler `failClosed` | ✅ | 覆盖单条 handler;与全局为 OR 关系 | + +--- + +## 9. 生命周期 Hook 集成与阻断消费 + +### 9.1 职责划分(PR#906 后) + +| 事件 | 集成位置 | 说明 | +|------|---------|------| +| `PreToolUse` / `PostToolUse` / `PermissionRequest` | **`ToolManager.single_call_tool()`** | 工具名/参数/返回值 | +| `SessionStart` | **`CallbackToHookBridge.on_task_begin`** | `round==0` | +| `UserPromptSubmit` | **`LLMAgent.run_loop()` / `InputCallback` 路径** | 用户消息进循环前(§4.5) | +| `Stop` | **`LLMAgent.after_tool_call()`** | `should_stop` 判定前(§4.5) | + +`CallbackToHookBridge` **仅**转发 `SessionStart`;`UserPromptSubmit` / `Stop` 由 `LLMAgent` 直接调 `HookRuntime`,避免误绑到 `on_generate_response` / `on_task_end`。 + +### 9.2 CallbackToHookBridge(SessionStart 专用) + +```python +class CallbackToHookBridge(Callback): + def __init__(self, config, hook_runtime: HookRuntime) -> None: + super().__init__(config) + self._hooks = hook_runtime + + async def on_task_begin(self, runtime, messages) -> None: + await self._hooks.run_session_start(runtime, messages) +``` + +### 9.3 UserPromptSubmit — 挂点与消费 + +**挂点 A — 首条用户消息**(`run_loop()`,`create_messages()` 之后): + +执行顺序:**`SessionStart`(`on_task_begin`)→ `UserPromptSubmit`**。SessionStart 负责会话初始化;UserPromptSubmit 在校验通过后才进入 `step()` / LLM。 + +```python +# llm_agent.py — run_loop() round==0 +messages = await self.create_messages(messages) +await self.on_task_begin(messages) # SessionStart(CallbackToHookBridge) +prompt_text = _extract_latest_user_prompt(messages) +submit = await self._hook_runtime.run_user_prompt_submit(prompt=prompt_text, ...) +if submit.action == "deny": + # 对齐 Claude processUserInput.ts:不进入 step() + messages.append(Message( + role="system", + content=f"UserPromptSubmit operation blocked by hook:\n{submit.reason}\n\nOriginal prompt: {prompt_text}", + )) + await self.on_task_end(messages) + yield messages + return +_apply_hook_attachments(messages, submit) # additionalContext → hook_additional_context +``` + +**挂点 B — 多轮 `InputCallback`**(`after_tool_call` 追加 user 后、下一轮 `step()` 前): + +在 `InputCallback` 之后、`runtime.should_stop = False` 分支内,对新增 user 内容再跑一遍 `run_user_prompt_submit`;`deny` 时撤销该 user 消息并 `should_stop = True`。 + +**消费语义**(对齐 Claude `processUserInput.ts`): + +| HookResult | ms-agent 行为 | +|------------|--------------| +| `deny` / exit 2 | **不调用** `step()` / LLM;写入 system 警告 + 原始 prompt 摘要;结束或等待新输入 | +| `additional_context` | 追加 `HookAttachment(type=hook_additional_context, hook_event=UserPromptSubmit)`,下轮 LLM 前渲染为元 user 消息 | +| `pass` | 正常进入 `step()` | + +### 9.4 Stop — 挂点与消费 + +**挂点** — `LLMAgent.after_tool_call()`,在现有 `should_stop` 逻辑**之前**: + +```python +async def after_tool_call(self, messages: List[Message]) -> None: + assistant = messages[-1] + would_stop = assistant.role == "assistant" and not assistant.tool_calls + + if would_stop and self._hook_runtime is not None: + last_text = assistant.content if isinstance(assistant.content, str) else "" + stop = await self._hook_runtime.run_stop( + reason="no_tool_calls", + last_assistant_message=last_text, + stop_hook_active=getattr(self.runtime, "stop_hook_active", False), + ) + if stop.action in ("block", "deny"): + # 对齐 Claude stopHooks.ts:HookAttachment 承载,下轮 condense 为 user 元消息 + append_stop_blocking_feedback(messages, stop.reason) + self.runtime.should_stop = False + self.runtime.stop_hook_active = True + await self.loop_callback("after_tool_call", messages) + return + apply_hook_result_to_messages(messages, stop, hook_event="Stop") + + if would_stop: + self.runtime.should_stop = True + await self.loop_callback("after_tool_call", messages) +``` + +| HookResult | ms-agent 行为 | +|------------|--------------| +| `block` / `deny` | `should_stop = False`;`append_stop_blocking_feedback` → 下轮 condense 为 `Stop hook feedback` 元 user 消息;`stop_hook_active = True` | +| `additional_context` | `hook_additional_context` 附件,下轮 LLM 前注入 | +| `pass` | `should_stop = True`(默认停止) | + +### 9.5 HookAttachment 统一消费(阻断 + context) + +```python +# ms_agent/hooks/context.py +def apply_hook_result_to_messages(...) -> bool: + """返回 False 表示调用方应中止后续流程(UserPromptSubmit deny)。""" + ... + +def append_stop_blocking_feedback(messages, reason: str) -> None: + """Stop block:挂 hook_blocking_feedback 到当前 assistant 消息。""" + ... + +def condense_hook_attachments_for_llm(messages: list[Message]) -> list[Message]: + """hook_additional_context → [hook:Event];hook_blocking_feedback → Stop hook feedback。""" + ... +``` + +### 9.6 WebUI 预留接口(实现 Hold) + +以下接口在 `ms_agent/hooks/` 与 `webui/backend/` 间预留,**具体 UI 延后**: + +```python +# 供 SSE / agent_runner 消费 +class HookEventNotification(TypedDict, total=False): + hook_event: str + hook_name: str + action: str + reason: str + duration_ms: float + +# HookRuntime 可选 callback +on_hook_event: Callable[[HookEventNotification], Awaitable[None]] | None = None +``` + +WebUI 仅需订阅 `on_hook_event` 与 message 上的 `hook_attachments`;阻断态用现有 run 中止 + system 消息展示,不做专用弹窗(P2)。 + +### 9.7 注册方式 + +```python +async def prepare_tools(self): + ... + session_id = self.runtime.session_id or self.tag or str(uuid.uuid4()) + hook_runtime = build_hook_runtime(self.config, session_id=session_id) + + self.tool_manager = ToolManager(..., hook_runtime=hook_runtime, ...) + if hook_runtime.has_session_handlers: + self.register_callback(CallbackToHookBridge(self.config, hook_runtime)) + self._hook_runtime = hook_runtime + await self.tool_manager.connect() +``` + +--- + +## 10. 与权限系统的协作 + +> **权限模块已落地**(见 `docs/zh/design/permission-design.md` §1.1):`SafetyGuard` + `PermissionEnforcer` 已在 `ToolManager.single_call_tool()` L294–344 运行。本文档仅描述 **Hooks 插入点**;不在 `permission-design.md` 重复实现,但需在 permission 文档 §2 判定流程图补一行「1.5 Hooks PreToolUse」(见 [附录 C](#附录-c实现待办与跨文档约定))。 + +### 10.1 权限基线(PR#906)与 Hooks 集成后流程 + +PR#906 落地时 `single_call_tool` 仅有 SafetyGuard + PermissionEnforcer;**Hooks 已在此基础上插入**(见 §10.2)。以下为插入前的权限摘录(步骤 1 仍保持不变): + +```python +# ms_agent/tools/tool_manager.py — single_call_tool() 摘录 +args_dict = dict(tool_args) if isinstance(tool_args, dict) else {} + +if self._safety_guard is not None: + safety_decision = self._safety_guard.check(tool_name, args_dict) + if safety_decision.action == 'deny': + return f'Blocked by safety policy: {safety_decision.reason}' + # ask → resolve_ask() ... + +if self._permission_enforcer is not None: + perm_decision = await self._permission_enforcer.check(tool_name, args_dict) + if perm_decision.action == 'deny': + return f'Tool call denied: {perm_decision.reason}' + if perm_decision.updated_args is not None: + tool_args = perm_decision.updated_args + +response = await asyncio.wait_for(tool_ins.call_tool(...), timeout=wait_sec) +return response +``` + +### 10.2 目标执行顺序(插入 Hooks 后) + +``` +ToolManager.single_call_tool(tool_info) + │ + ├─ 1. SafetyGuard.check() ← 安全底线(不可绕过,已实现) + │ + ├─ 2. HookRuntime.run_pre_tool_use() ← PreToolUse(已实现) + │ └─ 产出 HookResult + pre_attachments(additionalContext) + │ + ├─ 3. resolve_hook_permission_decision() ← Hook × Permission 合并(§10.6) + │ ├─ deny → return 'Blocked by hook: ...' + │ ├─ allow → 规则层无异议则放行;blacklist / ask rule 仍可拦截 + │ └─ pass/ask → PermissionEnforcer.check()(ask 可带 hook reason) + │ + ├─ 4. tool_ins.call_tool() ← 执行(已实现) + │ + └─ 5. HookRuntime.run_post_tool_use() ← PostToolUse(已实现) + └─ pre + post hook_attachments → §8.5 +``` + +与 `permission-design.md` §2 对齐:**SafetyGuard → PreToolUse →(resolve)→ PermissionEnforcer → call_tool → PostToolUse**。 + +#### 10.2.1 F7 MCP Runtime 扩展(交叉引用) + +当 Playground 启用 `MCPRuntime` 时,[`mcp_runtime_management.md` §7.4](../../design/mcp_runtime_management.md#74-与-hooks-管线协作single_call_tool-完整顺序) 在**本节前**插入: + +1. `_tool_index` 快照 +2. **MCP callable 检查**(`degraded` / `error` 短路,不进入下文 SafetyGuard) + +此后步骤与本节 §10.2 编号对齐(本文步骤 1 → MCP 文档步骤 2,依此类推)。`degraded` 的 MCP 工具**不触发** PreToolUse,因其在 MCP callable 步骤已拒绝 RPC。 + +### 10.3 为什么 PreToolUse 在 PermissionEnforcer 之前、SafetyGuard 之后 + +- **SafetyGuard 不可绕过**:已在步骤 1 拒绝的调用不会进入 Hook(与 Claude bypass-immune safety checks 同层) +- **Hook 可提前 deny**:避免无意义的 confirm 弹窗 +- **Hook `allow` ≠ 全权放行**:对齐 Claude `resolveHookPermissionDecision`——`allow` 仅表示**建议免弹窗**,`permission:` blacklist 与显式 ask rule **仍可覆盖** +- **参数改写前置**:`updated_args` 在 permission 匹配前生效,黑白名单匹配最终参数 + +### 10.4 ToolManager 集成代码(目标补丁) + +```python +# ms_agent/tools/tool_manager.py + +from ms_agent.hooks.permission_resolve import resolve_hook_permission_decision + +args_dict = dict(tool_args) if isinstance(tool_args, dict) else {} + +# 1. SafetyGuard(已有) +if self._safety_guard is not None: + ... + +# 2. PreToolUse +hook_result: HookResult | None = None +pre_attachments: list[HookAttachment] = [] +if self._hook_runtime is not None: + hook_result, pre_attachments = await self._hook_runtime.run_pre_tool_use( + tool_name=tool_name, + tool_args=args_dict, + ) + if hook_result.updated_args is not None: + tool_args = hook_result.updated_args + args_dict = dict(hook_result.updated_args) + tool_info['arguments'] = tool_args + +# 3. Hook × Permission 合并 +perm_out = await resolve_hook_permission_decision( + hook_result=hook_result, + tool_name=tool_name, + tool_args=args_dict, + permission_enforcer=self._permission_enforcer, + permission_config=self._permission_config, +) +if isinstance(perm_out, str): + return perm_out +if perm_out.action == 'deny': + return f'Tool call denied: {perm_out.reason}' +if perm_out.updated_args is not None: + tool_args = perm_out.updated_args + tool_info['arguments'] = tool_args + +response = await asyncio.wait_for(tool_ins.call_tool(...), timeout=wait_sec) + +# 5. PostToolUse +if self._hook_runtime is not None: + post = await self._hook_runtime.run_post_tool_use(...) +return response +``` + +### 10.5 hooks 与权限的边界 + +| 维度 | Permission 系统 | Hooks 系统 | +|------|---------------|------------| +| 职责 | 内置的工具访问控制(YAML 规则 + 用户确认) | 可扩展策略脚本(社区 hook) | +| 配置来源 | YAML `permission:` 段 | YAML `hooks:` / `.claude/settings.json` 等 | +| 执行方式 | In-process Python | 子进程(语言中立) | +| PreToolUse `allow` | blacklist **仍可 deny**;无规则命中时可免弹窗 | 产出 `allow` **建议**,由 `resolve_hook_permission_decision` 合并 | +| PreToolUse `deny` | 不再执行 | 硬拒绝,优先于 permission | +| PreToolUse `pass` / `{}` | 完整 `check()` 流程 | 社区脚本「只审计不干预」的默认写法 | +| 用户交互 | `ask` → handler 弹窗 | `ask` 可强制带 hook 文案进入 handler | +| 记忆持久化 | PermissionMemory(allow_always) | 无(每次执行脚本) | + +`ToolManager.__init__` 新增 `hook_runtime: HookRuntime | None = None`;`LLMAgent.prepare_tools()` 构造共享实例并传入。 + +### 10.6 `resolve_hook_permission_decision` — 社区 Hook 兼容核心 + +对齐 Claude Code `services/tools/toolHooks.ts` → `resolveHookPermissionDecision()` 与 `utils/permissions/permissions.ts` → `checkRuleBasedPermissions()`。 + +**设计原则**:Hook 产出 **permission 建议**,不与 Permission 整层互斥;`allow` **不**等于 `hook_skip_permission=True`。 + +```python +# ms_agent/hooks/permission_resolve.py + +async def check_rule_based_permissions( + tool_name: str, + tool_args: dict[str, Any], + config: PermissionConfig, + matcher: PermissionMatcher, +) -> PermissionDecision | None: + """仅规则层:blacklist deny、显式 ask rule。不跑 handler 弹窗。 + 返回 None 表示规则层无异议(对齐 Claude checkRuleBasedPermissions → null)。""" + for pattern in config.blacklist: + if matcher.match_with_content(pattern, tool_name, tool_args): + return PermissionDecision(action='deny', reason=f'Denied by blacklist: {pattern}') + for pattern in config.ask_rules: + if matcher.match_with_content(pattern, tool_name, tool_args): + return PermissionDecision(action='ask', reason=f'Ask rule matched: {pattern}') + return None + + +async def resolve_hook_permission_decision( + hook_result: HookResult | None, + tool_name: str, + tool_args: dict[str, Any], + *, + permission_enforcer: PermissionEnforcer | None, + permission_config: PermissionConfig | None, + hook_runtime: HookRuntime | None = None, +) -> PermissionDecision | str: + """合并 PreToolUse 与 PermissionEnforcer。返回 str 表示工具层错误文案。""" + + # Hook deny — 直接拒绝(优先于 permission) + if hook_result and hook_result.action == 'deny': + return f'Blocked by hook: {hook_result.reason}' + + args = hook_result.updated_args if (hook_result and hook_result.updated_args) else tool_args + + # Hook allow — 跳过「无规则命中」时的 ask,但规则层仍可拦截 + if hook_result and hook_result.action == 'allow': + if permission_config: + rule = await check_rule_based_permissions( + tool_name, args, permission_config, PermissionMatcher()) + if rule and rule.action == 'deny': + return rule # blacklist 覆盖 hook allow(Claude inc-4788) + if rule and rule.action == 'ask': + # 显式 ask rule:仍走 enforcer / handler + if permission_enforcer: + return await permission_enforcer.check( + tool_name, args, force_decision=rule) + return PermissionDecision( + action='allow', + reason=hook_result.reason or 'Allowed by PreToolUse hook', + ) + + # Hook ask — 带 hook 文案进入完整 permission 流程 + if hook_result and hook_result.action == 'ask': + if permission_enforcer: + return await permission_enforcer.check( + tool_name, args, + force_decision=PermissionDecision( + action='ask', reason=hook_result.reason), + ) + + # pass / 无 hook — PermissionRequest(P1,interactive 模式)→ PermissionEnforcer + if hook_runtime and permission_config and permission_config.mode == 'interactive': + pr = await hook_runtime.run_permission_request(tool_name, args) + if pr.action == 'deny': + return f'Blocked by hook: {pr.reason}' + if pr.action == 'ask' and permission_enforcer: + return await permission_enforcer.check( + tool_name, args, + force_decision=PermissionDecision(action='ask', reason=pr.reason), + ) + + if permission_enforcer: + return await permission_enforcer.check(tool_name, args) + return PermissionDecision(action='allow', reason='No permission enforcer') +``` + +**`PermissionEnforcer.check()` 扩展**(小改,已实现类上追加可选参数): + +```python +async def check( + self, + tool_name: str, + tool_args: dict[str, Any], + *, + force_decision: PermissionDecision | None = None, +) -> PermissionDecision: + if force_decision and force_decision.action == 'ask': + # handler.ask() 使用 force_decision.reason 作为弹窗说明 + ... +``` + +**社区脚本典型场景对照**: + +| 社区脚本写法 | ms-agent 行为 | +|-------------|--------------| +| `echo '{}'` / 只写日志 | `pass` → 完整 permission(含 ask) | +| `permissionDecision: "allow"` | 无 blacklist 时免弹窗;**blacklist 仍 deny** | +| `permissionDecision: "deny"` | 直接拒绝 | +| `permissionDecision: "ask"` | 强制弹窗,带 hook reason | +| 仅 `updatedInput` | 改参后走完整 permission | +| `decision: "approve"`(Codex 风格) | 同 `allow` | + +**与 ms-agent 双层架构的映射**: + +| Claude 概念 | ms-agent 等价 | +|------------|-------------| +| `tool.checkPermissions` + safetyCheck bypass-immune | **SafetyGuard**(在 Hook 之前) | +| `checkRuleBasedPermissions` | `check_rule_based_permissions()`(blacklist / ask rule) | +| `canUseTool` 弹窗 | `PermissionEnforcer.check()` + handler | +| Hook `allow` 跳过弹窗 | `resolve` 在规则层无异议时直接 `allow` | + +--- + +## 11. 集成点与代码变更 + +PR#906 已合入,集成策略调整为 **双入口、单 HookRuntime**: + +| 模块 | 变更 | 侵入度 | +|------|------|--------| +| `ms_agent/tools/tool_manager.py` | `hook_runtime` + Pre/Post;Post 返回 `hook_attachments` | **必要** | +| `ms_agent/agent/llm_agent.py` | `prepare_tools()`;`run_loop` UserPromptSubmit;`after_tool_call` Stop;`condense_hook_attachments_for_llm` | **中等** | +| `ms_agent/llm/utils.py` | `Message.hook_attachments`;`Runtime.stop_hook_active` | 小 | +| `ms_agent/hooks/*` | Hooks 模块(含 `permission_resolve.py`、loaders) | **已实现** | +| `ms_agent/permission/enforcer.py` | `check(..., force_decision=)` 可选扩展 | 小改 | +| `ms_agent/utils/pattern_matcher.py` | 从 `permission/matcher.py` 提取 | 重构 | +| `ms_agent/hooks/bridge.py` | `CallbackToHookBridge`(仅 SessionStart) | **已实现** | + +### 11.1 `prepare_tools()` 接线(`llm_agent.py`) + +`session_id` 与 `Runtime.session_id` 同步(默认 `agent.tag`,否则 UUID),并写入 hook stdin。 + +```python +async def prepare_tools(self): + safety_guard, permission_enforcer, perm_config = self._build_permission_objects() + session_id = self.runtime.session_id or self.tag or str(uuid.uuid4()) + hook_runtime = build_hook_runtime(self.config, session_id=session_id) + + self.tool_manager = ToolManager( + self.config, + self.mcp_config, + self.mcp_client, + permission_enforcer=permission_enforcer, + safety_guard=safety_guard, + permission_mode=perm_config.mode, + read_policy=perm_config.safety.read_policy, + hook_runtime=hook_runtime, + trust_remote_code=self.trust_remote_code, + ) + if hook_runtime.has_session_handlers: + self.register_callback(CallbackToHookBridge(self.config, hook_runtime)) + self._hook_runtime = hook_runtime + await self.tool_manager.connect() +``` + +### 11.2 `parallel_tool_call` 与并发 + +`parallel_call_tool()` 对每个 `ToolCall` 独立调用 `single_call_tool()`。PreToolUse / PostToolUse **按单工具粒度**触发,与 Claude `PreToolUse` 一致。并发下: + +- 各调用使用独立 `tool_info` 副本,避免 `updated_args` 竞态 +- `HookExecutor` 子进程彼此隔离;handler 脚本须自身保证可重入 +- `session_id` 从 `HookRuntime` 或 `LLMAgent.runtime` 读取,跨并行 tool 共享 + +### 11.3 与旧 Callback 共存 + +- `InputCallback` 等内置 Callback **保留**,与 `CallbackToHookBridge` 同链执行 +- 旧 Python Callback **不废弃**;仅新增 shell hook 能力 +- `trust_remote_code` 与 shell hook **无关**——hook 脚本通过配置路径显式声明,不经 `importlib` 加载 + +--- + +## 12. 文件结构 + +``` +ms_agent/ +├── utils/ +│ └── pattern_matcher.py # 共享 fnmatch(从 permission/matcher 提取) +├── hooks/ +│ ├── __init__.py +│ ├── events.py # Canonical 事件 + HookResult +│ ├── registry.py +│ ├── executor.py # Dispatcher 门面 +│ ├── executors/ +│ │ ├── __init__.py +│ │ ├── command.py # P0 +│ │ ├── http.py # P2 §17.2 +│ │ ├── prompt.py # P2 §17.3 +│ │ └── agent.py # P3 §17.4 +│ ├── runtime.py +│ ├── factory.py +│ ├── response_adapter.py +│ ├── tool_name_mapper.py +│ ├── context.py +│ ├── bridge.py +│ ├── hook_helpers.py # add_arguments_to_prompt、HookOkReasonSchema +│ ├── permission_resolve.py # resolve_hook_permission_decision(§10.6) +│ └── loaders/ +│ ├── __init__.py +│ ├── native.py +│ ├── claude.py +│ ├── cursor.py +│ ├── hermes.py +│ └── plugin.py # F9 Plugin hooks/hooks.json +├── permission/ # 已实现,见 permission-design.md +│ └── matcher.py # 委托 pattern_matcher.match_pattern +docs/zh/design/ +├── hooks-design.md +└── permission-design.md +tests/ +├── test_hooks.py +├── test_hooks_loaders.py +├── test_hooks_context.py +└── fixtures/hooks/ +``` + +> **注**:F9 通用 `PluginLoader`(manifest 发现)若独立于 hooks,可置于 `ms_agent/plugins/`;当前 **hooks 加载** 由 `hooks/loaders/plugin.py` 的 `PluginHooksLoader` 完成。 + +--- + +## 13. 与外部生态的对比 + +### 13.1 执行模型 + +| 平台 | Hook 形态 | 执行方式 | ms-agent v1 | +|------|-----------|----------|-------------| +| Claude Code | settings + plugin `hooks.json` | 子进程 / HTTP / prompt / agent | **command 子进程** | +| Cursor | `.cursor/hooks.json` | 子进程 / prompt | **command 子进程**(兼容 Claude third-party) | +| Hermes | Plugin / Shell / Gateway | Python 进程内 / 子进程 | **Shell hook 子进程** | +| OpenClaw | Typed `api.on()` + HOOK pack | TS 进程内 | Claude `hooks.json` **不执行** | +| ms-agent | Canonical + 多源 loader | asyncio 子进程 | 原生 | + +### 13.2 ms-agent vs Claude Code(核心协议) + +| 特性 | Claude Code | MS-Agent | +|------|-------------|----------| +| 协议 | stdin/stdout/exit code | 一致 | +| 阻断 exit code | exit 2 | exit 2 | +| v1 事件 | 30+ | 6 核心 + 3 可选扩展(§15.3) | +| 原生配置 | `.claude/settings.json` | `agent.yaml` / `.ms-agent/hooks.json` | +| Plugin | `hooks/hooks.json` | F9 转换 merge | +| 权限协作 | Hook `allow` 仍受 settings deny/ask 约束 | `resolve_hook_permission_decision`(§10.6) | +| handler v1 | command/http/prompt/agent | **command only**(P2: http/prompt;P3: agent) | + +--- + +## 14. 验证方式 + +### 14.1 单元测试 + +| 模块 | 测试要点 | +|------|---------| +| `pattern_matcher` | 通配符匹配、`\|` 分隔、空模式、边界情况 | +| `HookRegistry.from_dict` | YAML 解析、未知事件 warning、空配置 | +| `HookRegistry.merge` | 全局 + 项目追加、事件独立、空合并 | +| `HookRegistry.get_handlers` | matcher 过滤、非工具事件全匹配、无 handler | +| `HookExecutor.execute` | exit 0 + JSON、exit 2 阻断、exit 1 非阻断、超时、找不到命令 | +| `HookExecutor.execute_all` | deny 短路;deny > ask > allow 合并;allow 不短路后续 handler | +| `resolve_hook_permission_decision` | allow + blacklist 覆盖;pass 走完整 enforcer;ask 带 force_decision | +| `HookRuntime` + `ToolManager` | 端到端 PreToolUse;PostToolUse `hook_attachments` | +| `CallbackToHookBridge` | 仅 SessionStart | +| `LLMAgent` UserPromptSubmit / Stop | §9.3 / §9.4 deny 与 block 消费 | +| `condense_hook_attachments_for_llm` | PostToolUse context 不污染 tool content | +| `HttpHookExecutor`(P2) | URL 白名单、header env 插值、SSRF、ResponseAdapter 解析 body | +| `PromptHookExecutor`(P2) | `$ARGUMENTS`、ok/reason schema、不触发 UserPromptSubmit | +| `AgentHookExecutor`(P3) | max_turns、structured output、工具过滤、Stop block | +| `_parse_hook_handler` | command/http/prompt/agent 字段解析;未知 type warning | + +### 14.2 集成测试 + +| 场景 | 验证内容 | +|------|---------| +| 真实脚本执行 | 写一个 Python hook 脚本,验证 stdin 收到 JSON、stdout 返回 JSON、exit code 正确处理 | +| Shell 脚本执行 | 写一个 bash hook 脚本(`exit 2 + stderr`),验证阻断行为 | +| Bridge + LLMAgent | Mock 生命周期,验证 SessionStart / Stop 等非工具事件 | +| ToolManager + hooks | allow 免弹窗;blacklist 覆盖 hook allow;`{}` 仍走 permission ask | +| 配置合并 | 全局 + 项目配置合并后,同事件下 handlers 追加且顺序正确 | +| Claude 配置加载 | 解析 `.claude/settings.json` 中 `PreToolUse` 嵌套结构,脚本可执行 | +| Cursor 配置加载 | 解析 `.cursor/hooks.json` 扁平结构,`beforeShellExecution` 映射正确 | +| Hermes block 格式 | `decision:block` 与 `action:block` 均归一化为 `deny` | +| 跨平台脚本 | 同一份 `block-rm.sh` 经 wrapper 在三方配置下均可阻断 | +| HTTP Policy hook(P2) | mock 远端返回 `decision:deny`,PreToolUse 短路 | +| Prompt guardrail(P2) | mock LLM `ok:false`,UserPromptSubmit 阻断 | +| Agent Stop 验证(P3) | mock 子 agent `ok:false`,Stop 被 block、agent 继续 | + +### 14.3 Hook 脚本示例 + +**PreToolUse:社区脚本放行(须显式 allow)** + +```python +#!/usr/bin/env python3 +import json, sys +event = json.load(sys.stdin) +if event.get("tool_name", "").endswith("shell_executor"): + cmd = event.get("tool_args", {}).get("command", "") + if cmd.startswith("pip install"): + # 仅审计、不干预 → pass(仍会走 permission ask) + print(json.dumps({})) + sys.exit(0) + if cmd.startswith("npm test"): + # 建议免弹窗 → allow(blacklist 仍可覆盖) + print(json.dumps({ + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "allow", + } + })) + sys.exit(0) +print(json.dumps({})) +``` + +**PreToolUse:硬拒绝** +```python +#!/usr/bin/env python3 +import json, sys +event = json.load(sys.stdin) +if event.get("tool_name", "").endswith("shell_executor"): + cmd = event.get("tool_args", {}).get("command", "") + if cmd.startswith("pip install"): + print(json.dumps({"decision": "deny", "reason": "pip install not allowed"})) + sys.exit(0) +print(json.dumps({})) +``` + +**PostToolUse:日志记录(Shell)** +```bash +#!/bin/bash +read event_json +tool_name=$(echo "$event_json" | jq -r '.tool_name') +echo "[$(date)] Tool used: $tool_name" >> /tmp/hook-log.txt +echo '{}' +``` + +**PreToolUse:阻断(Shell exit 2)** +```bash +#!/bin/bash +read event_json +tool_name=$(echo "$event_json" | jq -r '.tool_name') +if [[ "$tool_name" == *"shell_executor"* ]]; then + echo "Shell commands are disabled" >&2 + exit 2 +fi +echo '{}' +``` + +--- + +## 15. 多平台生态兼容设计 + +### 15.1 兼容目标与边界 + +**目标(对外承诺)**: + +> ms-agent 兼容 Claude Code、Cursor、Hermes 的 **shell-based third-party hooks**(工具拦截、审计、auto-format、context 注入)。用户可将社区 hook **脚本** 与 **配置文件** 以最小改动迁移到 ms-agent 或与之并存。 + +**边界(v1 执行后端)**: + +> 上表「不兼容项」指 **v1 不实现其原生 Executor**,不是放弃对应框架。各框架的 shell/command hook、Plugin 清单、Skills 仍在兼容范围内(见 §3.6.1)。 + +| v1 不原生执行 | 原因 | 该框架 v1 仍兼容什么 | +|--------------|------|---------------------| +| Claude HTTP / prompt / agent hook | 需独立 HTTP/LLM/子 agent 后端 | `type: command` hook、plugin `hooks.json`、settings 加载 | +| Hermes Python `register_hook()` | Hermes 进程内 API | Hermes **shell** hooks、`config.yaml` 加载 | +| OpenClaw `api.on()` | TS 进程内 | Skills/MCP bundle;Claude hooks.json 仅 detect(与 OpenClaw 一致) | +| 全量专有事件 | 非最小交集 | 8 个 Canonical 事件 + P2 扩展 | + +**覆盖率预期**: + +| 阶段 | 覆盖社区 hook 场景 | +|------|-------------------| +| P0 原生 + Bridge | ms-agent 自有配置,~40% | +| P1 Claude + Cursor loader | ~70%(审计/阻断/format 类) | +| P1 Hermes shell + Plugin | ~80% | +| P2 扩展事件 | ~85%,仍非 100% | + +### 15.2 三家 shell hook 生态关系 + +``` + ┌──────────────────────────────┐ + │ 社区 hook 脚本 (.sh/.py) │ ← 最易复用 + └──────────────┬───────────────┘ + │ + ┌─────────────────────────┼─────────────────────────┐ + ▼ ▼ ▼ + Claude Code Cursor Hermes Shell + settings.json hooks.json config.yaml hooks: + │ │ │ + └─────────────────────────┴─────────────────────────┘ + │ + ┌──────────────▼───────────────┐ + │ ms-agent ExternalLoaders │ + │ → Canonical IR → HookExecutor │ + └──────────────────────────────┘ +``` + +- **Claude ↔ Cursor**:互通最好;Cursor 官方支持 third-party Claude hooks,事件名为 camelCase 映射 +- **Hermes shell**:文档明确接受 Claude 风格 `{"decision":"block","reason":"..."}`;配置为 YAML +- **OpenClaw**:单列;Claude `hooks.json` 不执行,需 Codex `HOOK.md` 布局或 native plugin + +### 15.3 Canonical 事件模型与映射表 + +ms-agent 内部统一使用 **Canonical 事件名**(PascalCase,与 Claude 对齐),各 loader 负责入站映射: + +| Canonical | 触发时机(ms-agent) | Claude Code | Cursor | Hermes Shell | +|-----------|---------------------|-------------|--------|--------------| +| `SessionStart` | `on_task_begin` | `SessionStart` | `sessionStart` | `on_session_start` | +| `UserPromptSubmit` | 用户输入进入 run 前 | `UserPromptSubmit` | `beforeSubmitPrompt` | `pre_llm_call`(注入 context) | +| `PreToolUse` | 工具执行前 | `PreToolUse` | `preToolUse` | `pre_tool_call` | +| `PostToolUse` | 工具执行后 | `PostToolUse` | `postToolUse` | `post_tool_call` | +| `Stop` | `after_tool_call()` 内、`should_stop` 判定前(§9.4) | `Stop` | `stop` | `on_session_end`(**近似**,会话级) | +| `PermissionRequest` | `PermissionEnforcer.check` 前(interactive,`P1`) | `PermissionRequest` | — | `pre_approval_request`(仅观察) | +| `SubagentStop` | 配置可加载(`VALID_EVENTS`);**运行时触发 P2** | `SubagentStop` | `subagentStop` | `subagent_stop`(loader 映射) | +| `ShellBefore`(P2 独立事件) | v1 经 Cursor 合成 `PreToolUse`+shell matcher | `PreToolUse(Bash)` | `beforeShellExecution` → **P1 合成** | `pre_tool_call`+`terminal` matcher | +| `FileAfterEdit`(P2 独立事件) | v1 经 Cursor 合成 `PostToolUse`+write matcher | `PostToolUse(Write)` | `afterFileEdit` → **P1 合成** | `post_tool_call`+`write_file` matcher | + +> **⚠️ Hermes 兼容边界(必读)** +> +> | Hermes 事件 | ms-agent 映射 | 语义差异 | +> |-------------|--------------|---------| +> | `pre_llm_call` | `UserPromptSubmit` | Hermes:**每次** LLM 调用前;ms-agent:**仅用户消息进入循环时** | +> | `on_session_end` | `Stop` | Hermes:会话级结束;ms-agent:单轮 assistant 无 tool_calls 时,支持 **block 继续** | +> | `pre_approval_request` | `PermissionRequest` | Hermes 偏观察;ms-agent 在 interactive permission 流程中可阻断 | +> +> 迁移 Hermes shell hooks 时,勿假设触发频率与 Hermes 完全一致。 + +> **Cursor P1 策略**:`beforeShellExecution` / `afterFileEdit` 在 v1 通过 `CursorHooksLoader` 合成为带默认 matcher 的 `PreToolUse` / `PostToolUse`;P2 可拆为独立 Canonical 事件 `ShellBefore` / `FileAfterEdit`。 + +未知外部事件:**记录 warning 并跳过**,不导致 agent 崩溃。 + +### 15.4 工具名归一化(ToolNameMapper) + +外部 matcher 常按平台工具名编写,执行前需双向映射: + +| 语义 | Claude | Cursor | Hermes | ms-agent(示例) | +|------|--------|--------|--------|------------------| +| 执行命令 | `Bash` | `Shell` | `terminal` | `code_executor---shell_executor` | +| 读文件 | `Read` | `Read` | `read_file` | `file_system---read_file` | +| 写文件 | `Write` / `Edit` | `Write` | `write_file` / `patch` | `file_system---write_file` | +| 子 agent | `Task` | `Task` | `delegate_task` | `agent_tool---*` | + +`ToolNameMapper` 职责: + +1. **出站**(构造 stdin):Canonical payload 携带 `tool_name` 及按 `enabled_sources` 启用的 `tool_name_claude` / `tool_name_cursor` / `tool_name_hermes` 别名(见 §15.6) +2. **入站**(matcher 匹配):各 ExternalLoader 在加载时将外部 matcher **转换为** ms-agent `server---tool` 格式(`ToolNameMapper.external_matcher_to_native`);运行时按 ms-agent 工具名匹配,不做二次 `tool_name_*` 字段过滤 + +### 15.5 ExternalHookLoaders 设计 + +```python +class HookLoader(Protocol): + def load(self, ctx: HookLoadContext) -> HookRegistry: ... + +@dataclass(frozen=True) +class HookLoadContext: + project_root: str + global_ms_agent_dir: str # ~/.ms_agent + plugin_roots: tuple[str, ...] + enabled_sources: frozenset[str] +``` + +#### ClaudeSettingsLoader + +- 输入:`.claude/settings.json` 或 `~/.claude/settings.json` 的 `hooks` 段 +- 解析 Claude 三层嵌套:`event → [{matcher, hooks:[{type, ...}]}]` +- **P0/P1**:`type` 缺失或为 `command` 时入库;`http` / `prompt` / `agent` 若未在 `enabled_executors` 中启用 → **warning + 跳过**(不进入 registry;P2 `hooks doctor` 可扫描源文件提示) +- **P2+**:解析全部 `type`,字段映射见 §17.1 +- 路径变量:`${CLAUDE_PROJECT_DIR}` → `project_root`;`${CLAUDE_PLUGIN_ROOT}` → plugin root(F9) + +#### CursorHooksLoader + +- 输入:`.cursor/hooks.json` 的 `hooks` 对象 +- 扁平结构:`{ "preToolUse": [{ "command", "matcher", "timeout", "failClosed" }] }` +- 事件名 camelCase → Canonical PascalCase +- `beforeShellExecution` → 合成 `ShellBefore` 或带 `tool_class: shell` 的 `PreToolUse` matcher +- `failClosed` 透传到 handler 元数据 + +#### HermesShellLoader + +- 输入:`~/.hermes/config.yaml` 的 `hooks:` 段(**v1 仅全局**;项目级 Hermes 配置 P2) +- 事件名 snake_case → Canonical +- 仅加载 shell hook 条目(非 Python plugin) + +#### PluginHooksLoader(F9) + +对齐 `playground_prototype_design.md` F9: + +```python +# ms_agent/plugins/loader.py(示意) +def load_plugin_hooks(manifest: PluginManifest) -> HookRegistry: + hooks_path = manifest.root / "hooks" / "hooks.json" + if manifest.format == "claude": + return ClaudeSettingsLoader.parse_hooks_file(hooks_path, plugin_root=manifest.root) + ... +``` + +环境变量(脚本运行时注入,兼容 Claude Code plugin): + +| 变量 | 含义 | +|------|------| +| `MS_AGENT_PROJECT_DIR` | 项目根目录 | +| `MS_AGENT_PLUGIN_ROOT` | 当前 plugin 根目录 | +| `MS_AGENT_PLUGIN_DATA` | 可变数据目录 `~/.ms_agent/plugins/data//` | +| `CLAUDE_PROJECT_DIR` | **别名**,便于复用 Claude 社区脚本 | + +### 15.6 stdin CanonicalPayload 格式 + +对外部脚本,ms-agent 统一发送: + +```json +{ + "event": "PreToolUse", + "hook_event_name": "PreToolUse", + "session_id": "abc123", + "project_path": "/path/to/project", + "tool_name": "code_executor---shell_executor", + "tool_name_claude": "Bash", + "tool_name_cursor": "Shell", + "tool_name_hermes": "terminal", + "tool_args": {"command": "rm -rf /tmp/x"}, + "tool_input": {"command": "rm -rf /tmp/x"}, + "cwd": "/path/to/project", + "extra": {} +} +``` + +- `tool_args` / `tool_input` **同值**,兼容 Claude(`tool_input`)与 Hermes(`tool_input`)习惯 +- 多平台工具名字段可选;简单脚本可只读 `tool_args` + +### 15.7 兼容矩阵(能否直接换用) + +| 从 → 到 ms-agent | 配置文件 | 脚本 | 说明 | +|------------------|---------|------|------| +| Claude Code | 经 loader 转换 | **高** | 改 jq 路径即可跑大多数社区脚本 | +| Cursor | 经 loader 转换 | **高** | third-party Claude hooks 同理 | +| Hermes shell | 经 loader 转换 | **高** | block 双格式已在 ResponseAdapter 处理 | +| Claude plugin `hooks.json` | F9 merge | **高** | 需 `${CLAUDE_PLUGIN_ROOT}` 别名 | +| Hermes Python plugin | ✗ | ✗ | 需改写为 shell 或 ms-agent 原生 | +| OpenClaw Claude bundle | ✗(detect only) | 视脚本 | 仅当用户单独提供可执行脚本 | + +### 15.8 可移植脚本编写规范(推荐) + +供 Agent Hub / Playground 导出与社区文档使用: + +```bash +#!/usr/bin/env bash +# portable-pre-tool.sh — 尽量只依赖 jq 与 Canonical 字段 +payload=$(cat) +tool=$(echo "$payload" | jq -r '.tool_name_claude // .tool_name_cursor // .tool_name // empty') +cmd=$(echo "$payload" | jq -r '.tool_input.command // .tool_args.command // empty') + +if [[ "$tool" =~ ^(Bash|Shell|terminal)$ ]] && echo "$cmd" | grep -qE 'rm[[:space:]]+-rf'; then + jq -n '{"decision":"deny","reason":"rm -rf blocked","action":"block","message":"rm -rf blocked"}' + exit 0 +fi +printf '{}\n' +``` + +--- + +## 16. 分阶段交付与验收 + +对齐 `playground_prototype_design.md` F6(P0.5)与 F9(P1): + +### 16.1 P0 — 引擎 + ToolManager 主路径(权限集成点已就绪) + +> PR#906 / `permission-design.md` 已落地双层权限;**P0 + P1 loader 生态已实现**(`ms_agent/hooks/`),P2/P3 扩展 Executor 待做。 + +| 交付项 | 验收 | +|--------|------| +| `HookRegistry` / `HookExecutor` / `HookRuntime` / `pattern_matcher` | 单元测试通过 | +| **`ToolManager.single_call_tool` 集成** | PreToolUse deny/allow/updated_args;PostToolUse `hook_attachments` | +| `LLMAgent` UserPromptSubmit + Stop 挂点 | §9.3 / §9.4 语义测试 | +| `condense_hook_attachments_for_llm` | PostToolUse context 进入下轮 LLM,不污染 tool content | +| `CallbackToHookBridge` | 仅 SessionStart | +| `ResponseAdapter`(统一 stdout 解析) | `permissionDecision` / `approve`/`block` / `updatedInput` 归一化 | + +### 16.2 P1 — 三方生态 + Plugin(已实现) + +| 交付项 | 验收 | +|--------|------| +| `ClaudeSettingsLoader` | ✅ 加载 Claude `PreToolUse` 并在 ToolManager 执行 | +| `CursorHooksLoader` | ✅ `preToolUse` / `beforeShellExecution` 合成 | +| `HermesShellLoader` | ✅ `pre_tool_call` shell 配置 | +| `PluginHooksLoader`(F9) | ✅ plugin `hooks/hooks.json` merge | +| `ToolNameMapper` | ✅ Bash/Shell/terminal matcher | +| `PermissionRequest` hook | ✅ interactive 模式下 `resolve_hook_permission_decision` 内触发 | + +### 16.3 P2 — 扩展 Executor + Playground 集成 + +| 交付项 | 验收 | +|--------|------| +| `HttpHookExecutor` + URL 白名单 / SSRF 防护 | §17.2;企业 Policy POST 可阻断 PreToolUse | +| `PromptHookExecutor` + `$ARGUMENTS` 替换 | §17.3;`ok:false` 阻断;不递归触发 UserPromptSubmit | +| `hooks.enabled_executors` | 默认 `[command]`;解析时过滤未启用 type(P2 开启 `http` / `prompt`) | +| `SubagentStop` 运行时挂点 / `ShellBefore` / `FileAfterEdit` 独立事件 | P2:配置已可加载 `SubagentStop`;独立事件与运行时待做 | +| `fail_closed` / `hooks doctor` | 对标 Cursor/Hermes 运维体验;doctor 列出被跳过的非 command handler | +| WebUI Hooks 设置页 | 展示 enabled_sources、enabled_executors、脚本路径、测试触发 | +| Agent Hub 导出 | 导出 `.ms-agent/hooks.json` + 可选 Claude/Cursor 并列配置 | + +### 16.4 P3 — Agent Hook 与高级集成 + +| 交付项 | 验收 | +|--------|------| +| `AgentHookExecutor` | §17.4;Stop 验证、受限工具、`dontAsk` 模式、结构化 `{ok, reason}` | +| OpenClaw HOOK pack 适配 | §17.6;detect-only 或 command 转换 | +| 子 agent transcript 路径注入 | agent hook 可读 session log,对齐 Claude `getTranscriptPath()` | + +### 16.5 对外表述(产品 / 文档) + +建议使用: + +> ms-agent 支持原生 Hooks,并兼容 Claude Code、Cursor、Hermes 的 **shell hook 脚本与配置**(通过 `hooks.enabled_sources` 开启)。v1 完整支持 `command` handler;`http` / `prompt` 在 P2、`agent` 在 P3 以独立 Executor 补齐(见 [§17](#17-扩展-executorhttppromptagent) 与 [附录 A](#附录-ahook-handler-类型与应用场景))。 + +--- + +## 17. 扩展 Executor:HTTP / Prompt / Agent + +> 对齐 Claude Code `execHttpHook.ts` / `execPromptHook.ts` / `execAgentHook.ts`;与 §8 Dispatcher 共用 `ResponseAdapter` 与 `HookResult` 语义。 + +### 17.1 统一路由与配置 + +`HookExecutor`(§8.1)按 `handler.type` 分发;扩展后端与 `CommandHookExecutor` **并列**,不嵌套。 + +```python +# ms_agent/hooks/executor.py +class HookExecutor: + def __init__(self, ..., enabled_executors: frozenset[str] = frozenset({"command"})): + self._backends: dict[str, HookExecutorBackend] = {} + if "command" in enabled_executors: + self._backends["command"] = command_executor + if "http" in enabled_executors: + self._backends["http"] = http_executor + # prompt / agent 同理 +``` + +**全局开关**(`agent.yaml` / `hooks.yaml`): + +```yaml +hooks: + enabled_executors: [command] # P2: 追加 http, prompt;P3: 追加 agent + default_model: "qwen-plus" # prompt/agent 未指定 model 时的 fast 模型 + # HTTP 策略(对齐 Claude allowedHttpHookUrls / httpHookAllowedEnvVars) + allowed_http_hook_urls: # undefined=不限制;[]=全禁;非空=通配符白名单 + - "https://policy.corp.example/*" + http_hook_allowed_env_vars: ["POLICY_TOKEN", "AUDIT_API_KEY"] +``` + +**Claude settings → HookHandlerConfig 字段映射**: + +| Claude 字段 | ms-agent | 说明 | +|-------------|----------|------| +| `type: http` + `url` | `type`, `url` | 必填 | +| `headers` | `headers` | 值支持 `$VAR` / `${VAR}`,仅 `allowed_env_vars` 白名单内解析 | +| `allowedEnvVars` | `allowed_env_vars` | 与全局 `http_hook_allowed_env_vars` 取交集 | +| `type: prompt` + `prompt` | `type`, `prompt` | `$ARGUMENTS` 替换为事件 JSON 字符串 | +| `model` | `model` | 缺省 → `hooks.default_model` | +| `type: agent` + `prompt` | `type`, `prompt`, `max_turns` | 默认 `max_turns=20`(Claude 硬上限 50) | +| `timeout` | `timeout` | 秒;各后端独立默认见 §17.2–17.4 | + +**Loader 行为**:`ClaudeSettingsLoader` / `NativeYamlLoader` **解析并保留**全部 type;若 executor 未启用,`HookRegistry.get_handlers()` 过滤或 `HookExecutor.execute()` 返回 `action=error` + doctor warning,避免静默丢配置。 + +### 17.2 HttpHookExecutor + +**职责**:将 Canonical 事件 JSON **POST** 到 `handler.url`,响应 body 经 `ResponseAdapter` 解析(与 command stdout 相同 schema:`decision` / `permissionDecision` / `updatedInput` / `additional_context` 等)。 + +**执行流程**(对齐 `execHttpHook.ts`): + +``` +event_data ──json.dumps──► POST url + │ │ + │ headers + Content-Type: application/json + │ │ + ▼ ▼ + URL 白名单校验 SSRF lookup(无代理时) + │ │ + └──────┬───────┘ + ▼ + response body (text) + ▼ + ResponseAdapter.parse() + ▼ + HookResult +``` + +**类设计**: + +```python +# ms_agent/hooks/executors/http.py +class HttpHookExecutor: + async def execute( + self, handler: HookHandlerConfig, + event_data: dict[str, Any], + ctx: HookExecutionContext, + ) -> HookResult: + # 1. allowed_http_hook_urls 通配符匹配(* 语义同 Claude MCP allowlist) + # 2. 构建 headers:interpolate_env_vars(value, allowed_env_vars ∩ policy) + # sanitize CR/LF/NUL 防 header injection + # 3. aiohttp/httpx POST,timeout=handler.timeout,max_redirects=0 + # 4. 有 HTTP_PROXY 或沙箱代理时跳过直连 SSRF guard(与 Claude 一致) + # 5. 2xx → ResponseAdapter.parse(body);非 2xx / 网络错误 → action=error +``` + +**安全要点**: + +| 项 | 策略 | +|----|------| +| URL 白名单 | `hooks.allowed_http_hook_urls`:`undefined` 不限制;`[]` 禁止全部;否则 `urlMatchesPattern()` | +| SSRF | 直连时 DNS 解析后拒绝 private/link-local(可配置允许 loopback 用于本地 dev) | +| 环境变量 | 仅 `allowed_env_vars` ∩ 全局白名单可注入 header;其余替换为空串 | +| 重定向 | `max_redirects=0`,防开放重定向绕过白名单 | +| fail_closed | handler 或全局 `fail_closed=true` 时,网络/解析失败 → 可阻断事件上视为 `deny` | + +**与 command + curl 的差异**:统一超时、白名单、SSRF、响应 schema;Playground / 企业 MDM 可只开放 URL 而不分发脚本。 + +**典型配置**(Claude `settings.json` 等价): + +```json +{ + "hooks": { + "PreToolUse": [{ + "matcher": "Bash", + "hooks": [{ + "type": "http", + "url": "https://policy.example/v1/pre-tool", + "timeout": 10, + "headers": { "Authorization": "Bearer ${POLICY_TOKEN}" }, + "allowedEnvVars": ["POLICY_TOKEN"] + }] + }] + } +} +``` + +### 17.3 PromptHookExecutor + +**职责**:将 hook 事件 JSON 填入 `handler.prompt` 的 `$ARGUMENTS` / `${ARGUMENTS}` 占位符,调用**单次** LLM(非完整 agent 循环),解析结构化输出 `{ok: bool, reason?: string}`。 + +**执行流程**(对齐 `execPromptHook.ts`): + +``` +event_data ──json.dumps──► add_arguments_to_prompt(prompt, json) + │ + ▼ + 构造单条 user Message(不经过 run_loop / InputCallback) + │ + ▼ + llm.generate(structured output / json_schema) + model = handler.model ?? hooks.default_model + system: "return {ok:true} or {ok:false, reason}" + │ + ┌───────────────┴───────────────┐ + ▼ ▼ + ok == true ok == false + HookResult(pass/allow) HookResult(deny, reason) +``` + +**关键约束**: + +| 约束 | 原因 | +|------|------| +| **禁止递归 UserPromptSubmit** | 不得走 `create_messages()` / `processUserInput()`;直接构造 hook 专用 message(Claude L40–41) | +| **不注入完整 session history(默认)** | v1 仅 hook prompt + 可选 `ctx.messages` 尾部摘要;避免 token 爆炸 | +| **结构化输出** | `ok:false` → `action=deny`(可阻断事件)或 `block`(Stop);解析失败 → `action=error`(非阻断,除非 fail_closed) | +| **与 PreToolUse permission** | prompt 返回 `ok:true` 等价 `pass`(`{}`),**不**自动 `allow`;若需免弹窗须响应含 `permissionDecision: allow` 并经 `resolve_hook_permission_decision` | +| **工具不可用** | prompt hook **不**暴露 ToolManager;复杂判断用 agent hook | + +**类设计**: + +```python +# ms_agent/hooks/executors/prompt.py +class PromptHookExecutor: + async def execute(...) -> HookResult: + processed = add_arguments_to_prompt(handler.prompt, json.dumps(event_data)) + response = await ctx.llm.generate( + messages=[Message(role="user", content=processed)], + system_prompt=HOOK_PROMPT_SYSTEM, + model=handler.model or self._default_model, + response_format=HookOkReasonSchema, # {ok, reason?} + timeout=handler.timeout, + ) + parsed = parse_hook_ok_reason(response) + if parsed.ok: + return HookResult(action="pass") + return HookResult(action="deny", reason=parsed.reason or "Prompt hook condition not met") +``` + +**Cursor `type: prompt`**:`CursorHooksLoader` 映射为 `type: prompt`,共用本 Executor。 + +### 17.4 AgentHookExecutor + +**职责**:启动**短生命周期子 agent**(多轮 tool loop),用于需读仓库 / transcript / 多步验证的场景;主要用于 **`Stop`** 事件(Claude Stop 验证),亦可用于高成本 `PreToolUse`(需显式配置)。 + +**执行流程**(对齐 `execAgentHook.ts`): + +``` +event_data ──► add_arguments_to_prompt ──► 子 agent run_loop(受限) + │ + ┌─────────────────────────┼─────────────────────────┐ + ▼ ▼ ▼ + 过滤危险工具 permission mode=dontAsk max_turns 上限 + (无 spawn subagent) transcript 路径可读 timeout abort + │ │ │ + └─────────────────────────┴─────────────────────────┘ + ▼ + StructuredOutputTool → {ok, reason} + ▼ + HookResult(deny|pass) +``` + +**类设计**: + +```python +# ms_agent/hooks/executors/agent.py +class AgentHookExecutor: + DISALLOWED_TOOLS = frozenset({ + "agent_tool", "plan_mode", ... # 禁止子 agent 再 spawn / 进 plan + }) + + async def execute(...) -> HookResult: + hook_agent_id = f"hook-agent-{uuid4()}" + tools = filter_tools(ctx.tool_manager, disallow=self.DISALLOWED_TOOLS) + tools.append(StructuredOutputTool(schema=HookOkReasonSchema)) + + sub_ctx = HookAgentContext( + parent=ctx, + agent_id=hook_agent_id, + permission_mode="dontAsk", # 对齐 Claude getAppState().mode + extra_allow_rules=[f"read:{transcript_path}"], + max_turns=min(handler.max_turns, 50), + ) + result = await run_hook_agent_loop( + messages=[user_msg_from_prompt], + system_prompt=HOOK_AGENT_SYSTEM.format(transcript=transcript_path), + tools=tools, + ctx=sub_ctx, + timeout=handler.timeout, + ) + if result is None: # 超时 / 未调用 structured output + return HookResult(action="error", reason="Agent hook did not complete") + if not result.ok: + return HookResult(action="deny", reason=result.reason) + return HookResult(action="pass") +``` + +**与 prompt 的选型**: + +| | PromptHook | AgentHook | +|---|-----------|-----------| +| LLM 调用 | 单次 | 多轮 + 工具 | +| 延迟 / 成本 | 低 | 高 | +| 可读文件 / 跑命令 | 否 | 是(受限工具集) | +| 典型事件 | UserPromptSubmit、轻量 PreToolUse | **Stop**、复杂合规 | + +**Stop 语义**:`ok:false` → `block`(阻止停止,agent 继续);映射到 `HookResult(action="block", reason=...)`,由 `LLMAgent` §9.4 消费。 + +**安全**:子 agent 继承父 session 的工具面但经白名单过滤;`dontAsk` 仅作用于 hook 子会话;禁止修改 hooks 配置或启动新 top-level session。 + +### 17.5 扩展 Executor 与权限 / 阻断事件 + +三类扩展 Executor 的输出**统一**进入既有管线: + +``` +Executor → HookResult + ├─ PreToolUse + deny → 短路,不调用工具 + ├─ PreToolUse + allow/pass → resolve_hook_permission_decision(§10.6) + ├─ UserPromptSubmit + deny → 拒绝用户消息进入循环 + ├─ Stop + block/deny → 取消 should_stop,注入 reason 到 assistant 上下文 + └─ PostToolUse → additional_context → HookAttachment(§8.5) +``` + +**prompt/agent 的 `ok` 与 permission JSON 的关系**: + +- 仅 `{ok:false}` → 策略性阻断(等价 exit 2 / `decision:deny`) +- `{ok:true}` + stdout 风格 `permissionDecision: allow` → 走 `resolve_hook_permission_decision` +- HTTP 响应 body 可同时携带 Claude 完整 JSON(`updatedInput` 等),由 `ResponseAdapter` 一次解析 + +### 17.6 OpenClaw 与其它扩展 + +OpenClaw **typed `api.on()`** hook 为 TS 进程内中间件,ms-agent **不**原生执行。P3 可选路径: + +1. **detect-only**:识别 OpenClaw bundle,文档引导作者导出等价 `hooks.json` command 脚本 +2. **command 转换**:将简单 HOOK pack 译为 shell 包装(只读场景) +3. **TS 沙箱**(远期):独立 Node 子进程,不在 P2/P3 范围 + +### 17.7 测试与验收 + +| Executor | 单测要点 | 集成要点 | +|----------|---------|---------| +| Http | URL 白名单、env 插值、SSRF mock、2xx/4xx body 解析 | mock Policy server 阻断 PreToolUse | +| Prompt | `$ARGUMENTS` 替换、ok/schema 失败、不触发 UserPromptSubmit | UserPromptSubmit deny 端到端 | +| Agent | max_turns、structured output 缺失、工具过滤 | Stop block → agent 继续一轮 | + +--- + +## 附录 A:Hook Handler 类型与应用场景 + +各平台(尤其 Claude Code、Cursor)的 hook 配置里,`type` 字段决定**用哪种执行后端**处理同一生命周期事件。与 §3.6.1 一致:ms-agent **兼容这些框架**,但 v1 仅原生实现 `command`;其余类型见下表「ms-agent 规划」列。 + +### A.1 四种 Handler 对比 + +| 类型 | 执行模型 | 确定性 | 典型延迟 | 社区占比(经验) | +|------|---------|--------|---------|-----------------| +| **command** | 子进程 + stdin/stdout JSON | 高 | 毫秒~秒级 | **>80%** | +| **http** | POST 事件 JSON 到 URL,解析响应 | 高(依赖远端) | 百毫秒~秒级 | 企业 / Partner 为主 | +| **prompt** | 将 hook input 填入 prompt,调 LLM 判断 | 低~中 | 秒级 | 少量 | +| **agent** | spawn 短生命周期子 agent 多步验证 | 中 | 秒~分钟级 | 很少 | + +### A.2 command(shell / 可执行文件) + +**机制:** + +``` +生命周期事件 → fork 子进程 → stdin(JSON) → 脚本 → stdout(JSON) 或 exit 2 +``` + +**典型场景:** + +| 场景 | 示例 | +|------|------| +| 硬规则拦截 | 拒绝 `rm -rf`、拒绝 `pip install` | +| 自动格式化 | `PostToolUse` 后对刚写入的 `.py` 跑 `black` | +| 本地审计 | 追加 tool 调用日志到 `/var/log/agent-audit.jsonl` | +| Secret 扫描 | 脚本内用 regex / trufflehog 扫描命令参数 | +| 会话初始化 | `SessionStart` 时检查环境变量、git 状态 | + +**为何作为 v1 主路径:** 与 Claude / Cursor / Hermes shell hook 协议一致,可移植、可审计、无额外 LLM 成本。 + +### A.3 http + +**机制:** + +``` +生命周期事件 → HTTP POST(JSON body)→ 远端服务 → 响应 JSON(allow/deny/...) +``` + +**典型场景:** + +| 场景 | 示例 | +|------|------| +| 企业统一策略中心 | 每次 `PreToolUse` 询问公司 Policy API 是否允许 | +| SIEM / 可观测 | 将 tool 调用异步上报 Splunk、Datadog、自建 audit 服务 | +| Secrets / 合规 SaaS | Cursor Partner 类集成:POST 到厂商治理 endpoint | +| 集中留痕 | 金融/医疗:所有 shell 必须经合规网关登记 | +| 跨团队通知 | `Stop` / `agent:end` 时 POST Slack/Teams webhook | + +**与 command + curl 的区别:** 平台对 http hook 约定统一超时、鉴权头、async、响应 schema;企业可只开放 URL 白名单,无需在每台机器分发脚本。 + +**ms-agent 规划:** P2 `HttpHookExecutor`(§17.2);v1 加载配置时对未启用的 `type: http` **warning + 跳过**,或文档建议用户用 shell 脚本包装同一 HTTP 调用。 + +### A.4 prompt + +**机制:** + +``` +生命周期事件 → 构造策略 prompt(含 hook input)→ 调 LLM → 解析 allow/deny +``` + +**典型场景:** + +| 场景 | 示例 | +|------|------| +| 自然语言策略 | 「只允许只读操作」— 难以用正则穷举 | +| 意图判断 | 「这条 shell 是否在执行生产部署?」 | +| Prompt 合规 | `UserPromptSubmit` 前检查是否含 PII、违规内容 | +| 轻量 guardrail | 策略频繁变更,不想维护大量 shell | + +**代价:** 多一次 LLM 调用(慢、花钱、非完全确定)。**不适合**必须 100% 确定的硬安全底线(应由 `SafetyGuard` + command hook 承担)。 + +**ms-agent 规划:** P2 `PromptHookExecutor`(§17.3,复用 `hooks.default_model`);Cursor `type: prompt` 共用同一后端。 + +### A.5 agent + +**机制:** + +``` +生命周期事件 → 启动子 agent(只读/受限工具)→ 多步推理 → 返回决策 +``` + +**典型场景:** + +| 场景 | 示例 | +|------|------| +| 复杂合规 | 子 agent 读内部 runbook + 当前 diff,判断 DB migration 是否允许 | +| 多文件上下文 | 需结合多个文件状态才能决定能否执行某命令 | +| 动态威胁分析 | 不仅看单条命令,还要看 branch、近期 commits、CI 状态 | + +**与 prompt 的区别:** agent hook 可**调用工具、读仓库**,不仅是一次 LLM 问答。 + +**ms-agent 规划:** P3 `AgentHookExecutor`(§17.4),对接 ms-agent 子 agent / `AgentTool`,主要用于 Stop 验证。 + +### A.6 选型建议(产品 / 实施) + +``` +需要 100% 确定、可审计? → command(+ SafetyGuard) +策略在远端、组织统一治理? → http(P2,§17.2) +策略难脚本化、可接受 LLM? → prompt(P2,§17.3) +需多步读库/读文件才能判断? → agent(P3,§17.4) +``` + +--- + +## 附录 B:Hermes 三套 Hook 体系与功能关系 + +Hermes Agent 的 hook 常口语说成「两套」,实为 **三套**,按**注册方式、运行范围、能否阻断 agent 循环**划分。理解差异有助于 ms-agent 对齐 **Hermes shell hooks**(v1)而**不**追求原生执行 Python plugin hook 或 Gateway hook。 + +### B.1 三套体系总览 + +| 体系 | 注册方式 | 配置位置 | 语言 | 运行范围 | 能否 block 工具 | +|------|---------|---------|------|---------|----------------| +| **Plugin hooks** | `ctx.register_hook()` in `register(ctx)` | Python plugin 内 | Python 进程内 | CLI + Gateway + Cron | ✅ `pre_tool_call` 等 | +| **Shell hooks** | `hooks:` in `config.yaml` | `~/.hermes/config.yaml` | 任意(子进程) | CLI + Gateway | ✅ 同 Plugin | +| **Gateway hooks** | `HOOK.yaml` + `handler.py` | `~/.hermes/hooks//` | Python(Gateway 内) | **仅 Gateway** | ❌(观察/副作用) | + +### B.2 为何拆成多套? + +**1. 信任边界** + +| 体系 | 信任模型 | +|------|---------| +| Shell hooks | 子进程隔离;每个 `(event, command)` 首次需用户 consent(`hooks_auto_accept` / `--accept-hooks`) | +| Plugin hooks | 与 agent 同进程;靠 `plugins.enabled` 白名单显式启用 | +| Gateway hooks | 信任 `~/.hermes/hooks/` 目录;错误只 log,不 crash Gateway | + +**2. 事件命名空间不同** + +**Plugin + Shell** 共用 `VALID_HOOKS`(agent 循环): + +``` +pre_tool_call, post_tool_call, pre_llm_call, post_llm_call, +on_session_start, on_session_end, on_session_reset, on_session_finalize, +subagent_stop, pre_gateway_dispatch, pre_approval_request, ... +``` + +**Gateway hooks** 使用 Gateway 生命周期事件: + +``` +gateway:startup, session:start, session:end, session:reset, +agent:start, agent:step, agent:end, command:*, ... +``` + +CLI 没有「Telegram 用户发消息」等上下文,故 Gateway hooks **故意不在 CLI 加载**。 + +**3. 设计目标不同** + +| 目标 | 适用体系 | +|------|---------| +| 拦截危险工具、注入 turn context | Plugin / Shell | +| 运维不写 Python、只要一个脚本 | Shell | +| 插件作者与 `register_tool` 同包发布 | Plugin | +| Gateway 启动巡检、IM 告警、slash 命令审计 | Gateway hooks | + +### B.3 功能关系图 + +``` + ┌─────────────────────────────────────┐ + │ Agent 循环(CLI + Gateway) │ + │ │ + Python plugin ──►│ Plugin hooks ──┐ │ + │ ├──► invoke_hook() │ + config.yaml ──►│ Shell hooks ──┘ 分发器 │ + │ │ │ + │ ▼ │ + │ pre_tool_call / pre_llm_call / ... │ + │ (可 block / 可注入 context) │ + └─────────────────────────────────────┘ + + ┌─────────────────────────────────────┐ + │ 仅 Gateway(Telegram/Discord/…) │ + ~/.hermes/hooks/ │ Gateway hooks(HOOK.yaml) │ + │ gateway:startup / agent:step / ... │ + │ (观察为主,不 block 工具循环) │ + └─────────────────────────────────────┘ +``` + +### B.4 Plugin hooks 与 Shell hooks 的协作 + +二者经 **同一 `invoke_hook()` 分发器**: + +1. **执行顺序**:先 Plugin hooks(按插件发现顺序),后 Shell hooks +2. **`pre_tool_call` 阻断**:第一个有效 `{"action":"block"}` / `{"decision":"block"}` 胜出 +3. **能力重叠**:同一事件可既有 Plugin 又有 Shell;Plugin 适合复杂逻辑,Shell 适合运维一键脚本 + +Hermes 文档中的 **BOOT.md 启动清单** 是 Gateway hooks 的典型模式:在 `gateway:startup` 后台起一个 one-shot agent 执行 `~/.hermes/BOOT.md` 里的自然语言指令(与 Plugin/Shell 的 `pre_tool_call` 无关)。 + +### B.5 能力对照(节选) + +| 能力 | Plugin | Shell | Gateway | +|------|--------|-------|---------| +| 阻断 `pre_tool_call` | ✅ | ✅ | ❌ | +| `pre_llm_call` 注入 context | ✅ | ✅ | ❌ | +| `post_tool_call` 后处理(format) | ✅ | ✅ | ❌ | +| Gateway 启动时跑 BOOT 检查 | ❌ | ❌ | ✅ | +| `agent:step` 超过 N 步发 Telegram 告警 | ❌ | ❌ | ✅ | +| 记录所有 `/command` 使用 | 部分 | 部分 | ✅(`command:*`) | +| 子进程隔离 | ❌ | ✅ | ❌(Gateway 进程内) | + +### B.6 与 ms-agent 设计的映射 + +| Hermes | ms-agent 策略 | +|--------|--------------| +| **Shell hooks** | v1 **主兼容路径**(`HermesShellLoader` + `HookExecutor`) | +| **Plugin hooks** | 不执行 Python `register_hook()`;等价逻辑用 shell 或 ms-agent 自有 plugin | +| **Gateway hooks** | v1 不实现(无 IM Gateway 产品面);若未来有 Gateway,可单列 `gateway:*` 事件族 | +| block 双格式 | `ResponseAdapter` 同时识别 `decision:block` 与 `action:block`(§3.6) | + +### B.7 常见误解澄清 + +| 误解 | 说明 | +|------|------| +| 「兼容 Hermes = 能跑 Hermes Python plugin」 | ❌ Plugin hook 是 Hermes 运行时 API;兼容的是 **shell hook + 配置语义** | +| 「Gateway hook 也能拦工具」 | ❌ Gateway hooks 不进入 `invoke_hook()` 的 block 路径 | +| 「Shell 和 Plugin hook 是两套互斥系统」 | ❌ 互补,同一事件可叠加;Plugin 优先 | +| 「Hermes 只有两套」 | 通常指 **进程内(Plugin)vs 可配置(Shell/Gateway)**;严格说是三套 | + +--- + +## 附录 C:实现状态与跨文档约定 + +permission 模块**已在当前仓库实现**(`ms_agent/permission/`)。Hooks **P0 + P1 已实现**;下表标注剩余 P2/P3 项。 + +| # | 项 | 状态 | 文档位置 | +|---|-----|------|---------| +| 4 | ResponseAdapter 统一解析 | ✅ | §8.2 | +| 5 | 子进程 env | ✅ | §8.4 | +| 6 | 多 hook 合并 | ✅ | §8.3 | +| 7 | `resolve_hook_permission_decision` | ✅ | §10.6 | +| 8 | pattern_matcher 提取 | ✅ | §6.1 | +| 9 | fail_closed(全局 + per-handler) | ✅ | §8.6 | +| 10 | permission 交叉引用 | 待补 `permission-design.md` §2 | §10 | +| 11 | HttpHookExecutor | P2 | §17.2 | +| 12 | PromptHookExecutor | P2 | §17.3 | +| 13 | AgentHookExecutor | P3 | §17.4 | +| 14 | `enabled_executors` 扩展后端注册 | 解析✅ / http·prompt·agent 执行器 P2/P3 | §17.1 | +| 15 | `SubagentStop` 运行时挂点 | P2 | §4.1 | +| 16 | `hooks doctor` | P2 | §16.3 | + +**`permission-design.md` 建议补丁** — 在 §2 判定流程中,将原步骤 2 拆为: + +``` + ├─ 1.5. HookRuntime.run_pre_tool_use() ← 社区 hook(可选) + ├─ 2. resolve_hook_permission_decision() ← Hook allow/deny/ask × 规则层合并 + │ └─ 内部调用 PermissionEnforcer.check()(非 allow 短路整层) + └─ 3. tool_ins.call_tool() +``` + +**`pass` ≠ `allow`**:社区脚本 `echo '{}'` 走完整 permission;仅 `permissionDecision: allow` / `decision: approve` 才触发免弹窗路径。 + +**Hook 脚本安全**:子进程 hook 等价于用户显式授权执行命令;v1 不做沙箱,依赖配置路径 + Playground 工作区隔离;与 `trust_remote_code` 无关但需在用户文档中警告。 + +**Cursor `beforeSubmitPrompt` 阻断**:官方 IDE 对输出 JSON 阻断支持仍在演进;ms-agent v1 按 Claude 语义实现 deny,经 `ResponseAdapter` 兼容 Cursor 字段名即可。 diff --git a/docs/zh/design/mcp_runtime_management.md b/docs/zh/design/mcp_runtime_management.md new file mode 100644 index 000000000..167d019bd --- /dev/null +++ b/docs/zh/design/mcp_runtime_management.md @@ -0,0 +1,1101 @@ +# MCP 运行时管理 — 方案设计 + +> 基于 [`playground_prototype_design.md`](../../playground_prototype_design.md) F3(分层配置)与 F7(Skill / MCP 运行时管理)细化;Hooks 集成对齐 [`hooks-design.md`](../zh/design/hooks-design.md)(F6,已落地)。 +> +> 状态:方案设计 v0.5 | 2026-06-16 + +--- + +## 1. 背景与目标 + +实验场(Playground)需要支持: + +| 能力 | 产品语义 | +|------|----------| +| 配置分层 | 全局 → 项目 → session 多级合并,MCP server 按 name 并集去重 | +| 持久化 CRUD | UI / CLI 可增删改 MCP server 定义,写入 `~/.ms_agent/` 或项目 `.ms-agent/` | +| 运行时开关 | `enabled: false` 的 server **完全不可用**(不出现在 tool 列表、不可被调用) | +| 热更新 | 修改配置或切换 enabled 后,无需重启整个 Agent 进程即可生效 | +| 多形态复用 | 同一套能力供 WebUI、TUI、CLI、Workflow 共用 | + +**核心结论(回应「MCP 原本就可以独立初始化」)**: + +现有 `MCPClient` **已经**是一个可独立创建、连接、复用的 MCP 连接层;`LLMAgent` / `ToolManager` 也支持外部注入。F7 不应重写连接逻辑,而是在此之上补齐 **配置持久化 + 运行时状态机 + ToolManager 索引同步**。 + +--- + +## 2. 现状分析 + +### 2.1 已有能力(可直接复用) + +#### MCPClient — 独立初始化 + +`ms_agent/tools/mcp_client.py` 的 `MCPClient` 继承 `ToolBase`,**不依赖** `LLMAgent` 即可使用: + +```python +# 方式 1:async context manager +async with MCPClient(mcp_config) as client: + tools = await client.get_tools() + +# 方式 2:手动生命周期 +client = MCPClient(mcp_config) +await client.connect() +await client.cleanup() + +# 方式 3:运行时增量添加 +await client.add_mcp_config(extra_config) +``` + +单元测试 `tests/tools/test_mcp_client.py::test_outside_init` 专门验证了「Agent 外部独立初始化」场景。 + +#### 外部注入链路 + +``` +MCPClient (可选,外部创建) + ↓ mcp_client= +LLMAgent.__init__ + ↓ prepare_tools() +ToolManager(mcp_config, mcp_client) + ↓ connect() + ├─ 有外部 client → 复用,调用 add_mcp_config 合并增量配置 + └─ 无外部 client → 内部 new MCPClient 并 connect +``` + +关键代码路径: + +| 组件 | 行为 | +|------|------| +| `LLMAgent` | `kwargs['mcp_client']` 透传;`parse_mcp_servers()` 支持 `mcp_server_file` + 内联 `mcp_config` 合并 | +| `ToolManager` | `_managed_client = (mcp_client is None)`;外部 client 时 **不** 在 `cleanup()` 中断开连接;**已集成** `hook_runtime` + SafetyGuard + PermissionEnforcer(见 hooks-design §10) | +| `MCPClient` | 构造时可同时接收 `config`(agent.yaml tools)和 `mcp_config`(JSON 格式) | +| `Config.convert_mcp_servers_to_json` | 将 agent.yaml 中 `mcp: true` 的 tool 条目转为 `mcpServers` 字典 | + +#### 配置来源(当前分散) + +| 来源 | 格式 | 消费方 | +|------|------|--------| +| `agent.yaml` → `tools.*`(`mcp: true`) | YAML | `MCPClient(config=...)` | +| CLI `--mcp_config` / `mcp_server_file` | JSON 文件 | `LLMAgent.parse_mcp_servers` | +| WebUI `ConfigManager` | `~/.ms_agent/config.json` → `mcp_servers` | 仅 WebUI 层,未接入 SDK | +| 运行时 `mcp_config` 参数 | `{"mcpServers": {...}}` | `LLMAgent` / `MCPClient` | + +### 2.2 缺口(F7 需补齐) + +| 缺口 | 说明 | +|------|------| +| 无 `MCPRuntime` | 缺少统一的运行时状态机封装 | +| 无 per-server `enabled` | 配置和运行时均不支持开关 | +| 无 `disconnect_server` | `MCPClient` 只有 `add_mcp_config` / 全局 `cleanup`,无法单 server 下线 | +| 无 `MCPConfigManager` | 仅有 WebUI 侧简单 CRUD,无全局/项目分层、无与 ConfigResolver 集成 | +| ToolManager 无动态 reindex API | `reindex_tool()` **只追加、不清除** `_tool_index`;重复调用会触发 duplicate assert;无「移除 disabled server 工具」的公开接口 | +| 配置语义不统一 | YAML tools 与 JSON mcpServers 两套格式,缺少归一化层 | +| Session 级覆盖未定义 | F3 提到 session 层,MCP 热更新时 session 与全局的生效边界未明确 | + +--- + +## 3. 设计原则 + +1. **连接与治理分离**:`MCPClient` = 传输连接;`MCPRuntime` = 启停/开关/状态;`MCPConfigManager` = 持久化 CRUD;`ConfigResolver` = 多层合并。 +2. **注入优先于自建**:Session / Workflow 多 Agent 共享连接时,由上层创建 `MCPClient` 并注入,避免重复拉起 stdio 子进程。 +3. **软禁用优先、硬断开按需**:一期 `enabled=false` 通过 ToolManager 索引过滤即可满足「完全不可用」;二期再补 per-server 硬断开以释放资源。 +4. **不破坏 CLI 兼容**:`Config.from_task()` 路径保持不变;新能力通过 `ConfigResolver` + `MCPRuntime` 供 Playground 使用。 +5. **连接只发生一次(模式 A)**:Playground 由 `MCPRuntime.start()` 独占 connect;`ToolManager` 注入外部 client 且 `mcp_config` 为空时不再 `add_mcp_config`。 +6. **单向依赖**:`MCPRuntime → ToolManager`(`bind_tool_manager` + `sync_tools`);`ToolManager` 不引用 `MCPRuntime`,失败上报走可选 `mcp_failure_handler` 回调。 +7. **部分失败可启动**:`connect_policy=skip`(默认)— 单 server 连接失败不阻断 Agent;`fail_fast` 保留给 CLI。 +8. **与 Hooks 分治、管线协作**:`MCPRuntime` 与 `HookRuntime`(F6)**不合并**;`ToolManager` 可持有 `hook_runtime`(已有),MCP 通过 `mcp_callable_check` / `mcp_failure_handler` 回调接入;MCP `degraded` 检查须在 PreToolUse **之前**(见 §7.4)。 + +--- + +## 4. 总体架构 + +```mermaid +flowchart TB + subgraph ConfigLayer["配置层 (F3)"] + CR[ConfigResolver] + MCM[MCPConfigManager] + CR --> MCM + end + + subgraph RuntimeLayer["运行时层 (F7)"] + MR[MCPRuntime] + MC[MCPClient] + MR --> MC + end + + subgraph AgentLayer["Agent 层"] + LA[LLMAgent] + TM[ToolManager] + HR[HookRuntime] + LA --> TM + LA --> HR + HR -->|Pre/PostToolUse| TM + end + + MR -->|bind + sync_tools| TM + TM --> MC + + subgraph UI["UI / API"] + WebUI[WebUI Backend] + CLI[CLI / TUI] + end + + MCM --> CR + CR -->|resolved mcpServers| MR + MR -->|mcp_client inject| LA + WebUI --> MCM + WebUI --> MR + CLI --> CR + CLI --> LA +``` + +### 4.1 模块职责 + +| 模块 | 路径(建议) | 职责 | +|------|-------------|------| +| `MCPConfigManager` | `ms_agent/config/mcp_manager.py` | 全局/项目 MCP 条目 CRUD、`enabled` 持久化、导入导出 | +| `ConfigResolver` | `ms_agent/config/resolver.py` | 五层合并,输出归一化后的 `ResolvedMCPConfig` | +| `MCPRuntime` | `ms_agent/mcp/runtime.py` | 连接状态机、`connect_policy`、enable/disable、reload、`sync_tools`(持 `_sync_lock`) | +| `MCPClient` | `ms_agent/tools/mcp_client.py` | **保持现有**,小幅增强 per-server 操作 | +| `ToolManager` | `ms_agent/tools/tool_manager.py` | 已有 `hook_runtime`;新增 `_clear_mcp_index_entries` + `sync_mcp_tools()`、`mcp_callable_check` / `mcp_failure_handler` 回调;**不**持有 `MCPRuntime` | +| `HookRuntime` | `ms_agent/hooks/runtime.py` | F6 已落地;Pre/PostToolUse 嵌入 `single_call_tool`;与 MCP 分治 | + +--- + +## 5. 配置模型 + +### 5.1 归一化 Schema + +所有来源最终合并为: + +```json +{ + "mcpServers": { + "": { + "enabled": true, + "transport": "stdio | sse | streamable_http | websocket", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path"], + "url": "https://...", + "env": {"API_KEY": ""}, + "headers": {}, + "timeout": 120, + "include": ["tool_a"], + "exclude": ["tool_b"], + "source": "global | project | agent_yaml | plugin", + "meta": { + "description": "", + "added_at": "2026-06-15T00:00:00Z" + } + } + } +} +``` + +**字段说明**: + +- `enabled`:**用户配置意图**,默认 `true`;`false` 时不出现在 tool 列表、不主动连接。不因连接/调用失败自动改写。 +- `mcp`(agent.yaml 遗留):合并时若 `mcp: false` 则该条目表示 **系统内置工具**(如 `filesystem`),**不进入** `mcpServers`,由 `ToolManager.extra_tools` 提供;若与下层(global)同名,**遮蔽** global 中的 MCP 定义(见 §5.4 用例 7)。 +- 同名 server:项目级覆盖全局级连接参数,但 `enabled` 以**最具体层**为准(session > project > global)。 + +**`enabled` vs 运行时 `status`(勿混淆)**: + +| 字段 | 含义 | 持久化 | 示例 | +|------|------|--------|------| +| `enabled` | 用户是否要启用此 server | 是 | 用户在设置页关闭 → `enabled=false` | +| `status` | 当前连接健康度 | 否(内存) | 调用超时 → `status=degraded`,`enabled` 仍为 `true` | + +### 5.2 存储布局 + +``` +~/.ms_agent/ + settings.json # 全局 MCP 条目(MCPConfigManager global) + mcp.json # 可选:兼容 Cursor/Claude Desktop 格式导入 + +/.ms-agent/ + project.json # 项目元信息 + mcp.json # 项目级 MCP 补丁(enabled + 新增 server) + hooks.json # F6 已占用:项目级 Hooks(与 mcp.json 并列) + hooks/ # F6 已占用:Hook 脚本目录 + +// + session.json # session 级 enabled 覆盖(可选,一期可不做) +``` + +### 5.3 合并规则(与 F3 对齐) + +``` +resolved = merge( + framework_defaults, # 空 + global_settings, # ~/.ms_agent/settings.json + agent_yaml_mcp, # Config.convert_mcp_servers_to_json + project_mcp_patch, # /.ms-agent/mcp.json + session_override # 可选 +) +``` + +- **并集**:按 `server_name` 合并,后者覆盖前者同名字段。 +- **enabled 继承**:未显式设置时继承上层;显式 `false` 在下层可重新 `true`(session 级)。 +- **agent.yaml 归一化(`ConfigResolver` 职责,不可依赖 `convert_mcp_servers_to_json` 原样输出)**: + + 现有 `Config.convert_mcp_servers_to_json` 对 `tools.*` 做 `deepcopy`,会带入 YAML 专有条目(如 `implementation`、`mcp` 标志位、`trust_remote_code` 等)。**合并层须显式归一化**,规则: + + | 处理 | 字段示例 | + |------|----------| + | **剔除**(不进 `mcpServers`) | `mcp: false` 的整个 tool 条目;`implementation`;`enabled` 以外的 agent 元数据 | + | **保留**(连接相关) | `command`、`args`、`url`、`transport` / `type`、`env`、`headers`、`timeout`、`include`、`exclude` | + | **补充** | `enabled`(默认 `true`)、`source`(`agent_yaml`) | + + 实现建议:在 `ConfigResolver.resolve_mcp()` 内对 `agent_yaml_mcp` 层调用 `normalize_mcp_server_entry(entry)`,**不修改** `convert_mcp_servers_to_json` 本身(CLI 路径保持兼容);Playground / `MCPRuntime` 只消费归一化后的 `ResolvedMCPConfig`。 + +- **内置工具遮蔽(agent.yaml `mcp: false`)**:`ConfigResolver` 在合并完成后,收集 agent.yaml 中所有 `mcp: false` 的 tool 名称(如 `filesystem`),从 `resolved.mcpServers` **移除** 同名条目。即使 global 层配置了同名 MCP server,agent 任务以内置实现为准。 + +### 5.4 合并用例(实现须覆盖) + +| # | global | agent_yaml | project | session | 期望 `fetch.enabled` | 期望 `fetch.command` | 说明 | +|---|--------|------------|---------|---------|---------------------|----------------------|------| +| 1 | `fetch: {command: A}` | — | — | — | `true` | `A` | 仅全局 | +| 2 | `fetch: {enabled: false}` | — | `project: {enabled: true}` | — | `true` | 继承 global | 项目级覆写 enabled | +| 3 | `fetch: {command: A}` | `fetch: {command: B}` | — | — | `true` | `B` | agent_yaml 覆盖连接参数 | +| 4 | `fetch: {command: A}` | — | `project: {command: C}` | — | `true` | `C` | 项目级覆盖连接参数 | +| 5 | `fetch: {command: A}` | — | `remove fetch`(项目级删除标记) | — | 不存在或 `enabled: false` | — | 项目级「删除」= 遮蔽 global,不删 global 文件 | +| 6 | `fetch: {enabled: false}` | — | — | `session: {enabled: true}` | `true` | 继承 global | session 可重新启用(Phase 3) | +| 7 | 同名 `filesystem` | `filesystem: {mcp: false}`(内置 tool) | — | — | 不存在 | — | 内置 tool 遮蔽 global MCP;`filesystem` 走 `extra_tools`,不进 `mcpServers` | + +**优先级小结**:同名字段以后层覆盖前层;`enabled` 取最具体层显式值,未设置则继承;项目级 `remove` 等价于在该层写入 `enabled: false` 或删除条目(实现二选一,须与 `MCPConfigManager.remove` 语义一致)。 + +### 5.5 `enabled` 双源语义(持久化 vs 运行时) + +系统中存在两个独立的 `enabled` 概念,**不可混用**: + +| 来源 | 存储 | 写入 API | 生命周期 | 用途 | +|------|------|----------|----------|------| +| **持久化 `enabled`** | `settings.json` / `mcp.json`(经 `MCPConfigManager`) | `set_enabled` / `update` / UI 设置页 | 跨 session 持久 | 用户长期开关意图 | +| **运行时 `enabled`** | `MCPServerState.enabled`(内存) | `MCPRuntime.enable_server` / `disable_server` | 当前 session / Agent 进程 | 临时调试、热开关(不写盘) | + +**协作规则**: + +1. **UI / 设置页改开关** → `MCPConfigManager.set_enabled` → `ConfigResolver.resolve_mcp` → `MCPRuntime.apply_config`(持久化为准)。 +2. **仅运行时开关**(调试)→ `MCPRuntime.enable_server` / `disable_server`;**不**调用 `MCPConfigManager`;session 结束或 `apply_config` 后丢失。 +3. **`apply_config` 优先级高于运行时覆盖**:从 Resolver 重载配置时,以合并后的 `enabled` 覆盖内存状态。 +4. **连接失败不改 `enabled`**:仅更新 `status`(`error` / `degraded`);`enabled` 仍反映用户配置意图。 + +--- + +## 6. MCPRuntime 详细设计 + +### 6.1 状态机 + +每个 server 维护独立状态: + +``` + ┌─────────────┐ + 配置新增 │ REGISTERED │ enabled=false 时停留 + ──────────► │ (已注册) │ + └──────┬──────┘ + │ enable + connect + ▼ + ┌─────────────┐ + │ CONNECTING │ + └──────┬──────┘ + 成功 │ │ 失败 + ▼ ▼ + ┌──────────┐ ┌────────┐ + │ CONNECTED│ │ ERROR │──► 可重试 connect + └────┬─────┘ └────────┘ + │ 运行中 call_tool / list_tools 失败(连接已断) + ▼ + ┌──────────┐ + │ DEGRADED │──► 仅记录失败;等用户手动「重连」或下次 session 启动 + └────┬─────┘ + disable │ + ◄──────────────┘ + (一期:索引移除,连接可保留) + (二期:disconnect_server) +``` + +### 6.2 类接口 + +```python +@dataclass +class MCPFailureRecord: + """单次失败快照(内存,供 UI 展示与诊断)。""" + at: str # ISO8601 + phase: Literal["connect", "call_tool", "list_tools"] + tool_name: str | None = None + message: str + + +@dataclass +class MCPServerState: + name: str + config: dict + enabled: bool + status: Literal[ + "registered", "connecting", "connected", + "degraded", # 曾连上,运行中不可用(不自动重连) + "error", # 初次连接失败 + "disabled", # enabled=false + ] + last_error: str | None = None # 最近一次失败摘要(UI 直接展示) + last_success_at: str | None = None + last_failure_at: str | None = None + consecutive_failures: int = 0 # 连续失败次数(成功调用后归零) + failure_history: list[MCPFailureRecord] # 环形缓冲,默认保留最近 20 条 + tool_count: int = 0 + cached_tools: list[dict] = field(default_factory=list) # 上次 list_tools 成功快照(见 §6.5.1) + connected_at: str | None = None + + +class MCPRuntime: + """MCP 运行时管理器。封装 MCPClient,对上提供配置驱动的启停与状态查询。""" + + def __init__( + self, + *, + mcp_client: MCPClient | None = None, + config: ResolvedMCPConfig | None = None, + owns_client: bool | None = None, + connect_policy: Literal["skip", "fail_fast"] = "skip", + ): ... + _sync_lock: asyncio.Lock # apply_config / sync_tools 互斥 + + # ── 生命周期 ── + async def start(self) -> None: + """连接所有 enabled server,幂等。 + + connect_policy: + - skip(默认):单 server 连接失败 → status=error,继续连接其余 server,不 raise。 + - fail_fast:任一失败即 raise(兼容现有 MCPClient.connect 行为,供 CLI 等场景)。 + """ + + async def stop(self) -> None: + """断开全部连接。仅 owns_client=True 时调用 client.cleanup()。""" + + # ── 运行时开关 ── + async def enable_server(self, name: str) -> MCPServerState: + """enabled=true → connect(若未连接)→ 通知 ToolManager 同步。""" + + async def disable_server(self, name: str) -> MCPServerState: + """enabled=false → 从 ToolManager 移除工具;二期再断开连接。""" + + async def reload_server(self, name: str) -> MCPServerState: + """一期:disable(软)+ sync_tools + enable(软重连索引);二期:硬 disconnect + connect。""" + + # ── 配置热更新 ── + async def apply_config(self, config: ResolvedMCPConfig) -> list[MCPServerState]: + """diff 新旧配置:新增 connect、删除 disable、变更 reload。持 _sync_lock,与 sync_tools 互斥。""" + + # ── 查询 ── + def list_servers(self) -> list[MCPServerState]: ... + def get_server(self, name: str) -> MCPServerState | None: ... + + # ── ToolManager 集成(单向依赖:Runtime → ToolManager)── + def bind_tool_manager(self, tool_manager: ToolManager) -> None: ... + async def sync_tools(self) -> None: + """根据 enabled + 可见性规则刷新 ToolManager._tool_index(见 §6.5、§7.1)。持 _sync_lock。""" + + # ── 失败记录(Phase 2)── + async def record_failure( + self, name: str, phase: str, message: str, *, tool_name: str | None = None + ) -> None: + """记录失败并置 status=degraded。不触发自动重连。""" + + async def reconnect_server(self, name: str) -> MCPServerState: + """用户手动重连:断开并重连单个 server(UI「重连」/ reload_server)。""" + + def is_callable(self, server_name: str) -> bool: + """server 是否允许发起 RPC(仅 status=connected)。""" +``` + +### 6.3 与 MCPClient 的关系 + +| 场景 | MCPRuntime 行为 | +|------|----------------| +| 未注入 `mcp_client` | 内部 `MCPClient(resolved_config)`,`owns_client=True` | +| 注入外部 `mcp_client` | 复用连接,`owns_client=False`,`stop()` 不 cleanup | +| 新增 server | 调用 `MCPClient.add_mcp_config({...})` | +| 禁用 server(一期) | 不调 `cleanup`,由 `sync_tools()` 过滤 `_tool_index` | +| 禁用 server(二期) | 新增 `MCPClient.disconnect_server(name)` | + +### 6.3.1 连接职责划分(避免重复 connect) + +Playground 推荐 **模式 A(Runtime 独占连接)**: + +| 角色 | 职责 | +|------|------| +| `MCPRuntime.start()` | 唯一连接入口;按 `connect_policy` 连接所有 `enabled` server | +| `LLMAgent` | `mcp_client=mcp_runtime.client`,**`mcp_config={}`**(或仅传 agent.yaml 中 Runtime 未管理的增量,一般为空) | +| `ToolManager.connect()` | 外部 client 分支:**不再**对已由 Runtime 连接的 server 调用 `add_mcp_config`;`self.servers = mcp_client` 后仅 `reindex` / 等待 Runtime `sync_tools` | + +**模式 B(CLI 兼容,无 Runtime)**:保持现状 — `ToolManager` 内部 `MCPClient.connect()` 或外部 client + `add_mcp_config`。 + +**为何必须区分**:现有 `MCPClient.connect()` / `add_mcp_config` 会 **原地 `pop` env/exclude/timeout**,导致配置对象被变异;若 Runtime `start()` 与 `ToolManager.connect()` 各连一次,去重比较 `servers[name] == server` 会失效,且配置变更时旧 session 可能残留。模式 A 下连接只发生一次。 + +```python +# 模式 A 推荐写法(Playground Session 层) +mcp_runtime = MCPRuntime(config=resolved_mcp, connect_policy="skip") +await mcp_runtime.start() + +agent = LLMAgent( + config=agent_config, + mcp_client=mcp_runtime.client, + mcp_config={}, + mcp_runtime=mcp_runtime, +) +await agent.prepare_tools() # 内部注入 hook_runtime + mcp_* 回调;见 §9.1 +# sync_tools 在 prepare_tools 末尾或此处显式调用 +``` + +**漏调 `start()` 时的行为(代码已补齐)**: + +| 场景 | 行为 | +|------|------| +| Session 层已 `await mcp_runtime.start()` | 正常路径;`prepare_tools` 内检测到 `is_started=True`,跳过 | +| 注入 `mcp_runtime` 但未调 `start()` | `LLMAgent.prepare_tools()` **自动调用** `await mcp_runtime.start()`(幂等) | +| 无 `mcp_runtime`(模式 B / CLI) | 不介入;`ToolManager.connect()` 按现有逻辑连接 | + +推荐仍由 Session 层**显式** `start()`,以便在创建 Agent 前完成连接、收集 `status=error` 并展示给用户;`prepare_tools` 的自动 `start()` 为兜底,避免静默无工具。 + +`prepare_tools()` 结束时若存在 `mcp_runtime`,`cleanup_tools()` 会调用 `mcp_runtime.stop()`(`owns_client=True` 时释放连接)。 + +### 6.4 MCPClient 增强(最小改动) + +```python +# ms_agent/tools/mcp_client.py 新增 + +async def disconnect_server(self, server_name: str) -> None: + """断开单个 server。需将 per-server 资源从 exit_stack 中独立管理(二期)。""" + +def list_connected_servers(self) -> list[str]: + return list(self.sessions.keys()) + +def is_connected(self, server_name: str) -> bool: + return server_name in self.sessions +``` + +> **实现说明**:`MCPClient` 已采用 per-server `AsyncExitStack`(`_server_stacks`),支持 `disconnect_server`;`async with` / `__aexit__` 调用 `cleanup()` 释放全部 server 连接。 + +### 6.5 运行中失败与历史记录(仅手动重连) + +**场景**:`initialize()` 成功(`status=connected`),但后续 `call_tool` / `list_tools` 因进程退出、网络断开、SSE 超时等失败。 + +**不做自动重连**:MCP server 挂掉可能是预期状态(进程崩溃、用户关停、配置错误)。自动重连会反复拉起子进程或打远程接口,浪费资源且拖慢 Agent。故障后**只记录、视严重程度上报 `degraded`,等用户决定**是否重连。 + +#### 失败分类与 degraded 策略 + +| 分类 | 示例 | 首次失败行为 | 进入 `degraded` 条件 | +|------|------|-------------|---------------------| +| **hard**(硬断开) | `BrokenPipeError`、`session closed`、`connection refused` | 记录失败 | **立即** `status=degraded` | +| **transient**(瞬时) | `asyncio.TimeoutError`、HTTP 502/503、消息含 `timeout` | 记录失败,`status` 保持 `connected` | 连续失败 **≥ 3 次**(`DEGRADED_FAILURE_THRESHOLD`) | +| **none**(业务) | 参数错误、工具返回 `isError` | 可选记入 `failure_history` | **不**改 `status` | + +单次 timeout 可能是网络抖动,**不得**因一次超时就 `degraded` 并从 LLM 工具列表移除。成功调用后 `consecutive_failures` 归零(`record_success`,**仅当 `status=connected` 时**)。 + +**`degraded` 恢复策略(定稿)**: + +- `status=degraded` 后 **不**因后续成功 RPC 自动恢复为 `connected`(`record_success` 只重置计数器,不改 `degraded` 状态)。 +- 恢复路径:**仅** `reconnect_server` / `reload_server`、或新建 session 重新 `start()`。 +- 外层 `asyncio.wait_for` 超时(`ToolManager` 层)与 MCP 传输层 `TimeoutError` 均视为 **transient**,计入 `consecutive_failures`,达阈值后 `degraded`。 + +实现:`classify_mcp_failure(exc)` 区分三类;`ToolManager` 将原始 `exc` 传入 `record_failure(..., exc=exc)`(含外层 `asyncio.TimeoutError` 与 `sync_mcp_tools` 的 `list_tools` 失败)。 + +#### 失败记录策略 + +| 项目 | 设计 | +|------|------| +| 存储位置 | **内存**(`MCPServerState.failure_history`),不写入 `enabled` | +| 保留条数 | 每 server 最近 **20** 条(环形缓冲) | +| 记录内容 | 时间、`phase`(connect / call_tool / list_tools)、`tool_name`、`message` | +| UI 展示 | `GET /api/mcp/servers` 返回 `last_error` + `failure_history`(最近 5 条)+ 操作按钮「重连」 | +| Session 日志 | 可选写入 `sessions//mcp_events.jsonl`(P1,便于排查长任务) | + +**不改 `enabled`**:运行失败只更新 `status`(可能为 `degraded`),并累加 `consecutive_failures`。 + +#### 各 `status` 下工具可见性与调用语义(产品定稿) + +**LLM 侧**(`ToolManager.get_tools()` → 模型 tool 列表)与 **UI 侧**(`GET /api/mcp/servers`)分离: + +| status | LLM `get_tools()` 可见 | 允许 `call_tool` | UI `GET /api/mcp/servers` | 行为 | +|--------|------------------------|------------------|---------------------------|------| +| `registered` | 否 | 否 | 是 | 已注册未连接 | +| `connecting` | 否 | 否 | 是 | 连接中 | +| `connected` | **是** | **是** | 是 | 正常 | +| `error` | 否 | 否 | 是 | 初次连接失败;`last_error` 可供诊断 | +| `disabled` | 否 | 否 | 是 | `enabled=false` | +| `degraded` | **否** | 否 | **是**(含 `last_error` / `failure_history`) | 不可调用;**不对 LLM 展示**;用户通过 UI「重连」恢复 | + +`sync_tools()` 索引规则:仅 `enabled=true` 且 `status=connected` 的 server 写入 `_tool_index`(供 LLM)。`degraded` / `error` / `disabled` 不出现在 LLM 工具列表;运维信息仅通过 MCP 状态 API 暴露。 + +`single_call_tool` 对仍在索引中但不可调用的 server(竞态窗口)在 RPC 前通过 `mcp_callable_check` 短路。 + +#### 6.5.1 per-server 工具列表与隔离 + +`MCPClient.get_tools()` 须 **按 server 隔离**:单个 server `list_tools` 失败不得导致其他 server 工具不可用。 + +| API | 行为 | +|-----|------| +| `get_tools_for_server(name)` | 仅拉取指定 server;失败 raise | +| `get_tools()` | 遍历各 session,单 server 失败记日志并返回空列表,**不** raise | + +`sync_tools()` / `ToolManager.sync_mcp_tools()` 对每个 `indexable` server 调用 `get_tools_for_server`,互不影响。 + +`MCPServerState.cached_tools` 仍在 `connect` / `list_tools` 成功时更新,供 UI 展示 `tool_count`;**不**用于向 LLM 展示 `degraded` 工具。 + +#### 失败后的行为(无自动动作) + +``` +1. call_tool / list_tools 传输类异常 → classify → record_failure(transient 可能保持 connected) +2. hard 或 transient 达阈值 → status=degraded → 从 LLM 工具索引移除(sync_tools) +3. 不重连、不自动重试 +4. 用户手动 reconnect_server / reload_server → 恢复 connected 并重新入索引 +5. 新建 session / Agent 重新 start → 按 enabled 正常 connect +``` + +若 MCP 调用已发起 RPC 后失败,返回的错误文本仍会作为 `tool_result` 触发 **PostToolUse**(与 hooks-design §8.5 一致);业务层 `isError` 不改 `status`。 + +**hard** 示例:`BrokenPipeError`、`session closed`、`connection refused`。 +**transient** 示例:`TimeoutError`、HTTP 502/503。 +业务错误(工具返回 `isError`、参数非法)只记入 `failure_history`(可选),**不**改 `status`。 + +#### 与 ToolManager 协作(单向依赖) + +`ToolManager` **不**持有 `MCPRuntime` 引用(Hooks 已持有 `_hook_runtime` 是既有设计,MCP 仍用回调避免第三个 runtime 引用)。失败上报与 callable 检查: + +```python +# MCPRuntime 注册到 ToolManager 的轻量回调(非双向引用) +# tool_manager.mcp_callable_check: Callable[[str], bool] | None +# tool_manager.mcp_failure_handler: Callable[[str, str, str, str | None], Awaitable[None]] | None + +# single_call_tool 内(见 §7.4 步骤 ⑥),MCP 工具连接类异常时: +if self.mcp_failure_handler and is_connection_error(exc): + await self.mcp_failure_handler(server_name, "call_tool", str(exc), tool_name) +``` + +#### 状态流转 + +``` +CONNECTED ──call_tool 失败(连接类)──► DEGRADED(停在这里,等用户) + ▲ │ + │ 用户手动 reconnect_server │ + └──────────────────────────────────────┘ +``` + +`degraded` 期间 Agent 可继续用其他工具;用户修好 MCP 后点「重连」即可恢复,无需改 `enabled`。 + +--- + +## 7. ToolManager 集成 + +### 7.1 改动点 + +**不可直接复用 `reindex_tool()`**:现有实现只向 `_tool_index` 追加条目,不清除旧 MCP key;重复 sync 会触发 `Tool name duplicated` assert。须新增专用路径。 + +```python +class ToolManager: + def __init__( + self, + ..., + hook_runtime=None, # 已有(Hooks F6) + mcp_callable_check=None, # 新增:MCPRuntime.is_callable 绑定 + mcp_failure_handler=None, # 新增:MCPRuntime.record_failure 绑定 + ): + self._hook_runtime = hook_runtime + self.mcp_callable_check = mcp_callable_check + self.mcp_failure_handler = mcp_failure_handler + self._tool_index = {} + self._mcp_index_keys: set[str] = set() + + def _clear_mcp_index_entries(self) -> None: + """仅移除 MCP 来源的 _tool_index 条目,不影响 extra_tools。""" + for key in self._mcp_index_keys: + self._tool_index.pop(key, None) + self._mcp_index_keys.clear() + + async def sync_mcp_tools( + self, + *, + visible_servers: set[str], + indexable_servers: set[str], # 写入索引的 server(仅 connected) + callable_servers: set[str], # 允许 RPC 的 server(仅 connected) + cached_tools_by_server: dict[str, list[dict]] | None = None, + # 保留参数兼容;LLM 索引不再使用 degraded 缓存 + ) -> None: + """重建 MCP 相关 _tool_index 条目。由 MCPRuntime.sync_tools() 调用。""" + async with self._sync_lock: + self._clear_mcp_index_entries() + if self.servers is None: + return + for server_name in indexable_servers: + # 按 server 调用 get_tools_for_server,单 server 失败不影响其他 + ... + + async def connect(self): + if self.mcp_client and isinstance(self.mcp_client, MCPClient): + self.servers = self.mcp_client + # 模式 A:mcp_config 为空时跳过 add_mcp_config(Runtime 已连接) + if self.mcp_config and self.mcp_config.get('mcpServers'): + await self.servers.add_mcp_config(self.mcp_config) + self.mcp_config = self.servers.mcp_config + elif ...: + ... # 保持现有逻辑(模式 B) + ... + # 索引构建:有 MCPRuntime → sync_tools();否则 reindex_tool() + if not getattr(self, '_skip_mcp_reindex', False): + await self.reindex_tool() +``` + +`connect()` 在模式 A 下由 `LLMAgent.prepare_tools()` 设置 `_skip_mcp_reindex=True`(有 `mcp_runtime` 时),避免与 `sync_tools()` 重复建索引。 + +`MCPRuntime.sync_tools()` 根据 §6.5 可见性表计算三个集合后调用 `ToolManager.sync_mcp_tools(...)`。 + +### 7.2 调用时校验(摘要) + +MCP 相关校验嵌入 §7.4 管线**步骤 ①**(在 SafetyGuard / PreToolUse 之前): + +1. `mcp_callable_check(server_name)` 为 `False`(`degraded` / `disabled` / `error`)→ 返回结构化错误(含 `last_error`),**不**跑 PreToolUse、不发起 RPC。 +2. `enabled=false` 的 server 不应出现在索引中(双重保险)。 + +### 7.3 并发与热更新 + +`apply_config()` / `sync_tools()` / `sync_mcp_tools()` 可能与运行中的 `parallel_call_tool()` 并发修改 `_tool_index`(Hooks 下 `parallel_call_tool` 对各工具独立调用 `single_call_tool`,见 hooks-design §11.2)。 + +- `ToolManager` 增加 `_sync_lock: asyncio.Lock`(与 `_init_lock` 分离)。 +- `_sync_lock` **仅保护** `sync_mcp_tools`,**不**包裹 PreToolUse 子进程(避免与 Hooks 串行化)。 +- `single_call_tool` 在步骤 ① 之前对 `_tool_index[tool_name]` 取**快照** `(tool_ins, server_name, tool)`,避免 await 期间索引被热更新替换。 +- `MCPRuntime.apply_config` 与 `sync_tools` 共用 `_sync_lock`。 + +### 7.4 与 Hooks 管线协作(`single_call_tool` 完整顺序) + +Hooks F6 已落地后,`ToolManager.single_call_tool()` 为 MCP 与 Hooks 的**唯一交汇点**。MCP 扩展在 hooks-design §10.2 既有管线**之前**插入步骤 0–1,整体顺序如下: + +``` +ToolManager.single_call_tool(tool_info) + │ + ├─ 0. _tool_index 快照 (tool_ins, server_name, tool) ← MCP 新增 + │ + ├─ 1. MCP callable 检查 ← MCP 新增(degraded/disabled/error) + │ └─ mcp_callable_check(server_name)==False → 返回 last_error JSON,短路 + │ + ├─ 2. SafetyGuard.check() ← hooks-design §10.2 步骤 1 + │ + ├─ 3. HookRuntime.run_pre_tool_use() ← §10.2 步骤 2 + │ + ├─ 4. resolve_hook_permission_decision() ← §10.2 步骤 3 + │ + ├─ 5. tool_ins.call_tool() ← §10.2 步骤 4 + │ └─ 连接类异常 → mcp_failure_handler → record_failure → degraded + │ + └─ 6. HookRuntime.run_post_tool_use() ← §10.2 步骤 5 +``` + +#### 7.4.1 与 hooks-design §10.2 的步骤关系 + +| 文档 | 第一步 | 说明 | +|------|--------|------| +| hooks-design §10.2 | SafetyGuard | 描述 **Hooks 落地时**的基线管线(无 MCP Runtime) | +| 本文 §7.4 | MCP callable → SafetyGuard → … | F7 在基线**之前**增加 MCP 可用性门禁 | + +**为何 MCP 检查在 SafetyGuard 之前**: + +- `degraded` / `error` 时**不发起 RPC**,亦无参数可审;提前短路可避免无意义的 SafetyGuard / PreToolUse 子进程开销 +- 不构成安全绕过:调用在步骤 ① 已被拒绝,与「SafetyGuard 不可绕过**已执行的 tool call**」不冲突 +- `connected` 的 MCP 工具仍完整走 SafetyGuard → PreToolUse → Permission → RPC + +实施时须在 [hooks-design.md §10.2](../zh/design/hooks-design.md#102-目标执行顺序插入-hooks-后) 补交叉引用(见该文档 §10.2.1)。 + +**关键语义**: + +| 场景 | PreToolUse | RPC | PostToolUse | +|------|------------|-----|-------------| +| `enabled=false` | 不触发(不在索引) | 否 | 否 | +| `degraded` | 不触发(不在 LLM 索引) | 否 | 否 | +| `connected` + Hook `deny` | 触发后拒绝 | 否 | 否 | +| `connected` + 调用成功 | 触发 | 是 | 是 | +| `connected` + 连接类失败 | 触发 | 是(失败) | 是(`tool_result` 为错误文本) | + +MCP 工具名格式 `fetch---tool_name` 与 Hooks matcher(`fetch---*`)及 Permission 白名单**共用** `server---tool` 约定。 + +```python +# single_call_tool 内步骤 ① 示意 +tool_ins, server_name, _ = index_snapshot +if tool_ins is self.servers and self.mcp_callable_check is not None: + if not self.mcp_callable_check(server_name): + return json.dumps({ + 'success': False, + 'error': 'mcp_unavailable', + 'server_name': server_name, + 'message': ..., # 含 last_error + }) +# 此后进入 SafetyGuard → PreToolUse → ... +``` + +--- + +## 8. MCPConfigManager 详细设计 + +```python +class MCPConfigManager: + """全局 / 项目两级 MCP 配置持久化。""" + + def __init__(self, global_root: Path, project_root: Path | None = None): ... + + # CRUD + def list(self, scope: Literal["global", "project", "merged"]) -> dict: ... + def get(self, name: str, scope: str = "merged") -> dict | None: ... + def add(self, name: str, server: dict, scope: str = "project") -> None: ... + def update(self, name: str, patch: dict, scope: str = "project") -> None: ... + def remove(self, name: str, scope: str = "project") -> None: ... + + # 开关(持久化) + def set_enabled(self, name: str, enabled: bool, scope: str = "project") -> None: ... + + # 导入导出 + def import_cursor_format(self, path: Path, merge: bool = True) -> int: ... + def export_mcp_json(self, path: Path, scope: str = "merged") -> None: ... + + # 环境变量 + def resolve_env(self, server: dict) -> dict: + """空字符串 env 值从 Env.load_env() 填充(与 MCPClient 一致)。""" +``` + +与 `ConfigResolver` 协作: + +```python +class ConfigResolver: + def resolve_mcp(self, project_id: str | None, session_id: str | None) -> ResolvedMCPConfig: + ... +``` + +--- + +## 9. 嵌入 Agent 的路径 + +### 9.1 Playground(推荐路径,模式 A + Hooks) + +Hooks 已在 `LLMAgent.prepare_tools()` 内构造 `hook_runtime` 并传入 `ToolManager`。Playground Session 层**先**创建 `MCPRuntime`,再交给 `LLMAgent`;MCP 回调在 `prepare_tools` 内与 Hooks 并列注入。 + +```python +resolver = ConfigResolver(global_config, project_manager) +resolved_mcp = resolver.resolve_mcp(project_id, session_id) + +mcp_runtime = MCPRuntime(config=resolved_mcp, connect_policy="skip") +await mcp_runtime.start() # 唯一 MCP 连接入口 + +agent = LLMAgent( + config=resolver.resolve_agent_config(...), + mcp_client=mcp_runtime.client, + mcp_config={}, # 模式 A:避免二次 add_mcp_config + mcp_runtime=mcp_runtime, +) +await agent.prepare_tools() # 内部:build_hook_runtime + ToolManager(..., hook_runtime=...) + +# prepare_tools 扩展(LLMAgent 内,当 mcp_runtime 非空时): +# 若未 start() → 自动 await mcp_runtime.start()(幂等兜底) +# tool_manager.mcp_callable_check = mcp_runtime.is_callable +# tool_manager.mcp_failure_handler = mcp_runtime.record_failure # 传入 exc= +# tool_manager.mcp_success_handler = mcp_runtime.record_success +# tool_manager._skip_mcp_reindex = True +# connect() 后不 reindex MCP;改由 sync_tools 建索引 +# cleanup_tools → mcp_runtime.stop()(owns_client 时释放连接) + +mcp_runtime.bind_tool_manager(agent.tool_manager) +await mcp_runtime.sync_tools() +``` + +`prepare_tools()` 目标形态(摘录,与 hooks-design §9.7 / §11.1 合并): + +```python +async def prepare_tools(self): + ... + hook_runtime = build_hook_runtime(self.config, session_id=session_id) + mcp_rt = self.mcp_runtime # 可选,Playground 注入 + + self.tool_manager = ToolManager( + self.config, + self.mcp_config if mcp_rt is None else {}, + self.mcp_client, + permission_enforcer=permission_enforcer, + safety_guard=safety_guard, + hook_runtime=hook_runtime, + mcp_callable_check=mcp_rt.is_callable if mcp_rt else None, + mcp_failure_handler=mcp_rt.record_failure if mcp_rt else None, + ... + ) + self._hook_runtime = hook_runtime + if mcp_rt is not None: + self.tool_manager._skip_mcp_reindex = True + await self.tool_manager.connect() + if mcp_rt is not None: + mcp_rt.bind_tool_manager(self.tool_manager) + await mcp_rt.sync_tools() +``` + +**依赖方向**:`LLMAgent` 持有 `_hook_runtime` + 可选 `_mcp_runtime`;`MCPRuntime → ToolManager` 单向 `bind`;`ToolManager` 持有 `hook_runtime`,MCP 仅回调。 + +配置热更新(开关 / CRUD)时: + +```python +resolved = resolver.resolve_mcp(project_id, session_id) +await mcp_runtime.apply_config(resolved) # 内部 diff + sync_tools(持锁) +``` + +### 9.2 CLI(保持兼容) + +```python +# 现有方式不变 +agent = LLMAgent(mcp_config=mcp_dict) +# 或 +agent = LLMAgent(mcp_server_file="mcp.json") + +# 可选增强:--mcp-runtime 开启运行时管理(P1) +``` + +### 9.3 多 Agent Workflow 共享 + +`ToolManager.connect()` 的外部 client 分支**具备**复用连接的能力,但现有 `ChainWorkflow` / `DagWorkflow` **未**创建或透传共享 `mcp_client`(仅透传 `mcp_server_file`),因此每个 Agent 仍会各自 `connect()`。 + +Playground **Session 层**须显式创建共享运行时: + +```python +mcp_runtime = MCPRuntime(config=shared_resolved_mcp, connect_policy="skip") +await mcp_runtime.start() + +for task in workflow_chains: + agent = LLMAgent( + config=..., + mcp_client=mcp_runtime.client, + mcp_config={}, + mcp_runtime=mcp_runtime, + ) + await agent.prepare_tools() + mcp_runtime.bind_tool_manager(agent.tool_manager) # 每 Agent 绑定当前 ToolManager + await mcp_runtime.sync_tools() + await agent.run(inputs) +# MCPClient 只 connect 一次;各 Agent 切换 bind_tool_manager 同步各自索引 +``` + +**HookRuntime 共享策略**(Workflow 多 Agent 须额外约定): + +| 资源 | 推荐 | 说明 | +|------|------|------| +| `MCPRuntime` / `MCPClient` | Session 级共享 | 避免重复 stdio 子进程 | +| `HookRuntime` | 默认 **per-Agent** | `build_hook_runtime(agent.yaml)` 可能因 task 不同而异 | +| `HookRuntime` | 可选 Session 级共享 | 同一 `hooks.json`、同一 `agent.yaml` 的 workflow 可复用 | + +Workflow 引擎后续可选 `shared_mcp_runtime` / `shared_hook_runtime`;一期由 Playground Session 编排层负责。 + +--- + +## 10. API 面(供 WebUI Backend) + +| Method | Path | 说明 | +|--------|------|------| +| GET | `/api/mcp/servers` | 列出合并后 server + 运行时状态 | +| POST | `/api/mcp/servers` | 新增 server(写 project scope) | +| PATCH | `/api/mcp/servers/{name}` | 更新配置 / enabled | +| DELETE | `/api/mcp/servers/{name}` | 删除 project 级条目 | +| POST | `/api/mcp/servers/{name}/reload` | 热重载单个 server | +| POST | `/api/mcp/import` | 导入 Cursor/Claude mcp.json | +| GET | `/api/mcp/status` | 连接健康检查汇总 | + +响应示例: + +```json +{ + "servers": [ + { + "name": "fetch", + "enabled": true, + "status": "degraded", + "tool_count": 1, + "tools": ["fetch---fetch"], + "source": "global", + "last_error": "call_tool timeout: fetch", + "last_failure_at": "2026-06-15T10:23:01Z", + "last_success_at": "2026-06-15T10:20:00Z", + "consecutive_failures": 2, + "failure_history": [ + { + "at": "2026-06-15T10:23:01Z", + "phase": "call_tool", + "tool_name": "fetch", + "message": "Connection closed" + } + ] + } + ] +} +``` + +现有 `webui/backend/config_manager.py` 的 `get_mcp_config` / `update_mcp_config` 应迁移为调用 `MCPConfigManager`,运行时状态从 `MCPRuntime` 读取。 + +### 10.1 WebUI API 迁移方案 + +**现状**:WebUI 使用 `GET/PUT /config/mcp`,读写 `~/.ms_agent/config.json` 中的 `mcp_servers`,无运行时状态、无项目级 `.ms-agent/mcp.json`、无热更新。 + +**目标**:Playground Session 持有 `MCPRuntime`;配置 CRUD 走 `MCPConfigManager`;运行时状态走 `MCPRuntime.list_servers()`。 + +**迁移步骤(兼容过渡)**: + +| 阶段 | 配置 API | 运行时 API | 说明 | +|------|----------|------------|------| +| **A(当前)** | `GET/PUT /config/mcp` | 无 | 旧 WebUI 只管理全局配置 | +| **B(并存)** | 保留 `/config/mcp`,内部委托 `MCPConfigManager` global scope | 新增 `GET /api/mcp/servers` 等 §10 路由 | Session 层注入 `mcp_runtime` 后 UI 可展示 `status` / `failure_history` | +| **C(收敛)** | 废弃 `/config/mcp`,统一 `/api/mcp/*` | 同上 | 项目级 CRUD、import/export、热重载全部走新 API | + +**职责划分**: + +- `MCPConfigManager`:持久化 `enabled`、连接参数(`settings.json` + `mcp.json` + `.ms-agent/mcp.json`) +- `ConfigResolver.resolve_mcp()`:供 Runtime `apply_config` 消费 +- `ConfigResolver.resolve_mcp_all_layers()`:供 UI 列表(含 disabled;**已**遮蔽 agent.yaml `mcp: false` 内置工具) +- `MCPRuntime`:`status` / `failure_history` / 热重载,**不**写盘 + +配置变更链路:`MCPConfigManager` 写盘 → `ConfigResolver.resolve_mcp()` → `MCPRuntime.apply_config()` → `sync_tools()`。 + +--- + +## 11. 与 Skill / Hooks 运行时管理的对比 + +| 维度 | Skill(F7) | MCP(F7) | Hooks(F6) | +|------|------------|-----------|-------------| +| 关闭语义 | 不注入 prompt,`/skill-name` 可手动触发 | `enabled=false` 移出索引,完全不可用 | PreToolUse `deny` 阻断**单次**调用 | +| 运行时层 | `SkillRuntime` | `MCPRuntime` | `HookRuntime` | +| 配置层 | `SkillsConfigManager` | `MCPConfigManager` | `HookRegistry` + 多源 Loader | +| 热更新 | 重新生成 prompt | `sync_tools()` 刷新索引 | 改配置后下次 `build_hook_runtime` | +| ToolManager 集成 | 无专用字段 | `mcp_*` 回调 | 直接持有 `_hook_runtime` | +| 工具名 | N/A | `server---tool` | matcher 同格式,**对 MCP 工具生效** | +| 连接成本 | 无 | stdio/HTTP 长连接 | 子进程脚本,按次触发 | + +Skill / MCP / Hooks 共享 `ConfigResolver` 合并思路,但运行时语义不同,**不应合并为同一个 Runtime 类**。MCP 与 Hooks 在 `single_call_tool` 交汇(§7.4)。 + +--- + +## 12. 实施分期 + +### Phase 1 — P0(SDK 核心,已落地) + +- [x] `MCPConfigManager`:全局/项目 CRUD + `enabled` 持久化 +- [x] `ConfigResolver.resolve_mcp()`:归一化合并(含 §5.3 `normalize_mcp_server_entry` + §5.4 用例) +- [x] `MCPRuntime`:start/stop、`connect_policy`(默认 skip)、`is_callable`、`enable_server` / `disable_server`(软禁用) +- [x] `ToolManager`:`sync_mcp_tools` + `mcp_callable_check` / `mcp_failure_handler`;§7.4 管线步骤 ①⑥ +- [x] `LLMAgent.prepare_tools()`:与 `hook_runtime` 并列注入 MCP 回调(§9.1) +- [x] 单元测试:§5.4 合并、软禁用、connect skip、Hooks 交叉(§14 `test_mcp_tool_pre_tool_use_deny`) +- [ ] **WebUI / Playground Session 接线**(§9.1、§10.1 阶段 B):仍使用旧 `/config/mcp` + +### Phase 2 — P0(MCP 运行时增强,SDK 已落地) + +- [x] `MCPClient.disconnect_server()`(per-server exit stack) +- [x] `reload_server` / `reconnect_server` 硬重连 +- [x] `apply_config` diff 自动增量 +- [x] **运行中失败记录**:`failure_history`、`last_error`、`consecutive_failures`(**不**自动重连) +- [x] 用户手动 `reconnect_server` +- [x] MCP 工具 `server---tool` + PreToolUse deny 回归(`test_mcp_tool_pre_tool_use_deny`) +- [x] 外层 `asyncio.TimeoutError` → `record_failure`(transient 计数) +- [x] `sync_mcp_tools` 的 `list_tools` 失败 → `record_failure` +- [x] `resolve_mcp_all_layers` 内置工具遮蔽 +- [x] `connection_params_for_client` 接入 `MCPRuntime._connect_server` + +### Phase 3 — P1(待做) + +- [x] Session 级 enabled 覆盖(合并逻辑 + `session_override` 参数;**未**接 `session.json` 持久化) +- [ ] 失败事件写入 `sessions//mcp_events.jsonl` +- [ ] Plugin `tools/` → MCP server 注册 +- [ ] 认证 OAuth 跳转(后端处理) +- [ ] WebUI API 迁移阶段 B/C(§10.1) + +--- + +## 13. 风险与对策 + +| 风险 | 对策 | +|------|------| +| 单一 `AsyncExitStack` 无法单 server 断开 | 一期软禁用;二期 per-server stack | +| stdio server 被多 Agent 重复拉起 | 文档明确推荐共享 `mcp_client`;Runtime 层 singleton 可选 | +| agent.yaml 与 mcp.json 同名冲突 | `ConfigResolver` 定义优先级并写测试 | +| 连接失败阻塞 Agent 启动 | **Phase 1 即实现** `connect_policy: skip`(默认);`fail_fast` 供 CLI;失败 server → `status=error`,不进入工具索引 | +| MCP server 运行中挂掉 | 标记 `degraded` + 记录历史;**不**自动重连,由用户点「重连」或新开 session | +| MCP 热更新与并行 PreToolUse 竞态 | `_sync_lock` 仅护 `sync_mcp_tools`;`single_call_tool` 索引快照;degraded 在 PreToolUse 前短路 | +| Hook 误拦 healthy MCP | matcher 显式配置;MCP `degraded` 不进入 PreToolUse | +| env 密钥泄露到 UI | `MCPConfigManager` 导出时脱敏;复用现有 key 泄露修复 | + +--- + +## 14. 测试策略 + +```python +# tests/mcp/test_mcp_runtime.py + +async def test_independent_client_injection(): + """验证用户记忆:MCPClient 可独立于 Agent 创建。""" + +async def test_disable_removes_tools_but_keeps_session(): + """软禁用后 get_tools 不包含该 server。""" + +async def test_config_merge_cases(): + """覆盖 §5.4 全部合并用例。""" + +async def test_sync_mcp_tools_clears_and_is_idempotent(): + """重复 sync 不 duplicate assert;disable 后索引清除。""" + +async def test_connect_skip_policy(): + """一个 server 失败,其余 enabled server 仍可连接并出现在工具列表。""" + +async def test_runtime_mode_a_no_double_connect(): + """模式 A:Runtime.start + mcp_config={} 不触发二次 connect_to_server。""" + +async def test_degraded_hidden_from_llm_tools(): + """degraded:不出现在 LLM get_tools;UI 仍可通过 MCP 状态 API 查询。""" + +async def test_transient_failure_not_immediately_degraded(): + """单次 timeout 保持 connected,工具仍对 LLM 可见。""" + +async def test_transient_failure_threshold_degrades(): + """连续 transient 失败达阈值后 degraded 并从 LLM 索引移除。""" + +async def test_get_tools_per_server_isolation(): + """单 server list_tools 失败不影响其他 server 入索引。""" + +async def test_mcp_tool_pre_tool_use_deny(): + """MCP 工具 fetch---* 可被 PreToolUse deny。""" + +async def test_sync_mcp_tools_during_parallel_hooks(): + """热更新 sync 与并行 single_call_tool 不死锁。""" +``` + +--- + +## 15. 总结 + +| 问题 | 结论 | +|------|------| +| MCP 能否独立初始化? | **能**。`MCPClient` 支持独立构造、`async with`、外部注入 `LLMAgent`/`ToolManager`。 | +| F7 要做什么? | 不是重写 MCP 连接,而是新增 **MCPConfigManager + MCPRuntime + ConfigResolver 集成 + ToolManager 同步**。 | +| 与 Hooks 关系? | **分治**;在 `single_call_tool` 协作(§7.4);`degraded` 检查在 PreToolUse 之前。 | +| 一期最小闭环? | 配置分层 + `connect_policy=skip` + 软禁用 + `sync_mcp_tools` + `prepare_tools` 与 Hooks 并列接线。 | +| 与现有代码关系? | `MCPClient` / 外部注入 / CLI 兼容不变;叠加于已落地的 Hooks + Permission 管线。 | + +--- + +## 附录 A:现有代码索引 + +| 文件 | 关键符号 | +|------|----------| +| `ms_agent/tools/mcp_client.py` | `MCPClient`, `add_mcp_config`, `connect`, `cleanup` | +| `ms_agent/tools/tool_manager.py` | `hook_runtime`, `single_call_tool`, `_managed_client`, `reindex_tool` | +| `ms_agent/agent/llm_agent.py` | `prepare_tools`, `_hook_runtime`, `mcp_client` / `mcp_runtime` kwargs | +| `ms_agent/hooks/runtime.py` | `HookRuntime`, `run_pre_tool_use`, `run_post_tool_use` | +| `ms_agent/hooks/factory.py` | `build_hook_runtime` | +| `ms_agent/hooks/permission_resolve.py` | `resolve_hook_permission_decision` | +| `ms_agent/utils/pattern_matcher.py` | MCP / Permission / Hooks 共用 matcher | +| `ms_agent/config/config.py` | `convert_mcp_servers_to_json` | +| `docs/zh/design/hooks-design.md` | F6 完整方案;§7.4 管线对齐其 §10.2 | +| `tests/tools/test_mcp_client.py` | `test_outside_init`, `test_add_config` | + +## 附录 B:与 playground F7 原文对照 + +原 F7 预期目标: + +- [x] 设计:`SkillRuntime` — 本文档第 11 节已区分,Skill 另文 +- [x] 设计:`MCPConfigManager` — 本文档第 8 节 +- [x] 设计:MCP 运行时 enable/disable — 本文档第 6 节(一期软禁用,二期硬断开) +- [x] 与 Hooks(F6)协作 — 本文档 §7.4、§9.1、§11;Hooks 管线已落地,MCP 叠加而非重写 diff --git a/docs/zh/design/permission-design.md b/docs/zh/design/permission-design.md new file mode 100644 index 000000000..4b7bfd4ff --- /dev/null +++ b/docs/zh/design/permission-design.md @@ -0,0 +1,1584 @@ +# 权限管控系统设计文档 + +> 参考 Claude Code 权限系统设计(`permissions.ts` + `bashPermissions.ts` + `pathValidation.ts` + `sedValidation.ts`) +> +> 本文档是权限管控模块的**完整可执行方案**,涵盖双层架构、外层用户意图管控、内层安全底线、Shell 路径级校验、前后端交互协议、现有代码迁移等全部内容。 + +--- + +## 目录 + +- [1. 现状分析](#1-现状分析) +- [2. 双层权限架构](#2-双层权限架构) +- [3. 外层:PermissionEnforcer(用户意图层)](#3-外层permissionenforcer用户意图层) +- [4. 内层:SafetyGuard(安全底线层)](#4-内层safetyguard安全底线层) +- [5. Shell 命令路径级校验](#5-shell-命令路径级校验) +- [6. 命令注册表:PATH_EXTRACTORS](#6-命令注册表path_extractors) +- [7. 路径校验流程](#7-路径校验流程) +- [8. 危险路径硬拦截](#8-危险路径硬拦截) +- [9. Safe Wrapper 剥离](#9-safe-wrapper-剥离) +- [10. 输出重定向与进程替换校验](#10-输出重定向与进程替换校验) +- [11. 共享基础设施](#11-共享基础设施) +- [12. 集成点与代码变更](#12-集成点与代码变更) +- [13. 现有代码迁移:WorkspacePolicyKernel](#13-现有代码迁移workspacepolicykernel) +- [14. YAML 配置格式(统一)](#14-yaml-配置格式统一) +- [15. 文件结构](#15-文件结构) +- [16. 与 Claude Code 的对比](#16-与-claude-code-的对比) +- [17. 验证方式](#17-验证方式) +- [18. 实现审查:已知问题与待办](#18-实现审查已知问题与待办) +- [附录 A:parse_pattern_command 通用实现](#附录-aparse_pattern_command-通用实现) +- [附录 B:完整命令操作类型对照表](#附录-b完整命令操作类型对照表) + +--- + +## 1. 现状分析 + +### 1.1 当前权限现状 + +- **统一拦截已就位**:`ToolManager.single_call_tool()` 中注入双层权限检查(SafetyGuard + PermissionEnforcer) +- **WorkspacePolicyKernel 已迁移删除**:安全职责归并到 `SafetyGuard`,功能职责(workspace cwd、deny_globs)提取为轻量 `WorkspaceContext`(`ms_agent/utils/workspace_context.py`) +- **覆盖完整**:SafetyGuard 校验 `shell_executor`、`read_file`、`write_file`、`edit_file`、`grep`、`glob` 六类工具的路径安全 + +### 1.2 原 WorkspacePolicyKernel 能力迁移对照 + +| 原调用者 | 原方法 | 迁移去向 | 状态 | +|--------|-----------|------|------| +| `filesystem_tool.py` | `resolve_under_roots()` | SafetyGuard + `validate_path()` | ✅ 已迁移 | +| `filesystem_tool.py` | `deny_globs` / `workspace_root` | `WorkspaceContext.deny_globs` / `.root` | ✅ 已迁移 | +| `filesystem_tool.py` | `path_is_allowed()` | SafetyGuard 在 tool_manager 层统一检查 | ✅ 已删除 | +| `local_code_executor.py` | `assert_shell_command_allowed()` | SafetyGuard → `ShellPathValidator.check()` | ✅ 已迁移 | +| `local_code_executor.py` | `workspace_root` (subprocess cwd) | `WorkspaceContext.root` | ✅ 已迁移 | +| 两处 | `_shell_looks_network()` | `PermissionConfig._DEFAULT_BLACKLIST` | ✅ 已迁移 | +| — | `iter_files_under()` | 已删除(无外部调用者) | ✅ 已删除 | + +### 1.3 设计目标 + +1. **统一入口**:所有工具调用在 `ToolManager.single_call_tool()` 中经过统一权限检查 +2. **双层分离**:用户可选择放行的"意图层" + 不可绕过的"安全底线层" +3. **交互式确认**:支持 CLI 和 Web 两种场景下的用户确认流程 +4. **精细化管控**:shell 命令做到参数级路径提取和操作类型区分 +5. **消除重复**:`WorkspacePolicyKernel` 的能力迁移到新体系后删除,不保留两套代码 + +--- + +## 2. 双层权限架构 + +``` +┌─────────────────────────────────────────────────────────┐ +│ PermissionEnforcer(外层 · 用户意图层) │ +│ 位置:ToolManager.single_call_tool() 入口 │ +│ 职责:用户是否允许这个操作? │ +│ 特点:可配置、可编辑、可被用户 allow_always 覆盖 │ +│ 规则来源:YAML whitelist/blacklist + PermissionMemory │ +├─────────────────────────────────────────────────────────┤ +│ SafetyGuard(内层 · 安全底线层) │ +│ 位置:ToolManager.single_call_tool(),权限检查最前面 │ +│ 职责:无论用户怎么选,这些操作绝对不允许 │ +│ 特点:不可被用户绕过,即使 mode=auto 也生效 │ +│ 规则来源:YAML safety_rules + 硬编码兜底 + 路径级校验 │ +│ 前身:WorkspacePolicyKernel(重构后纳入统一体系) │ +└─────────────────────────────────────────────────────────┘ +``` + +**关键区别:** +- 外层 `PermissionEnforcer`:用户选择 `allow_always` 后,后续匹配的调用不再询问 +- 内层 `SafetyGuard`:`rm -rf /`、访问 `/etc/passwd` 等操作,即使用户加了 `code_executor---*` 白名单也会被拦截 + +**共享基础设施:** +- 两层共用 `PermissionMatcher` 的通配符匹配逻辑,规则格式统一为 `server---tool:content_pattern` +- 两层的规则均从 YAML 统一加载,但标记不同的 `layer` 属性 +- `WorkspacePolicyKernel` 的现有逻辑(`deny_globs`、`assert_shell_command_allowed`、`resolve_under_roots`)重构为 `SafetyGuard` 的一部分,规则格式对齐 + +**判定流程:** +``` +工具调用进入 ToolManager.single_call_tool() + │ + ├─ 1. SafetyGuard.check() ← 内层先行,不可绕过 + │ ├─ 通用安全规则匹配(YAML safety_rules) + │ ├─ 工具特化检查: + │ │ ├─ code_executor---shell_executor → ShellPathValidator.check() + │ │ ├─ file_system---write_file/edit_file → validate_path(..., 'write') + │ │ ├─ file_system---read_file → validate_path(..., 'read') + │ │ └─ file_system---grep/glob → validate_path(..., 'read') + │ ├─ deny → 直接拒绝 + │ └─ ask → resolve_ask() 按模式解析(见下文) + │ + ├─ 1.5. resolve_ask() ← ask 模式解析层 + │ ├─ auto 模式 → 按 category 分类决策(allow/deny) + │ ├─ strict 模式 → 全部 deny + │ └─ interactive 模式 → 保持 ask,交给 enforcer/handler + │ + ├─ 2. PermissionEnforcer.check() ← 外层用户意图 + │ ├─ mode in (auto, strict) → allow(SafetyGuard 已做安全保障) + │ ├─ blacklist match → deny + │ ├─ whitelist match → allow + │ ├─ session memory match → allow + │ ├─ persistent memory match → allow + │ └─ 其余 → handler.ask()(询问用户) + │ + └─ 3. tool_ins.call_tool() ← 通过两层检查后执行 +``` + +**三种模式说明:** + +| 模式 | SafetyGuard `ask` 处理 | Enforcer 行为 | 适用场景 | +|------|----------------------|--------------|----------| +| `auto` | 按 category 分类:input 替换→allow, output 替换/解析失败/cd+write/变量展开→deny, 读超范围→看 read_policy | 直接 allow | 容器/沙箱/无人值守 | +| `strict` | 全部 → deny | 直接 allow | 高安全要求、无沙箱、无人值守 | +| `interactive` | 保持 ask → 交给 handler | 完整流程(blacklist→whitelist→memory→handler.ask) | 有人值守(CLI/Web/TUI) | + +--- + +## 3. 外层:PermissionEnforcer(用户意图层) + +### 3.1 PermissionConfig (`config.py`) + +从 agent YAML 或 settings 中解析 `permission` 段: + +```python +@dataclass(frozen=True) +class PermissionConfig: + mode: Literal['auto', 'strict', 'interactive'] # 兼容旧名 restricted → interactive + whitelist: tuple[str, ...] # 允许规则 + blacklist: tuple[str, ...] # 拒绝规则 + safety: SafetyConfig # 安全底线配置(传给 SafetyGuard) +``` + +- 白名单/黑名单格式:`server_name---tool_name`,支持 `*` 通配符 +- shell 命令级:支持 `code_executor---shell_executor:command_pattern` 格式 +- 示例:`file_system---read_*`、`web_search---*`、`code_executor---shell_executor:pip *` + +### 3.2 PermissionEnforcer (`enforcer.py`) + +```python +@dataclass(frozen=True) +class PermissionDecision: + action: Literal['allow', 'deny', 'ask'] + reason: str + updated_args: dict | None = None # action == 'allow' 且用户修改了参数时 + +class PermissionEnforcer: + def __init__(self, config: PermissionConfig, handler: PermissionHandler, memory: PermissionMemory) + async def check(self, tool_name: str, tool_args: dict) -> PermissionDecision +``` + +判定流程(参考 Claude Code 的 `hasPermissionsToUseToolInner` 多步管线): +1. `mode in ('auto', 'strict')` → 直接 `allow`(SafetyGuard + ask_resolver 已保障安全) +2. blacklist match → `deny`(不可绕过,参考 Claude Code 的 `alwaysDenyRules`) +3. whitelist match → `allow`(参考 `alwaysAllowRules`) +4. session memory match → `allow`(会话内 `allow_session` 记录) +5. persistent memory match → `allow`(`PermissionMemory` 持久化规则) +6. 其余 → 调用 `handler.ask()`,传入自动生成的 suggestions +7. 用户选择 `modify` 时,返回 `updated_args` 供 `ToolManager` 使用修改后的参数执行 + +### 3.3 PermissionHandler (`handler.py`) + +用户在 `interactive` 模式下遇到非白名单工具时,提供 5 种细粒度选择: + +```python +class PermissionAction(str, Enum): + ALLOW_ONCE = 'allow_once' # 仅允许本次调用 + ALLOW_SESSION = 'allow_session' # 本次会话中允许所有同类调用 + ALLOW_ALWAYS = 'allow_always' # 永久加入白名单(持久化) + DENY = 'deny' # 拒绝本次调用 + MODIFY = 'modify' # 用户修改工具参数后执行 + +@dataclass(frozen=True) +class PermissionResponse: + action: PermissionAction + updated_args: dict | None = None # action == MODIFY 时,用户修改后的参数 + pattern: str | None = None # action == ALLOW_ALWAYS 时,用户确认/编辑的通配符模式 + feedback: str | None = None # 用户附加的反馈信息 + +class PermissionHandler(Protocol): + async def ask(self, tool_name: str, tool_args: dict, + context: str, suggestions: list[str] | None = None) -> PermissionResponse +``` + +#### 3.3.1 AutoPermissionHandler + +直接返回 `action=ALLOW_ONCE`(auto 模式下不会被调用,但作为兜底)。 + +#### 3.3.2 CLIPermissionHandler + +交互式 CLI 菜单,对标 Claude Code 的 Select 组件: + +``` +╭─ Permission Required ──────────────────────────╮ +│ Tool: code_executor---shell_executor │ +│ Args: {"command": "pip install requests"} │ +│ │ +│ > [y] 允许本次 │ +│ [s] 本次会话中允许所有 shell_executor 调用 │ +│ [a] 以后都允许 code_executor---shell_executor │ +│ [e] 编辑命令后执行 │ +│ [n] 拒绝 │ +╰─────────────────────────────────────────────────╯ +``` + +**`allow_always` 可编辑模式**(参考 Claude Code 的 `yes-prefix-edited` 选项): + +用户选择 `a` 后,系统根据当前工具名和参数自动生成一个建议模式,用户可以编辑: + +``` +╭─ 添加白名单规则 ─────────────────────────────────╮ +│ 以后不再询问匹配以下模式的调用: │ +│ │ +│ > code_executor---shell_executor:ls *█ │ +│ │ +│ 回车确认 / 编辑后回车 / Esc 取消 │ +╰───────────────────────────────────────────────────╯ +``` + +建议模式生成示例: +- `shell_executor` + `{"command": "ls -la"}` → 建议 `code_executor---shell_executor:ls *` +- `shell_executor` + `{"command": "pip install requests"}` → 建议 `code_executor---shell_executor:pip *` +- `read_file` + `{"path": "/src/main.py"}` → 建议 `file_system---read_file` +- `web_search` + 任意参数 → 建议 `web_search---*` + +用户可以自由编辑模式来控制粒度: +- 更宽松:`code_executor---shell_executor` → 所有 shell 命令都放行 +- 更精确:`code_executor---shell_executor:pip install *` → 仅 pip install 放行 + +#### 3.3.3 WebPermissionHandler(前后端交互协议) + +后端 agent 在工具调用过程中需要暂停执行、等待前端用户确认,采用 **Future 挂起 + 事件推送 + REST 回调** 模式: + +```python +class WebPermissionHandler(PermissionHandler): + def __init__(self, event_emitter: EventEmitter, timeout: float = 120.0): + self._pending: dict[str, asyncio.Future[PermissionResponse]] = {} + self._event_emitter = event_emitter + self._timeout = timeout + + async def ask(self, tool_name: str, tool_args: dict, + context: str, suggestions: list[str] | None = None) -> PermissionResponse: + request_id = uuid4().hex + future: asyncio.Future[PermissionResponse] = asyncio.get_event_loop().create_future() + self._pending[request_id] = future + + # 1. 向流式通道推送权限请求事件 + self._event_emitter.emit(PermissionRequestEvent( + request_id=request_id, + tool_name=tool_name, + tool_args=tool_args, + suggestions=suggestions or [], + )) + + # 2. 挂起等待前端回传决策 + try: + return await asyncio.wait_for(future, timeout=self._timeout) + except asyncio.TimeoutError: + return PermissionResponse(action=PermissionAction.DENY, + feedback='Permission request timed out') + finally: + self._pending.pop(request_id, None) + + def resolve(self, request_id: str, response: PermissionResponse) -> None: + future = self._pending.get(request_id) + if future and not future.done(): + future.set_result(response) +``` + +**前后端交互流程:** + +``` +Backend (agent) Frontend + │ │ + │──── SSE/Stream: permission_request ───>│ ← 推送权限请求事件 + │ { │ + │ type: "permission_request", │ + │ request_id: "abc123", │ + │ tool_name: "code_executor--- │ + │ shell_executor", │ + │ tool_args: {command: "ls -la"}, │ + │ suggestions: [ │ + │ "code_executor--- │ + │ shell_executor:ls *" │ + │ ], │ + │ options: ["allow_once", │ + │ "allow_session", │ + │ "allow_always", "modify", │ + │ "deny"] │ + │ } │ + │ │ + │ (后端 await future,执行暂停) │ (前端渲染权限弹窗) + │ │ + │<─── POST /permission/respond ──────────│ ← 用户选择后回传 + │ { │ + │ request_id: "abc123", │ + │ action: "allow_always", │ + │ pattern: "code_executor--- │ + │ shell_executor:ls *" │ + │ } │ + │ │ + │ (future.set_result → await 返回) │ + │ (工具调用继续执行) │ +``` + +**需要预留的接口:** + +1. **流式事件类型** `permission_request`:在现有 agent 流式输出协议中新增此事件类型,前端需识别并渲染权限 UI +2. **REST 接口** `POST /permission/respond`:接收前端回传的用户决策,参数为 `{request_id, action, pattern?, updated_args?, feedback?}` +3. **`EventEmitter` 抽象**:`WebPermissionHandler` 通过此接口推送事件,不直接依赖具体的流式协议(SSE/WebSocket/StreamableHTTP),由上层注入具体实现 +4. **超时机制**:默认 120 秒未回复自动 deny,可配置 + +### 3.4 建议模式自动生成 (`suggestions.py`) + +```python +def generate_suggestions(tool_name: str, tool_args: dict) -> list[str]: + """根据工具名和参数自动生成通配符建议模式""" +``` + +- 参考 Claude Code 的 `PermissionUpdate.suggestions` 机制 +- 例如 `code_executor---shell_executor` + `{"command": "npm run build"}` → 建议模式 `code_executor---shell_executor:npm *` +- 建议展示在 `allow_always` 选项中,用户可编辑后确认 + +### 3.5 PermissionMemory (`memory.py`) + +持久化用户的 "always allow" 决策: + +```python +@dataclass(frozen=True) +class MemoryEntry: + pattern: str # 通配符模式 + scope: Literal['project', 'global'] + source: Literal['user', 'plugin', 'hook'] # 规则来源 + created_at: str +``` + +- 项目级 `.ms_agent/permission_memory.json` + 全局级 `~/.ms_agent/permission_memory.json` +- 合并时项目级优先 +- 提供 `add()` / `matches()` / `revoke()` / `list_all()` 接口 +- 会话级记忆(`allow_session`)仅存内存,不持久化 + +--- + +## 4. 内层:SafetyGuard(安全底线层) + +### 4.1 职责 + +无论用户配置了什么模式、加了什么白名单,以下操作**绝对不允许**自动放行: +- `rm -rf /`、`mkfs`、`dd if=` 等破坏性命令 +- 写入 `/etc/`、`/sys/`、`~/.ssh/` 等系统敏感路径 +- shell 命令操作的文件路径超出工作目录范围(write/create 操作) + +### 4.2 SafetyGuard 类设计 + +```python +@dataclass(frozen=True) +class SafetyConfig: + """内层安全配置(从 YAML permission.safety_rules 解析)""" + patterns: tuple[str, ...] # 通用工具级拦截规则 + sensitive_paths: tuple[str, ...] # 写入敏感路径拦截 + dangerous_removal_paths: tuple[str, ...] # rm/rmdir 危险路径 + read_policy: Literal['loose', 'strict'] = 'loose' # 读超范围时的兜底策略 + max_command_chars: int = 8192 + allowed_directories: tuple[str, ...] = () # 完全访问(读+写+create) + read_only_directories: tuple[str, ...] = () # 只读访问(读允许,写/create 拒绝) + +@dataclass(frozen=True) +class SafetyDecision: + action: Literal['allow', 'deny', 'ask'] + reason: str + category: str = '' # ask 时标记原因类别,供 resolve_ask 分类决策 + +class SafetyGuard: + def __init__(self, config: SafetyConfig, allowed_dirs: list[str], + read_only_dirs: list[str] = (), workspace_root: str | None = None): + self._config = config # YAML 中加载的通用安全规则 + self._allowed_dirs = list(allowed_dirs) + self._read_only_dirs = list(read_only_dirs) # 只读目录(读允许,写/create 拒绝) + self._workspace_root = workspace_root # 相对路径解析基目录(与工具端 output_dir 统一) + self._shell_validator = ShellPathValidator( # shell 命令专用校验器 + allowed_dirs=self._allowed_dirs, + safety_config=PathSafetyConfig( + max_command_chars=config.max_command_chars, + allowed_directories=tuple(self._allowed_dirs), + read_only_directories=tuple(self._read_only_dirs), + ), + ) + self._matcher = PermissionMatcher() # 共用的通配符匹配 + + def check(self, tool_name: str, tool_args: dict) -> SafetyDecision: + # 1. 通用安全规则匹配(YAML safety_rules 中的 server---tool:pattern) + for rule in self._config.patterns: + if self._matcher.match_with_content(rule, tool_name, tool_args): + return SafetyDecision(action='deny', reason=f'Blocked by safety rule: {rule}') + + # 2. 工具特化检查 + if tool_name.endswith('---shell_executor'): + return self._shell_validator.check(tool_args.get('command', '')) + elif tool_name.endswith('---write_file') or tool_name.endswith('---edit_file'): + return self._check_file_path(tool_args.get('path', ''), 'write') + elif tool_name.endswith('---read_file'): + return self._check_file_path(tool_args.get('path', ''), 'read') + elif tool_name.endswith('---grep') or tool_name.endswith('---glob'): + return self._check_file_path(tool_args.get('path', '.'), 'read') + + # 3. 未匹配 → 放行 + return SafetyDecision(action='allow', reason='No safety rule matched') + + def _check_file_path(self, path: str, op_type: str) -> SafetyDecision: + cwd = self._workspace_root or os.getcwd() + result = validate_path(path, cwd, self._allowed_dirs, op_type, + read_only_dirs=self._read_only_dirs) + if not result.allowed: + return SafetyDecision(action=result.action, reason=result.reason, category=result.category) + return SafetyDecision(action='allow', reason='Path validation passed') +``` + +### 4.3 ask 分类解析:resolve_ask (`ask_resolver.py`) + +SafetyGuard 返回 `ask` 时携带 `category` 字段,`resolve_ask()` 根据当前模式决定最终动作: + +```python +def resolve_ask(decision: SafetyDecision, mode: str, read_policy: str = 'loose') -> SafetyDecision: + """auto: 按 category 分类; strict: 全 deny; interactive: 保持 ask""" +``` + +**auto 模式策略表:** + +| category | 解析为 | 理由 | +|----------|--------|------| +| `process_input_sub` | allow | `<(...)` 是读操作,风险低 | +| `process_output_sub` | deny | `>(...)` 可能绕过路径校验写入 | +| `parse_failure` | deny | 无法验证即不信任 | +| `cd_write_compound` | deny | cd 改变 cwd,静态路径验证不可靠 | +| `command_validator` | deny | 验证器明确发现可疑模式 | +| `shell_expansion` | deny | `$VAR` 路径无法静态解析 | +| `read_outside_dirs` | 由 `read_policy` 决定 | `loose`→allow, `strict`→deny | + +**`read_outside_dirs` 触发条件**:读取路径不在 `allowed_dirs` 且不在 `read_only_dirs` 范围内时触发。如果路径在 `read_only_dirs` 中,`validate_path` 直接返回 allow,不会产生 `ask`。`read_policy` 是对"两个目录列表都未覆盖的读取"的兜底策略。 + +设计意图:auto 模式下**永不弹出交互**,所有不确定性在 SafetyGuard 返回后立即解析为确定性决策。 + +### 4.4 SafetyGuard 与 WorkspacePolicyKernel 的关系(已完成迁移) + +`SafetyGuard` 已**完全替代** `WorkspacePolicyKernel`,后者已从代码库中删除。WPK 的职责被拆分为两部分: + +**安全职责 → SafetyGuard(在 tool_manager 层统一拦截):** + +| 原 WPK 能力 | 迁移去向 | 说明 | +|-------------|---------|------| +| `resolve_under_roots()` | `SafetyGuard._check_file_path()` → `validate_path()` | 支持更多校验步骤 | +| `path_is_allowed()` | `SafetyGuard._check_file_path()` → `validate_path()` | 合并 | +| `assert_shell_command_allowed()` | `SafetyGuard.check()` → `ShellPathValidator.check()` | 精细化为 36 命令注册表 | +| `_shell_looks_network()` | `PermissionConfig._DEFAULT_BLACKLIST` | 升级为可配置的 enforcer 层策略 | +| `_shell_looks_mutating_or_network()` | `ShellPathValidator` 操作类型分类 | 精细化 | + +**功能职责 → WorkspaceContext(轻量 frozen dataclass):** + +| 原 WPK 能力 | 迁移去向 | 说明 | +|-------------|---------|------| +| `workspace_root` (subprocess cwd) | `WorkspaceContext.root` | 仅提供工作目录,不含安全检查 | +| `deny_globs` (文件遍历过滤) | `WorkspaceContext.deny_globs` | 仅提供过滤模式,不含安全检查 | +| `iter_files_under()` | 已删除 | 确认无外部调用者 | + +--- + +## 5. Shell 命令路径级校验 + +### 5.1 问题 + +`shell_executor` 工具允许 agent 执行任意 shell 命令。仅靠命令前缀匹配(如 `rm *`)无法覆盖以下风险: + +- **路径越权**:`cat /etc/shadow`、`rm -rf /` — 命令本身合法,但操作的路径超出工作目录 +- **路径伪装**:`rm -- -/../.claude/settings.json` — 利用 `-` 开头路径绕过 flag 过滤 +- **包装器绕过**:`timeout 10 rm -rf /` — wrapper 命令遮蔽真实操作命令 +- **输出重定向**:`echo "malicious" > /etc/passwd` — 命令是 `echo`,但写入了敏感路径 +- **复合命令**:`cd .claude/ && mv test.txt settings.json` — 通过 `cd` 改变工作目录后操作 +- **shell 展开**:`rm $HOME/.ssh/*` — 变量展开导致验证时路径和执行时路径不一致 + +### 5.2 ShellPathValidator 架构 + +``` +shell 命令字符串进入 ShellPathValidator.check() + │ + ├─ 1. 进程替换检查(区分 input/output) + │ >(cmd) → ask(category='process_output_sub') + │ <(cmd) → ask(category='process_input_sub') + │ + ├─ 2. 复合命令拆分 + │ && / || / ; / | → 拆分为独立子命令 + │ 记录是否包含 cd(影响后续路径解析) + │ + ├─ 3. 输出重定向校验(每个子命令,在 wrapper 剥离前) + │ > / >> / &> / &>> → 提取目标路径,校验是否在允许范围 + │ /dev/null 始终放行 + │ 变量展开 ($VAR) 在重定向目标中 → deny + │ + ├─ 4. Safe Wrapper 剥离 + │ timeout / nice / nohup / time / stdbuf / env → 去掉包装 + │ + ├─ 5. 命令路径校验(核心) + │ ├─ 识别 base command(第一个 token) + │ ├─ PATH_EXTRACTORS[command](args) → 提取路径列表 + │ ├─ 危险路径硬拦截(rm -rf / 等) + │ └─ validate_path(path, allowed_dirs, op_type, read_only_dirs) → 逐一校验 + │ + └─ 6. 返回决策 + allow(非路径命令/校验通过)/ ask(需确认)/ deny(硬拦截) +``` + +```python +@dataclass(frozen=True) +class PathSafetyConfig: + """ShellPathValidator 的配置,由 SafetyGuard 从 SafetyConfig 构建后注入。""" + max_command_chars: int = 8192 + allowed_directories: tuple[str, ...] = () + read_only_directories: tuple[str, ...] = () # 只读目录(读允许,写/create 拒绝) + +class ShellPathValidator: + """shell_executor 工具的路径级安全校验""" + + def __init__(self, allowed_dirs: list[str], safety_config: PathSafetyConfig): + self._allowed_dirs = allowed_dirs + self._config = safety_config + self._read_only_dirs = list(safety_config.read_only_directories) + self._extractors = build_extractor_registry() # 36 个命令的提取器 + + def check(self, command: str) -> SafetyDecision: + # 1. 进程替换检查 + # 2. 拆分复合命令 + # 3. 逐子命令:重定向校验 → 剥离 wrapper → 提取路径 → 校验路径 + # 4. cd + write/create 复合检测 + ... +``` + +`PathSafetyConfig` 是 `SafetyConfig`(YAML 级)与 `ShellPathValidator`(运行时)之间的桥接类型。`SafetyGuard.__init__` 从 `SafetyConfig` 提取字段构建此对象,避免 `ShellPathValidator` 直接依赖上层配置结构。 + +--- + +## 6. 命令注册表:PATH_EXTRACTORS + +### 6.1 设计原则 + +- **每个命令一个提取器**:不存在"通用提取",每个命令按自身参数语法提取路径 +- **安全优先**:未注册的命令不做路径校验(passthrough),由外层权限管控覆盖 +- **`--` 分隔符感知**:POSIX 标准中 `--` 表示"选项结束",之后所有参数均为位置参数,即使以 `-` 开头 + +注册表条目类型: + +```python +CommandExtractor = Callable[[list[str]], list[str]] +CommandValidator = Callable[[list[str]], str | None] + +@dataclass(frozen=True) +class ExtractorEntry: + extractor: CommandExtractor # 从命令参数中提取路径列表 + op_type: Literal['read', 'write', 'create'] # 操作类型,决定路径校验策略 + description: str # 人类可读描述(用于错误消息) + command_validator: CommandValidator | None = None # 可选的命令级校验器(如 mv/cp) +``` + +`build_extractor_registry()` 构建完整的 36 条命令映射 `dict[str, ExtractorEntry]`,`ShellPathValidator` 在初始化时调用一次并缓存。 + +### 6.2 路径提取策略分类 + +根据命令的参数解析方式,将 36 个命令分为 **5 类提取策略**: + +#### 策略 A:过滤 flags 取剩余参数(`filter_out_flags`) + +最常见的模式。跳过所有以 `-` 开头的参数(flags),将剩余视为路径。正确处理 `--` 分隔符。 + +```python +def filter_out_flags(args: list[str]) -> list[str]: + result = [] + after_double_dash = False + for arg in args: + if after_double_dash: + result.append(arg) + elif arg == '--': + after_double_dash = True + elif not arg.startswith('-'): + result.append(arg) + return result +``` + +安全关键:`rm -- -/../.claude/settings.json` 中 `-/../...` 以 `-` 开头,朴素过滤会丢弃它,但 `--` 之后应当保留。 + +#### 策略 B:模式命令解析(`parse_pattern_command`) + +用于 grep/rg 类命令,参数格式为 `command [flags] pattern [files...]`。 +第一个非 flag 参数是 pattern(跳过),后续是文件路径。如果通过 `-e`/`-f` 显式指定了 pattern,则所有非 flag 参数都是路径。 + +#### 策略 C:特殊参数跳过 + +用于 sed/jq 等命令,需要跳过"表达式"参数(非路径),仅提取文件参数。 + +#### 策略 D:搜索起点收集 + +用于 find 命令,收集位于 flags 之前的参数作为搜索起点。 + +#### 策略 E:子命令分发 + +用于 git 等有子命令体系的命令,根据子命令决定是否需要路径校验。 + +### 6.3 完整命令注册表 + +#### `cd` — 切换目录 | `read` | 特殊处理 + +- 无参数 → `[home_dir]` +- 有参数 → 所有参数拼接为一个路径 +- 安全考量:`cd` 本身是 read,但在复合命令中影响后续命令的工作目录(详见第 10.3 节) + +#### `ls` — 列出文件 | `read` | A + 默认值 + +- `filter_out_flags(args)`,无路径时默认 `['.']` + +#### `find` — 搜索文件 | `read` | D(搜索起点收集) + +- 跳过全局选项 `-H`/`-L`/`-P` +- 收集首个非全局 flag 之前的位置参数作为搜索起点 +- 某些 flag 值也是路径:`-newer`、`-anewer`、`-cnewer`、`-mnewer`、`-samefile`、`-path`、`-wholename`、`-ilname`、`-lname`、`-ipath`、`-iwholename` + `-newer[acmBt][acmtB]` 正则 +- `--` 之后所有参数强制为路径,无路径时默认 `['.']` + +```python +def extract_find(args): + paths = [] + path_flags = {'-newer', '-anewer', '-cnewer', '-mnewer', '-samefile', + '-path', '-wholename', '-ilname', '-lname', '-ipath', '-iwholename'} + newer_pattern = re.compile(r'^-newer[acmBt][acmtB]$') + found_non_global_flag = False + after_double_dash = False + + i = 0 + while i < len(args): + arg = args[i] + if after_double_dash: + paths.append(arg); i += 1; continue + if arg == '--': + after_double_dash = True; i += 1; continue + if arg.startswith('-'): + if arg in ('-H', '-L', '-P'): + i += 1; continue + found_non_global_flag = True + if arg in path_flags or newer_pattern.match(arg): + if i + 1 < len(args): + paths.append(args[i + 1]); i += 1 + i += 1; continue + if not found_non_global_flag: + paths.append(arg) + i += 1 + return paths if paths else ['.'] +``` + +#### 策略 A 命令组(27 个)| `filter_out_flags(args)` + +| 命令 | 操作类型 | 附加校验 | 描述 | +|------|---------|---------|------| +| `mkdir` | create | - | create directories in | +| `touch` | create | - | create or modify files in | +| `rm` | write | 危险删除路径检查 | remove files from | +| `rmdir` | write | 危险删除路径检查 | remove directories from | +| `mv` | write | 命令校验器:拒绝所有带 flag 的调用(`--target-directory=PATH` 绕过) | move files to/from | +| `cp` | write | 命令校验器:同 mv | copy files to/from | +| `cat` | read | - | concatenate files from | +| `head` | read | - | read the beginning of files from | +| `tail` | read | - | read the end of files from | +| `sort` | read | - | sort contents of files from | +| `uniq` | read | - | filter duplicate lines from files in | +| `wc` | read | - | count lines/words/bytes in files from | +| `cut` | read | - | extract columns from files in | +| `paste` | read | - | merge files from | +| `column` | read | - | format files from | +| `file` | read | - | examine file types in | +| `stat` | read | - | read file stats from | +| `diff` | read | - | compare files from | +| `awk` | read | - | process text from files in | +| `strings` | read | - | extract strings from files in | +| `hexdump` | read | - | display hex dump of files from | +| `od` | read | - | display octal dump of files from | +| `base64` | read | - | encode/decode files from | +| `nl` | read | - | number lines in files from | +| `sha256sum` | read | - | compute SHA-256 checksums for files in | +| `sha1sum` | read | - | compute SHA-1 checksums for files in | +| `md5sum` | read | - | compute MD5 checksums for files in | + +#### `tr` — 字符转换 | `read` | 特殊处理 + +- 跳过 1-2 个字符集参数(SET1、SET2),`-d`/`--delete` 时仅 SET1 +- 剩余为文件路径 + +```python +def extract_tr(args): + has_delete = any(a == '-d' or a == '--delete' or + (a.startswith('-') and 'd' in a) for a in args) + non_flags = filter_out_flags(args) + return non_flags[1 if has_delete else 2:] +``` + +#### `grep` — 文本搜索 | `read` | B(模式命令解析) + +带值 flags:`-e`, `--regexp`, `-f`, `--file`, `--exclude`, `--include`, `--exclude-dir`, `--include-dir`, `-m`, `--max-count`, `-A`/`-B`/`-C` + 长形式 + +特殊:`-r`/`-R`/`--recursive` 且无路径 → `['.']` + +#### `rg` (ripgrep) — 文本搜索 | `read` | B(模式命令解析) + +带值 flags:`-e`, `--regexp`, `-f`, `--file`, `-t`, `--type`, `-T`, `--type-not`, `-g`, `--glob`, `-m`, `--max-count`, `--max-depth`, `-r`, `--replace`, `-A`/`-B`/`-C` + 长形式 + +默认路径:`['.']` + +#### `sed` — 流编辑器 | `write`(可降级为 `read`) | C(特殊参数跳过) + +提取逻辑: +1. `-f`/`--file` 值 → 脚本文件路径,加入路径列表 +2. `-e`/`--expression` 值 → 表达式,跳过 +3. 第一个非 flag 参数 → 表达式(如未通过 `-e`/`-f` 指定),跳过 +4. 之后 → 文件路径 + +```python +def extract_sed(args): + paths, skip_next, script_found, after_dd = [], False, False, False + for i, arg in enumerate(args): + if skip_next: skip_next = False; continue + if not after_dd and arg == '--': after_dd = True; continue + if not after_dd and arg.startswith('-'): + if arg in ('-f', '--file'): + if i + 1 < len(args): paths.append(args[i + 1]); skip_next = True + script_found = True + elif arg in ('-e', '--expression'): + skip_next = True; script_found = True + elif 'e' in arg or 'f' in arg: script_found = True + continue + if not script_found: script_found = True; continue + paths.append(arg) + return paths +``` + +**操作类型降级**:`-n` + 仅打印表达式(`^(\d+(,\d+)?)?p$`)+ 无 `-i` → `read` + +**表达式安全检查(防御纵深)**:即使路径合法,以下表达式模式仍需拦截。 + +检查结果类型: + +```python +@dataclass(frozen=True) +class SedSafetyResult: + safe: bool + reason: str +``` + +`check_sed_expression_safety(expression) → SedSafetyResult` 对每个 sed 表达式逐一检查,任一不安全则整条命令被 deny。 + +拦截规则: + +| 危险模式 | 说明 | +|----------|------| +| `w`/`W` command | 写入文件 | +| `e`/`E` command | 执行 shell 命令 | +| `s......[flags]` 中含 `w`/`e` flag | 替换结果写文件/执行(支持任意分隔符,如 `s\|x\|y\|w file`) | +| 非 ASCII 字符 | Unicode 同形字攻击 | +| `{}` 花括号 | 块命令,无法静态分析 | +| 换行符 | 多行命令注入 | +| `!` 取反 | 增加分析复杂度 | + +**任意分隔符检测**:sed 的 `s` 命令允许使用任意字符作为分隔符(如 `s|foo|bar|w file`、`s#foo#bar#e`)。`_has_dangerous_sub_flags()` 通过解析实际分隔符字符、跳过转义分隔符、定位 flags 区段来检测危险 flag,而非假定 `/` 为分隔符。 + +#### `jq` — JSON 处理器 | `read` | C(特殊参数跳过) + +- 带值 flags:`-e`, `-f`, `--arg`, `--argjson`, `--slurpfile`, `--rawfile`, `-L`, `--indent` 等 +- 第一个非 flag 参数是 filter → 跳过,后续为文件路径 +- 无文件参数 → 从 stdin 读取,无需校验 + +#### `git` — 版本控制 | `read` | E(子命令分发) + +- **`git diff --no-index`**:提取 `diff` 之后的非 flag 参数,取前 2 个 +- **其他子命令**:在 git 仓库上下文内,受 git 自身安全模型约束 → 返回空列表 + +```python +def extract_git(args): + if args and args[0] == 'diff' and '--no-index' in args: + return filter_out_flags(args[1:])[:2] + return [] +``` + +### 6.4 分类汇总 + +| 策略 | 命令数 | 命令列表 | +|------|-------|---------| +| A: filter_out_flags | 27 | mkdir, touch, rm, rmdir, mv, cp, cat, head, tail, sort, uniq, wc, cut, paste, column, file, stat, diff, awk, strings, hexdump, od, base64, nl, sha256sum, sha1sum, md5sum | +| B: parse_pattern_command | 2 | grep, rg | +| C: 特殊参数跳过 | 2 | sed, jq | +| D: 搜索起点收集 | 1 | find | +| E: 子命令分发 | 1 | git | +| 特殊处理 | 3 | cd, ls, tr | + +### 6.5 命令级校验器 + +某些命令有 flag 可绕过路径提取(路径藏在 flag 值中),需额外校验: + +| 命令 | 规则 | 原因 | +|------|------|------| +| `mv` | 拒绝所有带 flag 的调用 → ask | `--target-directory=PATH` | +| `cp` | 拒绝所有带 flag 的调用 → ask | `--target-directory=PATH` | + +### 6.6 操作类型分类 + +| 类型 | 策略 | 命令(数量) | +|------|------|------------| +| `read` | 范围可放宽 | cd, ls, find, cat, head, tail, sort, uniq, wc, cut, paste, column, tr, file, stat, diff, awk, strings, hexdump, od, base64, nl, grep, rg, git, jq, sha256sum, sha1sum, md5sum (29) | +| `write` | 严格限制在工作目录内 | rm, rmdir, mv, cp, sed (5) | +| `create` | 严格限制在工作目录内 | mkdir, touch (2) | + +动态降级:`sed` 在只读条件下(`-n` + 仅打印 + 无 `-i`)从 `write` 降级为 `read`。 + +--- + +## 7. 路径校验流程 + +### 7.1 `validate_path(path, cwd, allowed_dirs, op_type, *, read_only_dirs=())` + +对单个路径做完整校验,返回 `PathValidationResult(allowed, resolved_path, action, reason)`。 + +**步骤:** + +1. **去引号 + 波浪号展开** + - 去除包裹的单/双引号 + - `~` → `home_dir`,`~/xxx` → `home_dir/xxx` + - `~username`、`~+`、`~-` → 拒绝(TOCTOU 风险:验证时无法知道 shell 实际展开结果) + +2. **拒绝 Shell 展开语法** + - 包含 `$` → 拒绝(`$VAR`、`${VAR}`、`$(cmd)`) + - 包含 `%` → 拒绝(Windows `%VAR%`) + - 以 `=` 开头 → 拒绝(Zsh `=cmd` 展开) + +3. **Glob 模式处理** + - write/create 操作中含 glob(`*?[]{}` 字符) → 拒绝(无法确定实际写入路径) + - read 操作中含 glob → 提取 glob 基础目录进行校验 + +4. **路径解析** + - 相对路径 → `resolve(cwd, path)` 转为绝对路径 + - 解析符号链接(但危险路径检查在解析前进行,防止 `/tmp` → `/private/tmp` 逃逸) + +5. **目录范围检查** + - 路径是否在 `allowed_dirs` 中某个目录的子树内 + - write/create 操作 → 必须在 `allowed_dirs` 范围内,否则 deny + - read 操作 → 先查 `allowed_dirs`,再查 `read_only_dirs`,都不在则 ask(交由 `read_policy` + `resolve_ask` 决定最终结果) + +### 7.2 目录白名单 + +**`allowed_dirs`**(读 + 写 + create)合并来源: +- 项目根目录(agent 启动时确定) +- YAML 配置的 `allowed_directories` +- 会话中动态添加的目录(用户确认后) + +**`read_only_dirs`**(仅读取)来源: +- YAML 配置的 `read_only_directories` + +路径校验优先级:`allowed_dirs`(完全访问) → `read_only_dirs`(只读) → 其余(ask/deny,由 `read_policy` 决定)。写入操作只查 `allowed_dirs`,`read_only_dirs` 中的路径不允许写入。 + +### 7.3 Glob 基础目录提取 + +```python +def get_glob_base_directory(pattern: str) -> str: + glob_chars = set('*?[]{}') + first_glob = len(pattern) + for i, c in enumerate(pattern): + if c in glob_chars: + first_glob = i; break + base = pattern[:first_glob] + last_sep = base.rfind('/') + if last_sep < 0: return '.' + return base[:last_sep] or '/' +``` + +--- + +## 8. 危险路径硬拦截 + +### 8.1 危险删除路径 (`is_dangerous_removal_path`) + +适用于 `rm` 和 `rmdir`,即使在工作目录范围内也**不可自动放行**: + +| 模式 | 示例 | 说明 | +|------|------|------| +| 通配符 `*` | `rm *` | 删除当前目录所有文件 | +| 尾部 `/*` | `rm /tmp/*` | 清空目录 | +| 根目录 `/` | `rm -rf /` | 系统根目录 | +| 家目录 `~` | `rm -rf ~` | 用户全部数据 | +| 根直接子目录 | `rm -rf /usr` | 系统关键目录(不含 `/usr/local`) | +| Windows 驱动器根 | `rm -rf C:\` | Windows 根 | +| Windows 驱动器直接子目录 | `rm -rf C:\Windows` | Windows 系统目录 | + +路径规范化:连续 `\` 和 `/` 压缩为单个 `/`。 + +```python +def is_dangerous_removal_path(path: str) -> bool: + normalized = re.sub(r'[/\\]+', '/', path) + if normalized == '*': return True + if normalized.endswith('/*') or normalized.endswith('\\*'): return True + if normalized == '/': return True + if normalized == os.path.expanduser('~').replace('\\', '/'): return True + if re.match(r'^/[^/]+$', normalized): return True + if re.match(r'^[A-Za-z]:/?$', normalized): return True + if re.match(r'^[A-Za-z]:/[^/]+$', normalized): return True + return False +``` + +### 8.2 系统敏感路径 + +对任何 write 操作均需特别警惕(配置在 YAML `safety_rules.sensitive_paths` 中): + +| 路径 | 说明 | +|------|------| +| `/etc/*` | 系统配置 | +| `/sys/*`, `/boot/*`, `/dev/*`, `/proc/*` | 内核/设备/进程 | +| `~/.ssh/*`, `~/.gnupg/*` | 密钥 | +| `~/.bashrc`, `~/.zshrc`, `~/.profile` | Shell 配置 | +| `.git/config`, `.git/hooks/*` | Git 配置和钩子 | + +--- + +## 9. Safe Wrapper 剥离 + +### 9.1 问题 + +包装命令遮蔽真实操作命令:`timeout 10 rm -rf /` → base command 是 `timeout`(非路径命令)→ passthrough。 + +### 9.2 支持剥离的 Wrapper + +| Wrapper | 剥离示例 | +|---------|---------| +| `timeout` | `timeout 10 rm file` → `rm file` | +| `time` | `time ls -la` → `ls -la` | +| `nice` | `nice -n 10 rm file` → `rm file` | +| `nohup` | `nohup rm file` → `rm file` | +| `stdbuf` | `stdbuf -o0 cat file` → `cat file` | +| `env` | `env KEY=val rm file` → `rm file` | + +**不剥离**:`sudo`、`su`、`doas`、`bash -c`、`sh -c` — 改变执行上下文,不能安全剥离。 + +### 9.3 两阶段剥离算法 + +``` +阶段 1:剥离安全环境变量 + 循环直到无变化: + - 检查是否以 VAR=value 开头 + - VAR 在安全变量白名单中 → 剥离 + - VAR 不在白名单 → 停止 + +阶段 2:剥离 wrapper 命令 + 循环直到无变化: + - 匹配 5 个 wrapper 的正则 → 剥离前缀 + - 此阶段不剥离环境变量(wrapper 用 execvp 执行子命令,VAR=val 是命令名不是赋值) +``` + +### 9.4 安全环境变量白名单 + +| 分类 | 变量 | +|------|------| +| Go | `GOEXPERIMENT`, `GOOS`, `GOARCH`, `CGO_ENABLED`, `GO111MODULE` | +| Rust | `RUST_BACKTRACE`, `RUST_LOG` | +| Node | `NODE_ENV`(不含 `NODE_OPTIONS`) | +| Python | `PYTHONUNBUFFERED`, `PYTHONDONTWRITEBYTECODE` | +| Pytest | `PYTEST_DISABLE_PLUGIN_AUTOLOAD`, `PYTEST_DEBUG` | +| 语言/编码 | `LANG`, `LANGUAGE`, `LC_ALL`, `LC_CTYPE`, `LC_TIME`, `CHARSET` | +| 终端/显示 | `TERM`, `COLORTERM`, `NO_COLOR`, `FORCE_COLOR`, `TZ` | +| 颜色配置 | `LS_COLORS`, `LSCOLORS`, `GREP_COLOR`, `GREP_COLORS`, `GCC_COLORS` | +| 显示格式 | `TIME_STYLE`, `BLOCK_SIZE`, `BLOCKSIZE` | + +**不安全(不可剥离)**:`HOME`, `TMPDIR`, `SHELL`(影响路径);`BASH_ENV`, `PYTHONPATH`(代码注入);`GOFLAGS`, `RUSTFLAGS`, `NODE_OPTIONS`(影响运行时) + +### 9.5 timeout 剥离细节(最复杂) + +| Flag | 类型 | +|------|------| +| `--foreground`, `--preserve-status`, `--verbose`/`-v` | 无值 | +| `--kill-after=N`/`-k N`/`-kN`, `--signal=SIG`/`-s SIG`/`-sSIG` | 有值 | + +Flag 值安全校验:必须匹配 `[A-Za-z0-9_.+-]+`,拒绝 `$()` `` ` `` `|;&` 等。 + +### 9.6 nice 的三种形式 + +- `nice cmd`(无参数) +- `nice -N cmd`(传统,如 `nice -10 ls`) +- `nice -n N cmd`(POSIX,如 `nice -n 10 ls`) + +### 9.7 env 的安全/不安全 flag + +- 安全:`-i`(清空环境)、`-0`(NUL 分隔)、`-v`(详细)、`-u NAME`(删除变量) +- 不安全(遇到则停止剥离):`-S`(字符串拆分 → 注入参数)、`-C`(改 cwd)、`-P`(改 PATH) + +--- + +## 10. 输出重定向与进程替换校验 + +### 10.1 输出重定向 + +| 运算符 | 校验 | +|--------|------| +| `>`、`>|`、`&>` | 目标路径校验,操作类型 `create` | +| `>>`、`&>>` | 同上 | +| `>&N`(如 `2>&1`) | **不校验**(fd 复制) | +| `>&file` | 同 `>` | + +- `/dev/null` 始终放行 +- 目标含 `$VAR`/`%VAR%` → 拒绝(无法确定实际路径) + +### 10.2 进程替换 + +```bash +echo secret > >(tee .git/config) # 输出替换:写入目标不在重定向列表中 +diff <(sort a.txt) <(sort b.txt) # 输入替换:只读操作,风险低 +``` + +区分输入/输出替换,分别标记 category: +- `>(cmd)` → ask(category=`process_output_sub`):可能绕过路径校验写入未知位置 +- `<(cmd)` → ask(category=`process_input_sub`):本质是读操作 + +auto 模式下:输出替换 → deny,输入替换 → allow + +### 10.3 复合命令中的 cd 安全问题 + +复合命令(`&&`/`;`)包含 `cd` + write/create 操作 → 强制 ask。 + +原因:路径校验基于原始 cwd,但 `cd` 在运行时改变了工作目录。 +攻击:`cd .claude/ && mv test.txt settings.json` → 校验看到 `settings.json`(相对原始 cwd),实际写入 `.claude/settings.json`。 + +--- + +## 11. 共享基础设施 + +### 11.1 PermissionMatcher (`matcher.py`) + +两层共用的通配符匹配逻辑: + +```python +class PermissionMatcher: + def match(self, pattern: str, tool_call: str) -> bool: + """使用 fnmatch 做通配符匹配""" + + def match_with_content(self, pattern: str, tool_name: str, tool_args: dict) -> bool: + """支持 server---tool:content_pattern 格式""" +``` + +- 工具名格式:`{server_name}---{tool_name}`(与 `ToolManager.TOOL_SPLITER = '---'` 一致) +- 支持 `*` / `?` 通配符,`|` 分隔多模式 +- 支持 `server---tool:content_pattern` 格式(content 从 tool_args 中提取) + +--- + +## 12. 集成点与代码变更 + +### 12.1 `tool_manager.py` 注入权限检查 + +在 `ToolManager.single_call_tool()` 中,解析 tool_name/tool_args 之后、`tool_ins.call_tool()` 之前: + +```python +# --- 权限检查注入点 (tool_manager.py ~L294) --- +args_dict = dict(tool_args) if isinstance(tool_args, dict) else {} + +# 内层:安全底线检查(不可绕过) +if self._safety_guard is not None: + from ms_agent.permission.ask_resolver import resolve_ask + safety_decision = self._safety_guard.check(tool_name, args_dict) + if safety_decision.action == 'deny': + return f'Blocked by safety policy: {safety_decision.reason}' + if safety_decision.action == 'ask': + resolved = resolve_ask(safety_decision, self._permission_mode, self._read_policy) + if resolved.action == 'deny': + return f'Blocked by safety policy: {resolved.reason}' + if resolved.action == 'ask': + if self._permission_enforcer is None: + return f'Blocked by safety policy (requires confirmation): {resolved.reason}' + # interactive 模式:fall through 到 enforcer/handler + +# 外层:用户意图检查(可被用户覆盖) +if self._permission_enforcer is not None: + perm_decision = await self._permission_enforcer.check(tool_name, args_dict) + if perm_decision.action == 'deny': + return f'Tool call denied: {perm_decision.reason}' + if perm_decision.updated_args is not None: + tool_args = perm_decision.updated_args + tool_info['arguments'] = tool_args + +# ... 继续现有的 tool_ins.call_tool() 逻辑 ... +``` + +`ToolManager.__init__()` 增加可选参数: +- `permission_enforcer: PermissionEnforcer | None = None` +- `safety_guard: SafetyGuard | None = None` + +两者由上层(LLMAgent 或 Server)根据配置注入。 + +### 12.2 初始化链路 + +```python +# llm_agent.py: _build_permission_objects() +raw = dict(self.config.permission) if self.config.permission else {} +project_root = os.getcwd() +perm_config = PermissionConfig.from_dict(raw, project_root=project_root) + +# workspace_root = output_dir,保证 SafetyGuard 和工具端使用相同基目录解析相对路径 +output_dir = str(Path(getattr(self.config, 'output_dir', './output')).expanduser().resolve()) + +# 创建 SafetyGuard(内层) +allowed_dirs = [project_root] + list(perm_config.safety.allowed_directories) +read_only_dirs = list(perm_config.safety.read_only_directories) +safety_guard = SafetyGuard( + config=perm_config.safety, + allowed_dirs=allowed_dirs, + read_only_dirs=read_only_dirs, + workspace_root=output_dir, # 统一路径解析基目录 +) + +# 创建 PermissionEnforcer(外层) +handler = AutoPermissionHandler() # 或 CLIPermissionHandler() / WebPermissionHandler(emitter) +memory = PermissionMemory(project_path=project_root) +enforcer = PermissionEnforcer(config=perm_config, handler=handler, memory=memory) + +# 注入到 ToolManager(含 mode 和 read_policy,供 ask_resolver 使用) +tool_manager = ToolManager( + ..., + permission_enforcer=enforcer, + safety_guard=safety_guard, + permission_mode=perm_config.mode, + read_policy=perm_config.safety.read_policy, +) +``` + +--- + +## 13. 已完成迁移:WorkspacePolicyKernel → SafetyGuard + WorkspaceContext + +> **状态:✅ 已完成** +> +> `ms_agent/utils/workspace_policy.py` 已从代码库中删除。安全职责统一到 SafetyGuard,功能职责提取为 `WorkspaceContext`。 + +### 13.1 迁移结果 + +| 原 WPK 代码 | 迁移去向 | 实现位置 | +|-------------|---------|---------| +| `resolve_under_roots()` | `SafetyGuard._check_file_path()` → `validate_path()` | `ms_agent/permission/safety.py` | +| `path_is_allowed()` | SafetyGuard 在 tool_manager 层统一检查,工具层不再需要 | 已删除 | +| `deny_globs` | `WorkspaceContext.deny_globs`(功能用途:文件遍历过滤) | `ms_agent/utils/workspace_context.py` | +| `workspace_root` | `WorkspaceContext.root`(功能用途:subprocess cwd、相对路径显示) | `ms_agent/utils/workspace_context.py` | +| `assert_shell_command_allowed()` | `SafetyGuard.check()` → `ShellPathValidator.check()` | `ms_agent/permission/shell_validator.py` | +| `_shell_looks_network()` | `PermissionConfig._DEFAULT_BLACKLIST`(enforcer 层策略) | `ms_agent/permission/config.py` | +| `_shell_looks_mutating_or_network()` | `ShellPathValidator` 操作类型 + 路径校验 | `ms_agent/permission/shell_validator.py` | +| `iter_files_under()` | 已删除(确认无外部调用者) | — | + +### 13.2 调用点变更(实际代码) + +**`filesystem_tool.py`:** +```python +# 之前 +from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel +self._fs_policy = WorkspacePolicyKernel(output_dir, extra_allow_roots=roots, deny_globs=deny) +root = self._fs_policy.resolve_under_roots(path) # grep/glob 路径解析 +cwd=str(self._fs_policy.workspace_root) # subprocess cwd +deny = self._fs_policy.deny_globs # 文件遍历过滤 +self._fs_policy.path_is_allowed(rp) # glob 结果路径检查 + +# 之后 +from ms_agent.utils.workspace_context import WorkspaceContext +self._ws = WorkspaceContext.from_config(config) +root = (self._ws.root / raw).resolve() # 路径解析(安全检查在 SafetyGuard 完成) +cwd=str(self._ws.root) # subprocess cwd +deny = self._ws.deny_globs # 文件遍历过滤 +# path_is_allowed 已删除 — SafetyGuard 在 tool_manager 层已检查 path 参数 +``` + +**`local_code_executor.py`:** +```python +# 之前 +from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel +self._policy = WorkspacePolicyKernel(output_dir, ...) +self._policy.assert_shell_command_allowed(command) # 命令安全检查 +cwd=str(self._policy.workspace_root) # subprocess cwd + +# 之后 +from ms_agent.utils.workspace_context import WorkspaceContext +self._ws = WorkspaceContext.from_config(config) +# assert_shell_command_allowed 已删除 — SafetyGuard 在 tool_manager 层已检查 +cwd=str(self._ws.root) # subprocess cwd +``` + +### 13.3 关键设计决策 + +**1. WorkspaceContext 不含任何安全逻辑** + +```python +@dataclass(frozen=True) +class WorkspaceContext: + root: Path # workspace cwd(原 output_dir) + deny_globs: tuple[str, ...] = ('**/.git/**',) # 文件遍历过滤模式 + + @classmethod + def from_config(cls, config) -> WorkspaceContext: ... +``` + +WorkspaceContext 是纯功能性的——提供 subprocess 的 cwd 和 grep/glob 的文件过滤模式。所有安全校验(路径白名单、敏感路径、命令检查)统一在 SafetyGuard 层完成。 + +**2. SafetyGuard 接受 workspace_root 统一路径解析** + +迁移前存在路径解析不一致:SafetyGuard 用 `os.getcwd()`,WPK 用 `output_dir`。迁移后 `SafetyGuard.__init__` 接受 `workspace_root` 参数,`llm_agent.py` 传入 `output_dir`,两端使用相同的基目录。 + +**3. SafetyGuard 覆盖 grep/glob 工具** + +迁移前 grep/glob 的路径安全完全依赖 WPK 的 `resolve_under_roots()`。迁移后 SafetyGuard 的 `check()` 新增 `---grep` 和 `---glob` 分支,复用 `_check_file_path(path, 'read')`。 + +**4. 网络命令检测迁移到 enforcer 层** + +WPK 的 `_shell_looks_network()` 硬编码阻止 curl/wget/ssh 等命令。这不属于安全底线(用户可以选择允许),因此迁移到 `PermissionConfig._DEFAULT_BLACKLIST`: +- auto 模式下被 enforcer deny +- interactive 模式下 ask 用户确认 +- 用户可通过 whitelist 覆盖 + +### 13.4 兼容性保证 + +| 原 WPK 行为 | 新系统对应 | 保证方式 | +|-------------|----------|---------| +| `deny_globs` 默认 `('**/.git/**',)` | `WorkspaceContext` 默认值 | 代码中硬编码 | +| `shell_network_enabled = False` | `_DEFAULT_BLACKLIST` 包含 curl/wget/ssh/scp/rsync/nc/netcat | 默认 blacklist | +| `max_command_chars` 限制 | `ShellPathValidator.check()` 入口检查 | 配置传递链保留 | +| `shell_default_mode = 'read_only'` | 暂未迁移 | 标记为后续优化 | + +--- + +## 14. YAML 配置格式(统一) + +```yaml +permission: + # --- 外层:用户意图 --- + mode: auto # auto | strict | interactive(兼容旧名 restricted → interactive) + + whitelist: + - "file_system---read_file" + - "file_system---grep" + - "file_system---glob" + - "web_search---*" + + blacklist: # 以下为内置默认值,用户配置会追加合并 + - "code_executor---shell_executor:curl *" # 默认 + - "code_executor---shell_executor:wget *" # 默认 + - "code_executor---shell_executor:ssh *" # 默认 + - "code_executor---shell_executor:scp *" # 默认 + - "code_executor---shell_executor:rsync *" # 默认 + - "code_executor---shell_executor:nc *" # 默认 + - "code_executor---shell_executor:netcat *" # 默认 + # - "custom---tool" # 用户自定义追加 + + # --- 内层:安全底线 --- + safety_rules: + # 通用工具级拦截(不可被用户覆盖) + patterns: + - "code_executor---shell_executor:rm -rf *" + - "code_executor---shell_executor:mkfs *" + - "code_executor---shell_executor:dd if=*" + - "file_system---write_file:/etc/*" + - "file_system---write_file:/sys/*" + + # 危险删除路径(rm/rmdir 专用,不可被用户覆盖) + dangerous_removal_paths: + - "*" + - "/*" + - "/" + - "~" + + # 系统敏感路径(write/create 操作拦截) + sensitive_paths: + - "/etc/*" + - "/sys/*" + - "/boot/*" + - "/dev/*" + - "/proc/*" + - "~/.ssh/*" + - "~/.gnupg/*" + - "~/.bashrc" + - "~/.zshrc" + - "~/.profile" + - ".git/config" + - ".git/hooks/*" + - "**/.git/**" # 兼容原 WorkspacePolicyKernel 默认 deny_globs + + # --- 路径校验配置 --- + allowed_directories: # 完全访问(读 + 写 + create) + - "${PROJECT_ROOT}" + - "/tmp/ms-agent-workspace" + + read_only_directories: # 只读访问(读允许,写/create 拒绝) + - "/data/models" + - "/usr/local/lib" + + path_validation: + read_policy: loose # loose: 读超出 allowed_dirs ∪ read_only_dirs 时 auto 模式放行 + # strict: 读超出范围时 auto 模式拒绝 + max_command_chars: 8192 # 兼容原 WorkspacePolicyKernel +``` + +--- + +## 15. 文件结构 + +``` +ms_agent/permission/ +├── __init__.py # 导出 PermissionEnforcer, SafetyGuard, resolve_ask 等 +├── config.py # PermissionConfig + SafetyConfig + _DEFAULT_BLACKLIST — 解析 YAML +├── ask_resolver.py # resolve_ask() — ask 模式解析(auto/strict/interactive) +├── matcher.py # PermissionMatcher — 通配符匹配逻辑(两层共用) +├── enforcer.py # PermissionEnforcer — 外层判定入口 +├── handler.py # PermissionHandler 协议 + Auto/CLI/Web 三种实现 +├── memory.py # PermissionMemory — "以后都允许" 持久化 +├── suggestions.py # generate_suggestions() — 自动建议模式生成 +├── safety.py # SafetyGuard — 内层安全底线(含 workspace_root + grep/glob 覆盖) +├── path_validator.py # validate_path() — 单路径校验函数 +├── shell_validator.py # ShellPathValidator — shell 命令路径级校验 + SafetyDecision +├── path_extractors.py # PATH_EXTRACTORS — 36 个命令的路径提取器注册表 +├── wrapper_strip.py # strip_safe_wrappers() — Safe Wrapper 剥离 +└── sed_validator.py # sed 表达式安全校验(防御纵深) + +ms_agent/utils/ +├── workspace_context.py # WorkspaceContext — 轻量功能上下文(替代 WPK 的 root/deny_globs) +├── ... # 其他 utils +└── [已删除] workspace_policy.py # WorkspacePolicyKernel 已迁移删除 +``` + +--- + +## 16. 与 Claude Code 的对比 + +| 特性 | Claude Code | ms-agent 方案 | +|------|-------------|--------------| +| 权限行为 | allow / deny / ask | allow / deny / ask | +| 模式 | default / acceptEdits / bypassPermissions / dontAsk / plan / auto | auto / strict / interactive | +| auto 模式 ask 处理 | AI 分类器(Haiku)判断 + denial tracking | 规则策略表按 category 自动解析(无额外 LLM 调用) | +| dontAsk/strict 等效 | ask → deny | ask → deny | +| interactive 等效 | default mode: ask → 弹出用户确认 | ask → handler.ask() | +| 询问选项 | Yes / Yes (session) / No / Tab to amend | allow_once / allow_session / allow_always / modify / deny | +| 规则持久化 | settings.json (user/project/local) | permission_memory.json (project/global) | +| 参数修改 | updatedInput + userModified | updated_args (via modify action) | +| 建议生成 | PermissionUpdate[] suggestions | generate_suggestions() → pattern list | +| 规则格式 | ToolName(ruleContent) | server---tool:content_pattern | +| 会话级规则 | session source in ToolPermissionContext | 内存中的 session_memory dict | +| 路径提取 | PATH_EXTRACTORS (34 命令) | PATH_EXTRACTORS (36 命令,含 cd/ls/tr 特殊处理) | +| 操作类型 | read / write / create | read / write / create | +| 危险路径 | isDangerousRemovalPath() | is_dangerous_removal_path() | +| Wrapper 剥离 | stripSafeWrappers() (2 阶段) | strip_safe_wrappers() (2 阶段) | +| 命令解析 | tree-sitter AST + shell-quote | shlex + 正则(一期),可扩展 AST | +| sed 安全 | 多层允许列表 + 拒绝列表 | 多层允许列表 + 拒绝列表 + 任意分隔符感知 | +| 目录权限 | additionalDirectories(读写不分离) | allowed_directories(读+写) + read_only_directories(只读) + read_policy 兜底 | +| 进程替换 | 不区分 input/output,统一 ask | 区分:`<(` → allow, `>(` → deny(auto 模式) | +| 前端交互 | React UI 组件 | Future + SSE + REST 回调 | +| 现有代码 | 无遗留 | WorkspacePolicyKernel 已迁移删除,安全→SafetyGuard,功能→WorkspaceContext | + +--- + +## 17. 验证方式 + +### 17.1 单元测试 + +| 模块 | 测试要点 | +|------|---------| +| PermissionMatcher | 通配符匹配、`|` 分隔、content pattern、边界情况 | +| PermissionEnforcer | auto/strict/interactive 模式、黑白名单优先级、session/persistent memory | +| ask_resolver | auto 模式 7 个 category 解析、strict 全 deny、interactive 保持 ask、read_policy | +| PermissionMemory | 项目级/全局级持久化、合并优先级、add/revoke/list | +| SafetyGuard | safety_rules 匹配、工具特化分发、deny 不可覆盖 | +| ShellPathValidator | 完整流水线(进程替换→拆分→剥离→提取→校验→重定向) | +| PATH_EXTRACTORS | 36 个命令各自的提取逻辑、`--` 处理、edge case | +| validate_path | 波浪号展开、shell 展开拒绝、glob 处理、目录范围检查、read_only_dirs 读写分离 | +| is_dangerous_removal_path | 7 种危险模式、路径规范化 | +| strip_safe_wrappers | 6 个 wrapper、2 阶段算法、安全环境变量白名单 | +| sed_validator | 只读降级、危险表达式拦截 | +| CLIPermissionHandler | 5 种操作的交互流程、allow_always 可编辑模式 | +| WebPermissionHandler | Future 挂起/resolve、超时 deny | +| generate_suggestions | 各工具类型的建议模式生成 | + +### 17.2 集成测试 + +| 场景 | 验证内容 | +|------|---------| +| SafetyGuard → ToolManager | mock SafetyGuard,验证 deny 时 single_call_tool 返回拒绝消息 | +| PermissionEnforcer → ToolManager | mock PermissionEnforcer,验证 deny/modify 行为 | +| WorkspacePolicyKernel 兼容 | 原测试用例在新系统中全部通过 | +| 端到端权限流程 | interactive 模式下工具调用→ask→用户选择→执行/拒绝 | +| auto 模式无交互 | auto 模式下 SafetyGuard ask → resolve_ask → allow/deny(无 hang) | +| strict 模式保守 | strict 模式下所有不确定命令被 deny | + +### 17.3 安全回归测试 + +针对已知攻击向量逐一验证: + +| 攻击 | 预期结果 | +|------|---------| +| `rm -rf /` | deny(危险路径) | +| `timeout 10 rm -rf /` | deny(剥离 wrapper 后识别) | +| `rm -- -/../.claude/settings.json` | deny(`--` 后正确提取路径) | +| `echo "x" > /etc/passwd` | deny(重定向写入超出 allowed_dirs) | +| `cd .claude/ && mv test settings.json` | auto→deny / interactive→ask(cd + write) | +| `rm $HOME/.ssh/*` | auto→deny / interactive→ask(shell 展开) | +| `env HOME=/tmp rm -rf ~` | 不剥离 HOME(不安全变量) | +| `echo secret > >(tee .git/config)` | auto→deny / interactive→ask(输出进程替换) | +| `diff <(sort a.txt) <(sort b.txt)` | auto→allow(输入进程替换,只读) | +| `mv --target-directory=/etc test.txt` | auto→deny / interactive→ask(命令校验器) | +| `sed -e 's/x/y/w /etc/passwd' file` | deny(sed 表达式安全检查) | +| `sed -e 's\|x\|y\|w /etc/passwd' file` | deny(sed 任意分隔符表达式安全检查) | + +--- + +## 18. 实现审查:已知问题与待办 + +> **审查日期**:2026-06-09 +> +> 对照实现路径:`ms_agent/permission/`、`ms_agent/tools/tool_manager.py`、`ms_agent/agent/llm_agent.py` +> +> **测试现状**:`tests/permission/` 共 248 项单元测试通过;§17.2 集成测试与部分 Handler 单测尚未落地。 + +### 18.1 实现完成度概览 + +| 维度 | 状态 | 说明 | +|------|------|------| +| 双层架构(SafetyGuard + PermissionEnforcer) | ✅ 已完成 | `ToolManager.single_call_tool()` 统一入口 | +| Shell 路径级校验(36 命令注册表) | ✅ 已完成 | 进程替换、重定向、wrapper 剥离、sed 纵深防御 | +| WorkspacePolicyKernel 迁移 | ✅ 已完成 | 安全→SafetyGuard,功能→WorkspaceContext | +| YAML 配置解析与默认规则 | ✅ 已完成 | `PermissionConfig.from_dict()` | +| auto / strict 模式 | ✅ 可用 | `resolve_ask` 规则表消歧,无交互 hang | +| interactive 模式端到端 | ⚠️ 未完成 | Handler 未按场景接入,见 §18.2 | +| Web 前后端权限协议 | ⚠️ 未完成 | `WebPermissionHandler` 已实现但未接通 webui | +| 集成测试 | ⚠️ 未完成 | §17.2 所列场景尚无对应用例 | + +### 18.2 已知问题(按优先级) + +#### P0 — 安全正确性 + +| # | 问题 | 设计预期 | 当前实现 | 状态 | +|---|------|---------|---------|------| +| 1 | ~~Shell 路径校验 cwd 与执行 cwd 不一致~~ | §13.2:统一使用 workspace root | `resolve_workspace_root()` 统一解析;`ShellPathValidator` 通过 `PathSafetyConfig.workspace_root` 校验;Agent / 文件工具 / SafetyGuard 共用同一根目录 | ✅ 已修复(2026-06-09) | +| 2 | **interactive 模式 Handler 未接入** | §12.2:按运行环境注入 Handler | `llm_agent._build_permission_objects()` 始终使用 `AutoPermissionHandler()` | 待修复 | + +#### P1 — 交互与安全策略缺口 + +| # | 问题 | 设计预期 | 当前实现 | 涉及文件 | +|---|------|---------|---------|---------| +| 3 | **SafetyGuard `ask` 可被白名单/memory 绕过** | §2 判定流程:interactive 模式下 SafetyGuard `ask` 应交给 handler 确认 | `ToolManager` 在 SafetyGuard 返回 `ask` 后直接 fall through 到 `PermissionEnforcer.check()`;若命中 whitelist 或 memory 则直接 `allow`,跳过了安全疑点确认 | `tool_manager.py:305-310`、`enforcer.py` | +| 4 | **Web 前后端集成缺失** | §3.3.3:`permission_request` 事件 + `POST /permission/respond` | `WebPermissionHandler` 类已实现,但 `webui/` 无事件处理与 REST 回调,`resolve()` 无调用链 | `handler.py`、`webui/` | +| 5 | **`sensitive_paths` 未覆盖 shell 写路径** | §8.2:对任何 write 操作均需警惕系统敏感路径 | `SafetyGuard._check_file_path()` 对 `write_file`/`edit_file` 做 fnmatch 检查;shell 重定向和命令路径仅走 `validate_path` 目录范围检查,不查 `sensitive_paths` | `safety.py`、`shell_validator.py` | +| 6 | **`dangerous_removal_paths` YAML 配置未生效** | §14:`safety_rules.dangerous_removal_paths` 可配置 | `SafetyConfig.dangerous_removal_paths` 已解析,但 `is_dangerous_removal_path()` 逻辑硬编码,未读取配置 | `config.py`、`path_validator.py` | + +#### P2 — 细节与文档 + +| # | 问题 | 说明 | 涉及文件 | +|---|------|------|---------| +| 7 | ~~`generate_suggestions` 未剥离 wrapper~~ | 复用 `strip_safe_wrappers()` | `suggestions.py` | ✅ 已修复(2026-06-09) | +| 8 | ~~`allowed_dirs` 与 `workspace_root` 来源不一致~~ | 统一 workspace root | `resolve_workspace_root()` + `allowed_dirs=[workspace_root, …]` | ✅ 已修复(2026-06-09) | +| 9 | ~~`PermissionConfig(mode='restricted')` 测试绕过别名~~ | 测试统一走 `from_dict()` | `tests/permission/test_enforcer.py` | ✅ 已修复(2026-06-09) | +| 10 | ~~§12.1 伪代码过时~~ | 更新为 `resolve_ask` 流程 | 本文档 §12.1 | ✅ 已修复(2026-06-09) | +| 11 | **`path_extractors.py` 注释写 34 命令** | 实际注册 36 条(与设计 §6.4 一致) | `path_extractors.py:319` | + + +### 18.3 测试覆盖缺口 + +对照 §17.1 / §17.2,以下测试尚未实现: + +| 缺失项 | 优先级 | +|--------|--------| +| `CLIPermissionHandler` 交互流程单测 | P1 | +| `WebPermissionHandler` Future 挂起/resolve/超时单测 | P1 | +| ~~`generate_suggestions` 各工具类型单测~~ | ✅ 已补充 | +| `ToolManager` + SafetyGuard / Enforcer 集成测试 | P1 | +| interactive 模式端到端流程测试 | P1 | + +### 18.4 修复路线图 + +``` +Phase 1(安全正确性) + ├─ ✅ ShellPathValidator 接受 workspace_root,与 SafetyGuard / subprocess cwd 对齐 + ├─ ✅ resolve_workspace_root:未配置 output_dir 时默认 cwd;Agent / 工具 / 权限层统一 + ├─ sensitive_paths 扩展到 shell 重定向与 write 路径校验 + └─ dangerous_removal_paths 配置化(替换硬编码 is_dangerous_removal_path) + +Phase 2(交互可用) + ├─ _build_permission_objects 按 mode + 运行环境选择 Handler + ├─ SafetyGuard ask 时 enforcer 跳过 whitelist/memory,强制走 handler + └─ Web:permission_request 事件 + POST /permission/respond 接通 webui + +Phase 3(完善) + ├─ ToolManager 集成测试 + Handler 单测 + ├─ ✅ generate_suggestions 增加 wrapper 剥离 + └─ ✅ 更新 §12.1 伪代码 +``` + +--- + +## 附录 A:parse_pattern_command 通用实现 + +```python +def parse_pattern_command( + args: list[str], + flags_with_args: set[str], + defaults: list[str] | None = None, +) -> list[str]: + """解析 grep/rg/jq 类命令的参数,提取文件路径""" + paths = [] + pattern_found = False + after_double_dash = False + + i = 0 + while i < len(args): + arg = args[i] + if arg is None: + i += 1; continue + if not after_double_dash and arg == '--': + after_double_dash = True + i += 1; continue + if not after_double_dash and arg.startswith('-'): + flag = arg.split('=')[0] + if flag in ('-e', '--regexp', '-f', '--file'): + pattern_found = True + if flag in flags_with_args and '=' not in arg: + i += 1 + i += 1; continue + if not pattern_found: + pattern_found = True + i += 1; continue + paths.append(arg) + i += 1 + return paths if paths else (defaults or []) +``` + +--- + +## 附录 B:完整命令操作类型对照表 + +| 命令 | 操作类型 | 提取策略 | 附加校验 | 描述 | +|------|---------|---------|---------|------| +| cd | read | 特殊 | cd+write 检查 | change directories to | +| ls | read | A+默认 | - | list files in | +| find | read | D | - | search files in | +| mkdir | create | A | - | create directories in | +| touch | create | A | - | create or modify files in | +| rm | write | A | 危险删除路径 | remove files from | +| rmdir | write | A | 危险删除路径 | remove directories from | +| mv | write | A | 命令校验器 | move files to/from | +| cp | write | A | 命令校验器 | copy files to/from | +| cat | read | A | - | concatenate files from | +| head | read | A | - | read the beginning of files from | +| tail | read | A | - | read the end of files from | +| sort | read | A | - | sort contents of files from | +| uniq | read | A | - | filter duplicate lines from files in | +| wc | read | A | - | count lines/words/bytes in files from | +| cut | read | A | - | extract columns from files in | +| paste | read | A | - | merge files from | +| column | read | A | - | format files from | +| tr | read | 特殊 | - | transform text from files in | +| file | read | A | - | examine file types in | +| stat | read | A | - | read file stats from | +| diff | read | A | - | compare files from | +| awk | read | A | - | process text from files in | +| strings | read | A | - | extract strings from files in | +| hexdump | read | A | - | display hex dump of files from | +| od | read | A | - | display octal dump of files from | +| base64 | read | A | - | encode/decode files from | +| nl | read | A | - | number lines in files from | +| grep | read | B | `-r` 默认 `.` | search for patterns in files from | +| rg | read | B | 默认 `.` | search for patterns in files from | +| sed | write/read | C | 降级+表达式检查 | edit files in | +| git | read | E | 仅 `diff --no-index` | access files with git from | +| jq | read | C | - | process JSON from files in | +| sha256sum | read | A | - | compute SHA-256 checksums for files in | +| sha1sum | read | A | - | compute SHA-1 checksums for files in | +| md5sum | read | A | - | compute MD5 checksums for files in | diff --git a/docs/zh/design/plugins-design.md b/docs/zh/design/plugins-design.md new file mode 100644 index 000000000..b295cf195 --- /dev/null +++ b/docs/zh/design/plugins-design.md @@ -0,0 +1,1763 @@ +# Plugins 兼容系统设计文档 + +> 基于 [`playground_prototype_design.md`](../../../playground_prototype_design.md) F9(Plugins 兼容);与已落地模块对齐: +> - [`hooks-design.md`](hooks-design.md) F6 / F9 Plugin hooks +> - [`mcp_runtime_management.md`](mcp_runtime_management.md) F7 Plugin MCP(`.mcp.json` / `tools/mcp.json`)→ MCPRuntime +> - Skill 体系(PR#907:`SkillCatalog` + `SkillRuntime` + `SkillsConfigManager`) +> - [`permission-design.md`](permission-design.md) 双层权限(Plugin 贡献的 MCP/Tool 调用进入 ToolManager 后受约束;Plugin hook command 脚本本身是独立子进程,见 §10.2) +> +> 状态:方案设计 v0.4 | 2026-06-23(修正安全边界、已实现状态、PluginRuntime/HookRuntime 职责、Phase 0/1 验收与链接) + +--- + +## 目录 + +- [1. 背景与目标](#1-背景与目标) +- [2. 现状分析](#2-现状分析) +- [3. 总体架构](#3-总体架构) +- [4. Plugin 包格式与 Manifest](#4-plugin-包格式与-manifest) + - [4.4 组件能力注册表(Component Registry)](#44-组件能力注册表component-registry) +- [5. 发现、安装与配置分层](#5-发现安装与配置分层) +- [6. PluginLoader — 分发注册](#6-pluginloader--分发注册) +- [7. 子资源加载语义](#7-子资源加载语义)(skills / commands / agents / hooks / mcp / bin / settings / userConfig) +- [8. PluginRuntime — 运行时管理](#8-pluginruntime--运行时管理) +- [9. 环境变量与路径变量](#9-环境变量与路径变量) +- [10. 与 Command / Permission 的协作](#10-与-command--permission-的协作) +- [11. 集成点与代码变更](#11-集成点与代码变更) +- [12. API 与 UI 数据模型](#12-api-与-ui-数据模型) +- [13. 文件结构](#13-文件结构) +- [14. 兼容矩阵](#14-兼容矩阵) +- [15. 分阶段交付与验收](#15-分阶段交付与验收) +- [16. 多生态兼容:OpenClaw 与 Hermes](#16-多生态兼容openclaw-与-hermes) +- [17. 风险与对策](#17-风险与对策) +- [18. 测试策略](#18-测试策略) +- [19. 社区 Plugin 组件全景(调研)](#19-社区-plugin-组件全景调研) +- [附录 D:黄金测例 — hookify](#附录-d黄金测例--hookify) +- [附录 A:plugins.json 示例](#附录-apluginsjson-示例) +- [附录 B:plugin.json 字段对照(Claude Code)](#附录-bpluginjson-字段对照claude-code) +- [附录 C:跨文档约定](#附录-c跨文档约定) + +--- + +## 1. 背景与目标 + +### 1.1 产品背景 + +Claude Code / Codex 社区已沉淀大量 **Plugin 包**:在单一目录内打包 manifest + 多种可加载组件(skills、agents、commands、hooks、MCP、settings 等)。MS-Agent 实验场(Playground)需要 **复用该生态**,避免重复造轮子。 + +Playground 原型(F9)定义的核心诉求(**已扩展为完整组件集**,详见 §4.4 / §19): + +| 能力 | 说明 | 优先级 | +|------|------|--------| +| Manifest 解析 | 多生态 manifest 路径 + 安装时 `format` / `manifest_path` 锁定 | P0 | +| Skills 分发 | `skills/`、根 `SKILL.md` → `SkillCatalog` | P0 | +| Commands 分发 | `commands/*.md` → Skill 或 `CommandRouter` | P1 | +| Agents 分发 | `agents/*.md` → 子 agent 模板 / `AgentDelegate` | P1 | +| Hooks 分发 | `hooks/hooks.json`、`hooks/hermes.yaml` → `HookRegistry` | P0/P1 | +| MCP 分发 | `.mcp.json` / `tools/mcp.json` → `MCPRuntime` | P1 | +| 运行时辅助 | `bin/` PATH、`settings.json` 补丁、`userConfig` 表单 | P1 | +| 元数据 / UI | `assets/`、`interface.*`、`dependencies` | P1 UI | +| 环境变量桥接 | `PLUGIN_ROOT` / `PLUGIN_DATA` / `user_config.*` | P0 | +| 安装来源 | 本地 / `github://` / `modelscope://` / marketplace | P0–P1 | + +### 1.2 设计原则 + +1. **Plugin 是容器,不是新子系统**:不重复实现 Skill / Hook / MCP 逻辑,只做发现、安装、enabled 管理、环境桥接与向各 Runtime 分发。 +2. **与分层配置一致**:全局 → 项目 → session 的 `plugins.json` 与 `mcp.json` / `skills.json` 同级;Plugin 内子资源的 enabled 语义遵循各子系统既有规则。 +3. **Gateway 无关**:同一套 `ms_agent/plugins/` 供 WebUI、TUI、CLI 共用。 +4. **多生态并存**:Claude Code Plugin 为主路径;**OpenClaw / Hermes 的「检测 + 可复用子资源加载」并入 P1**(见 §16);进程内 hook(OpenClaw `handler.ts`、Hermes Python plugin)不原生执行。 +5. **安全默认**:Plugin 安装不自动 `trust_remote_code`;Plugin hooks 默认不启用。MCP / 内置 tool 调用进入 `ToolManager.single_call_tool()` 后受 SafetyGuard + Permission 约束;**hook command 脚本本身由 HookExecutor 直接启动,不经过 ToolManager,不能宣称受 PreToolUse/Permission 二次拦截**。 + +### 1.3 与已落地三个模块的关系 + +当前分支 `feat/new_playground_part` 已落地: + +| 模块 | 文档 | Plugin 依赖点 | +|------|------|---------------| +| 权限管控 F4 | `permission-design.md` | Agent 发起的 MCP / tool 调用仍经 `ToolManager.single_call_tool()`;Plugin hook command 脚本自身按 hook 安全策略治理 | +| Hooks F6 | `hooks-design.md` | **`PluginHooksLoader` 已实现**;缺统一 Manifest 与 `PluginRuntime` | +| MCP 运行时 F7 | `mcp_runtime_management.md` | Phase 3 待做:Plugin `mcp` capability → MCP server | + +``` + ┌─────────────────────────────────────┐ + │ PluginRuntime │ + │ (install / enabled / hot-reload) │ + └──────────┬──────────────────────────┘ + │ PluginLoader.load_all() + ┌───────────┬───────────┼───────────┬───────────┬──────────────┐ + ▼ ▼ ▼ ▼ ▼ ▼ + skills/ commands/ agents/ hooks/ .mcp.json bin/ settings + │ │ │ │ │ │ + ▼ ▼ ▼ ▼ ▼ ▼ + SkillRuntime CommandRouter AgentRegistry HookRuntime MCPRuntime Executor/Config + + Catalog (P1) + Registry + ToolMgr patch (P1) + │ │ │ │ │ │ + └───────────┴───────────┴───────────┴───────────┴──────────────┘ + │ + ToolManager.single_call_tool() + (SafetyGuard → PermissionEnforcer → Hooks) +``` + +**安全边界**:上图描述的是 Agent 侧 tool 调用链。`hooks/hooks.json` 中 `type=command` 的脚本由 `HookExecutor` 直接以子进程执行,当前不会再作为一次 shell tool call 进入 `PermissionEnforcer`。因此 Plugin hooks 必须通过 `hooks.enabled_sources` 显式开启,并在安装 / 启用 UI 中提示风险。 + +**不经过 PluginLoader 独立加载、仅扫描/report 的组件**:LSP、output-styles、themes、monitors、channels、OpenClaw `handler.ts`、Hermes Python plugin 等(§4.4 `unsupported` / `detect-only`)。 + +--- + +## 2. 现状分析 + +### 2.1 已实现(F9 局部) + +| 组件 | 位置 | 状态 | +|------|------|------| +| `PluginHooksLoader` | `ms_agent/hooks/loaders/plugin.py` | ✅ 读取 `/hooks/hooks.json`,委托 `ClaudeSettingsLoader` | +| Plugin 根目录发现 | `ms_agent/hooks/factory.py::_discover_plugin_roots` | ✅ 扫描 `.ms-agent/plugins/*` + `agent.yaml` 的 `plugins:` 列表 | +| Hook 环境变量 | `ms_agent/hooks/executors/command.py::build_hook_env` | ⚠️ executor 已预留 `MS_AGENT_PLUGIN_ROOT` / `CLAUDE_PLUGIN_ROOT` / `MS_AGENT_PLUGIN_DATA`;执行期 `HookRuntime._ctx()` 尚未传入 plugin root/data | +| 路径变量展开 | `ms_agent/hooks/loaders/claude.py::_expand_path_vars` | ✅ `${CLAUDE_PLUGIN_ROOT}` 等 | + +### 2.2 缺口 + +| 缺口 | 影响 | +|------|------| +| 无 `ms_agent/plugins/` 模块 | 无 Manifest 解析、无安装器、无 CRUD | +| Plugin `skills/` 未自动挂载 | 用户需手动把 plugin 路径写入 `skills.json` | +| Plugin `mcp` 未实现 | 见 `mcp_runtime_management.md` Phase 3(`.mcp.json` / `tools/mcp.json`) | +| Plugin `commands/`、`agents/` 未挂载 | 需 `PluginLoader` 分发至 SkillCatalog / AgentRegistry | +| 无 `plugins.json` 持久化 | 无法 UI 级 enable/disable / 版本管理 | +| `PluginHooksLoader` 未注入 `plugin_data_dir` | `MS_AGENT_PLUGIN_DATA` 仅在 executor 层预留,loader 未关联 plugin id | +| Hook command 安全边界未在产品层显式表达 | command 脚本是 HookExecutor 子进程,不经过 ToolManager 的 Permission / SafetyGuard;需默认关闭 + 启用提示 | +| 无安装 URI | 不能 `github://org/repo` 一键安装 | +| Command 扩展未接入 Plugin | F5 预留的注册 API 未与 Plugin manifest 联动 | + +### 2.3 现有发现逻辑(待收敛) + +```python +# ms_agent/hooks/factory.py — 当前临时实现 +def _discover_plugin_roots(config, project_path) -> list[str]: + # 1. /.ms-agent/plugins// (安装目标目录) + # 2. config.plugins[] 中的相对/绝对路径 (agent.yaml 显式声明) +``` + +**问题**:无 manifest 校验、无 enabled 过滤、与全局 `~/.ms_agent/plugins/` 不同步。新设计将 discovery 收敛到 `PluginRegistry`。 + +### 2.4 术语冲突(实现时必须消歧) + +当前代码里已有三处名字接近但语义不同的 “plugin”: + +| 名称 | 当前位置 | 语义 | 与本文关系 | +|------|----------|------|------------| +| `config.plugins[]` | `ms_agent/hooks/factory.py` | 临时声明 Plugin hooks 根目录 | Phase 0 迁移到 `plugins.json` / `PluginRegistry` | +| `tools.plugins[]` | `ms_agent/tools/tool_manager.py` | Python `ToolBase` 外部工具插件,需 `trust_remote_code=True` | 不是本文的容器 Plugin;文档和 UI 需避免混称 | +| `plugins.json` | 本文新增 | 容器 Plugin 安装索引与 enabled 状态 | F9 正式配置入口 | + +--- + +## 3. 总体架构 + +### 3.1 模块职责 + +``` +ms_agent/plugins/ +├── manifest.py # PluginManifest 解析与校验 +├── registry.py # 已安装 Plugin 索引(内存 + 磁盘) +├── installer.py # 本地 / github / modelscope 安装 +├── config_manager.py # plugins.json CRUD(对标 MCPConfigManager) +├── loader.py # PluginLoader:按 manifest 分发到各子系统 +├── runtime.py # PluginRuntime:enabled、热重载、聚合 list_all() +└── types.py # PluginRecord、InstallSource、PluginStatus +``` + +### 3.2 数据流 + +```plaintext +安装/配置 + plugins.json (global / project) + │ + ▼ + PluginConfigManager.load_merged() + │ + ▼ + PluginRegistry.resolve() ──→ list[PluginManifest] + │ + ▼ + PluginLoader.load_all(manifests, ctx) + ├─ skills/ + 根 SKILL.md → SkillCatalog + ├─ commands/*.md → SkillLoader / CommandRouter + ├─ agents/*.md → AgentRegistry(P1) + ├─ hooks/hooks.json + yaml → HookRegistry + ├─ .mcp.json / tools/mcp.json → MCPRuntime + ├─ bin/ → code_executor PATH(P1) + ├─ settings.json → ConfigResolver 补丁(P1) + ├─ userConfig (manifest) → plugins/data + 变量展开(P1) + ├─ assets/ + interface → UI 元数据(不进入 Runtime) + └─ scan unsupported → lsp / themes / monitors / … + │ + ▼ + PluginRuntime + ├─ toggle(plugin_id, enabled) + ├─ reload(plugin_id) + └─ list_all() → UI +``` + +**唯一来源约定**:Phase 0 迁移完成后,Plugin 子资源发现以 `PluginRegistry` / `PluginLoader` 为准;`build_hook_runtime()` 不再自行扫描 `.ms-agent/plugins/*`。迁移期可保留 `_discover_plugin_roots()` 作为兼容层,但必须保证同一 plugin hook 不会被 legacy path 和 `plugins.json` 双重加载。 + +### 3.3 与 ConfigResolver 的关系 + +`ConfigResolver` 在 `resolve()` 末尾已有 `_merge_mcp` / `_merge_skills`。Plugin 合并作为 **第 6 步**(在 session overrides 之后、fill_missing_fields 之前),仅用于 Playground / Server / TUI 这类分层配置入口;CLI 直读 `Config.from_task()` 的兼容路径不强制接入: + +```python +# 伪代码 — config/resolver.py 扩展 +def resolve(...): + merged = self._merge_layers(layers) + merged = self._merge_mcp(merged, project_path) + merged = self._merge_skills(merged, project_path) + merged = self._merge_plugins(merged, project_path) # 新增 + return Config.fill_missing_fields(merged) +``` + +`_merge_plugins` 职责: + +- 读取 `PluginConfigManager.load_merged(project_path)` +- 将 **enabled** 的 plugin 根路径写入 `merged.plugins`(`List[str]`,兼容现有 `agent.yaml` 字段) +- 将 plugin 衍生的 MCP server 条目 **合并进** MCP 层(见 §7.3) + +--- + +## 4. Plugin 包格式与 Manifest + +### 4.1 目录布局(Claude Code / Codex 兼容) + +Manifest 位置因生态而异(**均需识别**): + +| 生态 | Manifest 路径 | +|------|---------------| +| Claude Code | `.claude-plugin/plugin.json`(组件目录在 plugin **根**) | +| Codex | `.codex-plugin/plugin.json` | +| Cursor | `.cursor-plugin/plugin.json`(bundle 检测) | +| OpenClaw native | `openclaw.plugin.json`(包根,TS 进程内) | +| MS-Agent 原生 | `plugin.json` 或 `.ms-agent-plugin/plugin.json` | + +完整目录(Claude Code 官方 reference + 社区常见布局): + +```plaintext +my-plugin/ +├── .claude-plugin/ # Claude:仅 manifest 在此 +│ └── plugin.json +├── .codex-plugin/ # Codex:同上 +│ └── plugin.json +├── README.md +├── skills/ # Skill 目录(每子目录含 SKILL.md) +│ └── commit-helper/SKILL.md +├── SKILL.md # 可选:无 skills/ 时根目录单 skill +├── commands/ # 遗留 slash command(flat .md) +├── agents/ # Subagent 定义(.md + frontmatter) +├── hooks/ +│ ├── hooks.json # Claude/Codex plugin 格式(含 hooks 包装层) +│ └── hermes.yaml # Hermes shell hooks(包内,P1) +├── .mcp.json # MCP(Claude/Codex 惯例文件名) +├── tools/ # MS-Agent 别名:tools/mcp.json +│ └── mcp.json +├── .app.json # Codex App Connectors(OAuth 应用) +├── .lsp.json # LSP 语言服务配置 +├── output-styles/ # Claude 输出风格 +├── themes/ # Claude 颜色主题(experimental) +├── monitors/ # Claude 后台监视器(experimental) +├── bin/ # 加入 Bash tool PATH 的可执行文件 +├── settings.json # Plugin 启用时的默认 settings 片段 +├── scripts/ # hook/MCP 引用的辅助脚本(非独立组件) +├── assets/ # Codex UI:icon/logo/screenshots +└── rules/ # 部分社区包携带 .claude/rules 片段 +``` + +**原设计遗漏项**见 [§19](#19-社区-plugin-组件全景调研)。 + +### 4.2 plugin.json Schema + +MS-Agent **超集** Claude Code manifest,未知字段忽略: + +```json +{ + "name": "commit-helper", + "version": "1.2.0", + "description": "Conventional commit assistant", + "author": { "name": "Alice" }, + "homepage": "https://github.com/org/commit-helper", + "license": "MIT", + "keywords": ["git", "commit"], + + "ms_agent": { + "min_version": "1.0.0", + "capabilities": [ + "skills", "commands", "agents", "hooks", "mcp", + "settings", "bin", "user_config" + ] + }, + + "skills": "./skills/", + "commands": "./commands/", + "agents": ["./agents/reviewer.md"], + "hooks": "./hooks/hooks.json", + "mcpServers": "./.mcp.json", + "lspServers": "./.lsp.json", + "outputStyles": "./output-styles/", + "dependencies": [{ "name": "base-plugin", "version": "~1.0.0" }], + "userConfig": { "...": "见 §19.2" }, + "defaultEnabled": true +} +``` + +| 字段 | 必填 | 说明 | +|------|------|------| +| `name` | ✅ | Plugin 稳定 id(目录名默认与此一致) | +| `version` | 推荐 | semver;用于升级与冲突检测 | +| `description` | 推荐 | UI 展示 | +| `ms_agent.min_version` | 可选 | SDK 版本门禁 | +| `ms_agent.capabilities` | 可选 | 声明包含的子资源,便于 UI 图标 | + +**Plugin id 规则**:`manifest.name` 规范化(小写、`/` → `-`)作为 `plugin_id`;安装目录名必须一致。 + +#### Manifest 发现 vs 安装域(必读) + +这是两个**正交**问题,原先 §4.2 列表易误解为「MS-Agent 去读 Claude 的全局缓存」——**不是**。 + +| 维度 | MS-Agent 行为 | 不做什么 | +|------|---------------|----------| +| **安装域 / 缓存** | 仅 `~/.ms_agent/plugins//`(global)或 `/.ms-agent/plugins//`(project) | **默认不**扫描 `~/.claude/plugins/cache/`、`~/.codex/plugins/cache/`、`~/.openclaw/` | +| **配置索引** | `~/.ms_agent/plugins.json` / `.ms-agent/plugins.json` 列出 enabled + **path** | 不读 Claude `enabledPlugins`、Codex `config.toml` plugins 段 | +| **可变数据** | `~/.ms_agent/plugins/data//` | 与 Claude `CLAUDE_PLUGIN_DATA` 目录**物理隔离** | +| **Manifest 解析** | 在**已落入 MS-Agent 安装目录的那一份拷贝**上,识别其生态格式 | 不在「用户同时开了 Claude/Codex」时跨工具抢目录 | + +用户本机同时装 Claude Code + Codex + MS-Agent **不会**导致 MS-Agent 加载错包,只要 MS-Agent 只消费自己的 `plugins.json` 条目。 +只有当用户用 **`--link` 开发模式** 把 `plugins.json.path` 指到 Claude 缓存里的同一路径时,才可能与 Claude 并发写同一目录——此时为显式 opt-in,文档警告。 + +#### Manifest 路径解析(安装时探测 + 持久化锁定) + +解析发生在 **`PluginInstaller.install()` 的 staging 阶段**,结果写入 `plugins.json` 的 `format` + `manifest_path`,**运行时不再按全局优先级重猜**。 + +**探测顺序**(仅当安装源未声明 `format` 且目录内存在多个 manifest 时): + +1. `.ms-agent-plugin/plugin.json` — **MS-Agent 原生,显式优先** +2. 根目录 `plugin.json` 且含 `ms_agent` 段 +3. `.claude-plugin/plugin.json` +4. `.codex-plugin/plugin.json` +5. `.cursor-plugin/plugin.json` +6. `openclaw.plugin.json` +7. 根目录 `plugin.json`(无 `ms_agent` 段的通用/遗留包) + +**冲突规则**(同一目录多个 manifest): + +为避免「装一次后运行时格式漂移」,安装 staging 阶段必须一次性锁定 `format` + `manifest_path`。无 `format_hint` 时: + +- 只有一个 manifest:直接采用; +- 同时存在 MS-Agent 原生 manifest(`.ms-agent-plugin/plugin.json` 或根 `plugin.json` 且含 `ms_agent` 段)和其他宿主 manifest:采用 MS-Agent 原生,并记录 warning; +- 其他多 manifest 并存:报 `AmbiguousPluginManifest`,要求用户显式 `--format claude|codex|...`。 + +```python +# 伪代码 +def detect_manifest(root: Path, *, format_hint: str | None) -> tuple[Path, PluginFormat]: + if format_hint: + return _resolve_by_hint(root, format_hint) # 安装 URI / CLI 指定 + candidates = _scan_all_manifests(root) + if len(candidates) == 1: + return candidates[0] + native = _pick_ms_agent_native(candidates) + if native is not None: + return native + if len(candidates) > 1: + raise AmbiguousPluginManifest(candidates) # 要求用户 --format claude +``` + +`PluginRegistry` / `PluginLoader` **只读** `plugins.json` 里已锁定的 `manifest_path`,避免用户后来在磁盘上多加 `.codex-plugin/` 导致运行时格式漂移。 + +#### 安装 URI 与「指定装进 MS-Agent」 + +社区 Plugin **没有** Claude 式的跨宿主自动发现;要通过 MS-Agent 安装器写入 MS-Agent 缓存,使用下列入口之一: + +| 方式 | 示例 | 行为 | +|------|------|------| +| **MS-Agent URI**(推荐) | `ms-agent://plugin/install?source=github://anthropics/claude-plugins-official@main#plugins/hookify` | 明确目标宿主为 MS-Agent | +| **GitHub 子路径** | `github://org/repo@ref#plugins/foo` | fetch → **copy** 到 `~/.ms_agent/plugins/foo/` | +| **Marketplace** | `ms-agent plugin install hookify --marketplace anthropics/claude-plugins-official` | 读 marketplace.json 的 `source.path`,仍安装到 MS-Agent 目录 | +| **本地目录** | `ms-agent plugin install /path/to/plugin` 或 `file:///...` | copy/link 到 MS-Agent 目录;**不**注册到 Claude | +| **显式 format** | `... --format claude` | 多 manifest 冲突时指定解析 | +| **显式 link** | `... --link` | path 指向外部目录(开发);与 Claude 共享目录时用户自负 | + +`plugins.json` 条目扩展(安装后持久化): + +```json +{ + "id": "hookify", + "enabled": true, + "managed_by": "ms-agent", + "format": "claude", + "manifest_path": ".claude-plugin/plugin.json", + "source": { + "type": "github", + "uri": "github://anthropics/claude-plugins-official@main#plugins/hookify", + "resolved_sha": "abc123..." + }, + "path": "/Users/me/.ms_agent/plugins/hookify", + "installed_at": "2026-06-18T12:00:00Z" +} +``` + +- `managed_by: "ms-agent"`:声明此拷贝由 MS-Agent 生命周期管理;Claude/Codex **不会**自动读取该条目。 +- `format` + `manifest_path`:安装时锁定,消除多工具歧义。 +- `resolved_sha`:供应链 pin,与 Claude marketplace 的 commit pin 类似但**独立存储**。 + +**与 Claude 并存时的推荐做法**:同一份社区包各装一份——Claude 走 `/plugin install` → `~/.claude/plugins/cache/...`;MS-Agent 走 `ms-agent plugin install` → `~/.ms_agent/plugins/...`。内容可相同,**缓存互不干扰**。 + +### 4.3 Manifest 解析与校验 + +```python +@dataclass(frozen=True) +class PluginManifest: + plugin_id: str + name: str + version: str + description: str + root: Path + format: PluginFormat # claude | codex | ms-agent | mixed | ... + manifest_path: str # 相对 root,安装时锁定 + capabilities: frozenset[str] # 见 §4.4 capability id + components: dict[str, ComponentScan] # 每组件:path、count、status + source: InstallSource + installed_at: str + enabled: bool = True + + # 约定路径(manifest 可覆盖,见 plugin.json 各 *Servers / skills / agents 字段) + def resolve_path(self, kind: str) -> Path | None: ... +``` + +#### 4.3.1 Manifest 文件校验 + +1. **manifest 存在**:`plugins.json` 锁定的 `manifest_path` 指向的 JSON 合法可读 +2. **`name`**:匹配 `^[a-z0-9][a-z0-9._-]{0,63}$`(kebab-case) +3. **`version`**:若存在则为 semver;用于升级与 `dependencies` 约束 +4. **未知顶层字段**:忽略(兼容 Claude/Codex 超集);`ms_agent plugin validate --strict` 可报 warning + +#### 4.3.2 可加载组件存在性(安装门槛) + +Plugin **至少须含下列「可加载组件」之一**(§4.4 `loadable=true`)。 +仅含 `scripts/`、`assets/`、`README` **不能**单独构成可安装 Plugin。 + +| capability id | 判定信号(任一命中即可) | +|---------------|-------------------------| +| `skills` | `skills/` 下含 `SKILL.md`;或 manifest `skills` 路径;或根 `SKILL.md` | +| `commands` | `commands/*.md`;或 manifest `commands` | +| `agents` | `agents/*.md`(或 `agents/*/AGENT.md`,warning 非标准) | +| `hooks` | `hooks/hooks.json`;manifest 内联 `hooks`;`hooks/hermes.yaml` | +| `mcp` | `.mcp.json`;`tools/mcp.json`;manifest `mcpServers` 内联 | +| `settings` | 根 `settings.json`(非空) | +| `bin` | `bin/` 下至少一个可执行文件 | +| `user_config` | manifest `userConfig` 非空 | + +**不记入安装门槛、仅扫描上报**:`lsp`、`output_styles`、`themes`、`monitors`、`apps`、`channels`、`rules`、OpenClaw/Hermes 进程内扩展等(§4.4)。 + +#### 4.3.3 分组件内容校验 + +| 组件 | 规则 | 失败级别 | +|------|------|----------| +| **skills** | 每个 skill 目录须含 `SKILL.md`;frontmatter `name` 可选 | 缺 `SKILL.md` → **error**(该 skill);其余 skill 仍可加载 | +| **commands** | 每个 `.md` 须可解析 YAML frontmatter | 单文件 **warning**,跳过 | +| **agents** | 每个 `.md` 须含 `description`;推荐 `name` | 缺 `description` → **warning** | +| **hooks** | `hooks.json` 符合 plugin 包装或 settings 格式;Hermes yaml 含 `hooks:` | 解析失败 → 该源 **error**,另一格式可并存 | +| **mcp** | JSON 合法;server 条目经 `normalize_mcp_server_entry` | 非法 server → **warning**,跳过该 server | +| **settings** | JSON 合法;仅 merge 白名单键(§7.8) | 未知键 → **warning** | +| **bin** | 文件存在;非 Windows 下检查 `+x` 或 shebang | 不可执行 → **warning** | +| **userConfig** | schema 字段 `type`/`title`/`description` 合法 | 非法项 → **error**(阻止启用表单) | +| **dependencies** | 引用的 plugin id 可解析;semver 合法 | 缺失依赖 → 安装时 **error**(可先装依赖) | + +#### 4.3.4 扫描期自动发现 `capabilities` + +安装 staging 阶段执行 `scan_components(root)`,不依赖 manifest 自声明(manifest `ms_agent.capabilities` 仅用于 UI 图标,**以扫描结果为准**): + +```python +def scan_components(root: Path) -> frozenset[str]: + found: set[str] = set() + if _has_skills(root): found.add("skills") + if _has_commands(root): found.add("commands") + if _has_agents(root): found.add("agents") + if _has_hooks(root): found.add("hooks") + if _has_mcp(root): found.add("mcp") + if (root / "settings.json").is_file(): found.add("settings") + if _has_bin(root): found.add("bin") + if _manifest_user_config(root): found.add("user_config") + # detect-only(写入 components.*.status,不计入 capabilities 门槛) + if (root / ".lsp.json").is_file(): _mark(root, "lsp", "detect_only") + ... + if not found & LOADABLE_CAPABILITIES: + raise EmptyPluginError(root) + return frozenset(found) +``` + +--- + +### 4.4 组件能力注册表(Component Registry) + +MS-Agent 用统一 **capability id** 描述 Plugin 内各组件,避免全文只写 skills/hooks/tools 三类。 + +| capability id | 目录 / 文件 | MS-Agent 运行时 | 阶段 | 说明 | +|---------------|-------------|-----------------|------|------| +| `skills` | `skills/*/SKILL.md`、根 `SKILL.md` | `SkillCatalog` | **P0** | 模型注入 + `/skill-name` | +| `commands` | `commands/*.md` | `SkillCatalog` 或 `CommandRouter` | **P1** | Claude 遗留 slash;2026 常与 skill 合并语义 | +| `agents` | `agents/*.md` | `AgentRegistry` → `AgentDelegate` | **P1** | 子 agent 模板;frontmatter:`model`、`tools`、`skills` | +| `hooks` | `hooks/hooks.json`、`hooks/hermes.yaml` | `HookRegistry` | **P0/P1** | command hook 为主;prompt/http P2 | +| `mcp` | `.mcp.json`、`tools/mcp.json` | `MCPRuntime` | **P1** | 与 F7 对齐;server 名冲突加 `plugin..` 前缀 | +| `settings` | `settings.json` | `ConfigResolver` 补丁 | **P1** | 白名单键:`agent`、`subagentStatusLine` 等 | +| `bin` | `bin/*` | `LocalCodeExecutor` 扩展 PATH | **P1** | plugin enable 时注入,disable 移除 | +| `user_config` | manifest `userConfig` | `plugins/data//config.json` | **P1** | `${user_config.KEY}`、`CLAUDE_PLUGIN_OPTION_*` | +| `dependencies` | manifest `dependencies[]` | `PluginInstaller` 拓扑 | **P1** | 安装顺序,非运行时组件 | +| `assets` | `assets/*` | UI only | P1 | Codex `interface.logo` 等 | +| `apps` | `.app.json` | AppConnector(OAuth) | P2 | Codex 专有 | +| `rules` | `rules/`、包内 instruction | `PersonalizationInjector` | P2 | 并入 project 指令 | +| `lsp` | `.lsp.json` | detect-only | P3 | Playground 非 IDE 核心 | +| `output_styles` | `output-styles/` | ignore / P3 | P3 | 终端呈现 | +| `themes` | `themes/` | ignore | P3 | | +| `monitors` | `monitors/monitors.json` | detect-only / P3 | P3 | 对标 Monitor tool | +| `channels` | manifest `channels[]` | P3 | P3 | MCP 消息注入 | +| `hooks_openclaw_internal` | `HOOK.md`+`handler.ts` | unsupported | detect | §16 | +| `hooks_hermes_python` | `register(ctx)` | unsupported | detect | §16 | +| `scripts` | `scripts/` | — | — | **非独立组件**;供 hooks/MCP 引用 | + +**`list_all()` / `capabilities_status` 键名**与上表 `capability id` 一致,例如: + +```json +"capabilities_status": { + "skills": { "count": 1, "status": "ready" }, + "commands": { "count": 4, "status": "ready" }, + "agents": { "count": 1, "status": "ready" }, + "hooks": { "count": 4, "status": "ready" }, + "mcp": { "count": 0, "status": "skipped" }, + "settings": { "status": "skipped" }, + "bin": { "status": "skipped" }, + "lsp": { "status": "detect_only", "hint": "Playground 不加载 LSP" } +} +``` + +--- + +## 5. 发现、安装与配置分层 + +### 5.1 存储布局(MS-Agent 独立安装域) + +| scope | 配置索引 | 安装目录(缓存) | 可变数据 | +|-------|----------|------------------|----------| +| global | `~/.ms_agent/plugins.json` | `~/.ms_agent/plugins//` | `~/.ms_agent/plugins/data//` | +| project | `/.ms-agent/plugins.json` | `/.ms-agent/plugins//` | 同上或 project 子目录(P2) | + +对比其他宿主(**MS-Agent 默认不读取**): + +| 宿主 | 典型缓存 | MS-Agent 关系 | +|------|----------|---------------| +| Claude Code | `~/.claude/plugins/cache////` | 仅当 `plugin install` **copy** 社区包后内容可相同;路径独立 | +| Codex | `~/.codex/plugins/cache/...` | 同上 | +| OpenClaw | `~/.openclaw/extensions/`、`plugins/installs.json` | bundle 可复用布局;安装域独立 | + +**staging 目录**(安装中间态):`~/.ms_agent/plugins/.staging//` → 校验通过后原子 `rename` 到 `/`。 + +`plugins.json` 格式见 [附录 A](#附录-apluginsjson-示例)。 + +### 5.2 合并规则 + +与 MCP / Skills 一致: + +- **并集**:global + project 均列出时,project 同 id **覆盖** global 的 `enabled` 与 `path` +- **enabled: false**:Plugin 不参与任何子系统加载;已安装文件保留磁盘 +- **path 解析**:条目可为 `{ "id": "commit-helper", "source": "local", "path": "/abs/path" }` 或安装后的默认路径 + +### 5.3 安装来源 + +| 来源 | URI 示例 | Phase | 行为 | +|------|----------|-------|------| +| **MS-Agent 显式** | `ms-agent://plugin/install?source=...` | P0 | 目标宿主固定为 MS-Agent 缓存 | +| 本地目录 | `/path/to/plugin` 或 `file:///...` | P0 | **copy**(默认)或 `--link` 到 MS-Agent 目录 | +| 本地 tarball | `file:///path/plugin.tgz` | P1 | 解压 → staging → 落入 MS-Agent 目录 | +| GitHub | `github://org/repo[@ref][/subdir]` | P1 | shallow clone 子路径 → **copy** 到 MS-Agent 目录 | +| ModelScope | `modelscope://org/pack[@rev]` | P1 | 下载 → MS-Agent 目录 | +| Claude marketplace 名 | `hookify@claude-plugins-official`(CLI 糖) | P1 | 解析 marketplace.json → 同 GitHub 流程 → **装入 MS-Agent**,不调用 Claude CLI | + +安装流程: + +```plaintext +PluginInstaller.install(source, scope, project_path?) + → fetch 到 staging/ + → PluginManifest.parse(staging) + → 冲突检测(同 id 高版本 / 强制 --force) + → 原子移动到 plugins// + → PluginConfigManager.upsert(record) + → PluginRuntime.reload(plugin_id) +``` + +**默认 copy**;开发模式可选 `--link` symlink。 + +### 5.4 PluginConfigManager + +对标 `MCPConfigManager` / `SkillsConfigManager`: + +```python +class PluginConfigManager: + def list(scope: Literal['global','project','merged']) -> list[PluginRecord] + def get(plugin_id: str, scope=...) -> PluginRecord | None + def upsert(record: PluginRecord, scope=...) -> None + def set_enabled(plugin_id: str, enabled: bool, scope=...) -> None + def remove(plugin_id: str, scope=...) -> None # 仅删配置;--purge 删目录 + def load_merged(project_path: str | None) -> list[PluginRecord] +``` + +--- + +## 6. PluginLoader — 分发注册 + +### 6.1 接口 + +```python +@dataclass +class PluginLoadContext: + project_path: str + session_id: str + enabled_executors: frozenset[str] + plugin_data_root: Path # ~/.ms_agent/plugins/data + +@dataclass(frozen=True) +class PluginHookContribution: + plugin_id: str + registry: HookRegistry + plugin_root: Path + plugin_data_dir: Path + +class PluginLoadResult: + skill_sources: list[SkillSource] + hook_registries: list[PluginHookContribution] + mcp_servers: dict[str, dict] + command_defs: list[CommandDef] # commands/*.md + agent_defs: list[AgentDef] # agents/*.md + settings_patch: dict[str, Any] # settings.json 片段 + bin_paths: list[Path] + user_config_schema: dict[str, Any] + ui_metadata: dict[str, Any] # assets + interface + unsupported: list[UnsupportedCapability] # lsp, themes, monitors, ... + +class PluginLoader: + @staticmethod + def load(manifest: PluginManifest, ctx: PluginLoadContext) -> PluginLoadResult: ... + + @staticmethod + def load_all(manifests: list[PluginManifest], ctx: PluginLoadContext) -> PluginLoadResult: + # 按 plugin_id 排序保证确定性;hook merge 顺序 = 安装顺序 +``` + +**Hook metadata 必须是 per-handler 级别**:`HookRegistry.merge()` 合并后只保留 `HookHandlerConfig`,不能只依赖外层 `(plugin_id, registry)` tuple。`PluginLoader` 在产出 hook registry 时必须给每个 handler 标注只读来源元数据(例如 `source_plugin_id`、`source_plugin_root`、`source_plugin_data_dir`),`HookExecutor` 执行单个 handler 时据此构造 `HookExecutionContext`。否则多 Plugin hooks 合并后无法稳定注入 `MS_AGENT_PLUGIN_ROOT` / `MS_AGENT_PLUGIN_DATA`,也无法精确热重载某个 Plugin 的 hook 段。 + +### 6.2 与现有 Hook factory 的迁移 + +**当前**:`build_hook_runtime` 内联 `_discover_plugin_roots` + 循环 `PluginHooksLoader`。 + +**目标**:`build_hook_runtime` 只负责把各来源 `HookRegistry` 合并成 `HookRuntime`,不负责 Plugin manifest / enabled / 安装域解析。Plugin 来源由 `PluginRegistry` 解析,`PluginLoader` 产出 `hook_registries` 后注入 hook factory。 + +```python +# hooks/factory.py — 重构后 +def build_hook_runtime(config, *, session_id=None, plugin_hook_registries=None): + ... + if 'plugin' in enabled_sources: + for contrib in (plugin_hook_registries or []): + # contrib.registry 内的 handler 已带 source_plugin_* metadata + loaders.append((f'plugin:{contrib.plugin_id}', contrib.registry)) +``` + +`PluginHooksLoader` **保留**为薄封装,不删除,供 `PluginLoader` 内部调用。Phase 0 允许保留 `_discover_plugin_roots()` 兼容 `agent.yaml plugins:`,但兼容路径必须在发现到同 id 的 `plugins.json` 记录时跳过,避免重复注册同一 hooks。 + +### 6.3 Skills 与 Commands 挂载 + +**Skills**(`skills/`、manifest `skills`、根 `SKILL.md`): + +```python +for skills_path in manifest.resolve_paths("skills"): + # SkillSource 需扩展 origin/plugin_id/capability;当前 sources.py 尚无这些字段。 + sources.append(SkillSource( + type=SkillSourceType.LOCAL_DIR, + path=str(skills_path), + origin="plugin", + plugin_id=manifest.plugin_id, + capability="skills", + )) +if (manifest.root / "SKILL.md").is_file(): + sources.append(...) # 单 skill 包 +``` + +**Commands**(`commands/*.md`,P1): + +- **策略 A(推荐)**:flat `.md` 经 `SkillLoader` 单文件模式并入 `SkillCatalog` +- **策略 B**:`CommandRouter.register` + `SUBMIT_PROMPT` +- UI 命名空间:`/plugin-id:command-name`(对齐 Claude) + +`SkillSource` / `SkillSchema` / `SkillRuntime.list_all()` 需补齐来源元数据:`origin`, `plugin_id`, `capability: "skills"|"commands"`。当前代码中的 `SkillSource` 尚未包含这些字段,`SkillRuntime.list_all()` 也未返回来源信息,因此这是 Phase 0 的显式 API 扩展,而不是现有接口。**优先级**:plugin **高于** builtin,**低于** workspace sources(tier 2.5);热重载需新增 source-level reload,或先移除该 plugin source 再重新加载。 + +### 6.4 Agents 挂载(P1) + +扫描 `agents/*.md`(`agents/*/AGENT.md` 兼容但 deprecation warning)→ `AgentDef` → `AgentRegistry` / `AgentDelegate`。P1 可先 **list 不执行**。 + +### 6.5 Hooks 挂载 + +- `hooks/hooks.json` → `PluginHooksLoader` +- `hooks/hermes.yaml` → `HermesShellLoader`(§16.4) +- manifest 内联 `hooks` → 与文件 merge + +### 6.6 MCP 挂载(P1) + +探测:manifest `mcpServers` → `.mcp.json` → `tools/mcp.json`。详见 §7.5。 + +### 6.7 辅助组件(P1) + +| 组件 | Loader 输出 | +|------|-------------| +| `settings.json` | `settings_patch` | +| `bin/` | `bin_paths` → `WorkspaceContext` | +| `userConfig` | `user_config_schema` + data 目录 | +| `assets/` + `interface` | `ui_metadata` | +| LSP / themes / monitors 等 | `unsupported`(detect-only) | + +--- + +## 7. 子资源加载语义 + +> 各节 capability id 与 §4.4 注册表一致。 + +### 7.1 Skills + +| 维度 | 语义 | +|------|------| +| Plugin `enabled=false` | 整个 Plugin 不加载;skills / commands 均不可见 | +| Skill 级 `disabled` | `SkillsConfigManager.disabled`;同名冲突时 plugin 来源优先 | +| Slash command | disabled skill 仍可通过 `/skill-name` 触发 | +| 根 `SKILL.md` | 无 `skills/` 时整包视为单 skill | +| 热重载 | `reload` → `SkillCatalog.reload_source` → `SkillRuntime.version++` | + +### 7.2 Commands + +| 维度 | 语义 | +|------|------| +| 与 skills 关系 | Claude 2026 统一为 skill 语义;MS-Agent P1 优先并入 `SkillCatalog` | +| Slash | `/plugin-id:cmd` 或 `/cmd`(项目内无冲突时) | +| frontmatter | `allowed-tools`、`argument-hint` 影响 Command 执行上下文(P2) | +| 注册 | `PluginLoader` 解析;优先 SkillCatalog,备选 `CommandRouter.register` | + +示例 frontmatter: + +```markdown +--- +name: deploy +description: Deploy current project +priority: 50 +--- +Run deployment using scripts in ${MS_AGENT_PLUGIN_ROOT}/scripts/ +``` + +### 7.3 Agents(Subagents) + +| 维度 | 语义 | +|------|------| +| 文件 | `agents/*.md`;frontmatter:`name`、`description`、`model`、`tools`、`disallowedTools`、`skills` | +| 运行时 | P1:`list_all` 展示;P2:`AgentDelegate` 按模板 spawn | +| Plugin disable | agent 从 registry 移除 | +| 安全 | plugin agent **不可**声明 `hooks` / `mcpServers` / `permissionMode`(对齐 Claude 限制) | + +### 7.4 Hooks + +| 维度 | 语义 | +|------|------| +| 格式 | Claude `hooks/hooks.json`(与 `.claude/settings.json` 的 `hooks` 段同构) | +| enabled_sources | 需在 `agent.yaml` / settings 中 `hooks.enabled_sources` 含 `plugin`(**默认不含**,避免静默执行第三方脚本) | +| plugin_data_dir | `~/.ms_agent/plugins/data//` 传入 `HookExecutionContext` | +| 安全 | command hook 是独立子进程;不经过 ToolManager 的 Permission + SafetyGuard。Plugin 不能 bypass Agent tool 调用权限,但 hook 脚本自身需按 hook 风险治理 | + +**推荐默认配置**(安全默认,CLI / Playground 均适用): + +```yaml +hooks: + enabled_sources: [native] + enabled_executors: [command] + fail_closed: false +``` + +Playground 可以在**用户显式确认**或企业内置 trusted profile 中开启 `plugin` source,例如 `enabled_sources: [native, plugin]`。UI 必须在首次启用含 hooks 的第三方 Plugin 时提示:`type=command` hook 可执行任意本地命令,风险不等同于一次受控 shell tool call。未确认前,即使 Plugin 已安装且 `enabled=true`,其 hooks 也不应加载。 + +### 7.5 MCP(`.mcp.json` / `tools/mcp.json`) + +Plugin MCP 配置示例(`.mcp.json` 惯例文件名): + +```json +{ + "mcpServers": { + "commit-helper": { + "command": "node", + "args": ["${MS_AGENT_PLUGIN_ROOT}/tools/server/index.js"], + "env": { + "PLUGIN_CONFIG": "${MS_AGENT_PLUGIN_DATA}/config.json" + } + } + } +} +``` + +处理规则: + +1. **Server 命名**:默认使用 manifest 中的 key;若与全局 MCP 冲突,加前缀 `plugin..` +2. **路径变量**:Loader 阶段展开 `${MS_AGENT_PLUGIN_ROOT}` / `${MS_AGENT_PLUGIN_DATA}` / Claude 别名 +3. **合并**:注入 `ConfigResolver.resolve_mcp()` 的 project 层,携带 `source: "plugin"`, `plugin_id` +4. **enabled**:随 Plugin enabled;MCP 级 `enabled: false` 可在 `tools/mcp.json` 内 per-server 设置 +5. **生命周期**:`MCPRuntime.reload_server` / `disable_server`;Plugin disable 时 disconnect 该 plugin 贡献的全部 server + +详见 `mcp_runtime_management.md` §Phase 3 第一条。 + +### 7.6 bin/ PATH 注入(P1) + +- enable:将 `/bin` 追加到 `LocalCodeExecutor` / shell 的 `PATH`(plugin 作用域) +- disable:移除;不影响系统 PATH +- 与 Claude「Bash tool 可裸调 bin 内命令」语义对齐 + +### 7.7 settings.json 补丁(P1) + +- 白名单键(一期):`agent`、`subagentStatusLine` 及 Playground 已支持字段 +- enable:merge 进 session/project resolved config;disable:revert 该 plugin 贡献的键 +- OpenClaw bundle 的 Claude `settings.json` 默认值同此路径 + +### 7.8 userConfig(P1) + +- 启用 plugin 时 UI 收集 manifest `userConfig` 字段 +- 持久化:`~/.ms_agent/plugins/data//config.json`;敏感项走 keychain / credentials 文件 +- 展开:`${user_config.KEY}`、`${CLAUDE_PLUGIN_OPTION_KEY}` 用于 MCP env、hook command、monitor command + +### 7.9 detect-only / unsupported 组件 + +| capability | 行为 | +|------------|------| +| `lsp` | `capabilities_status.lsp=detect_only`;文档引导 IDE 场景 | +| `output_styles` / `themes` | ignore 或 CLI P3 | +| `monitors` | P3;需 Monitor tool | +| `apps` | P2 OAuth | +| `channels` | P3 | +| `hooks_openclaw_internal` / Hermes Python | `unsupported` + `migration_hints` | + +--- + +## 8. PluginRuntime — 运行时管理 + +对标 `MCPRuntime` / `SkillRuntime`: + +```python +class PluginRuntime: + def __init__( + self, + config_manager: PluginConfigManager, + *, + skill_runtime: SkillRuntime | None = None, + hook_runtime_factory: Callable[..., HookRuntime] | None = None, + mcp_runtime: MCPRuntime | None = None, + ): ... + + async def start(self, project_path: str, session_id: str) -> None: + """加载全部 enabled plugin 并分发。""" + + def list_all(self) -> list[dict]: + """UI:id, name, version, enabled, capabilities, status, path""" + + async def toggle(self, plugin_id: str, enabled: bool, scope=...) -> None: + """写盘 + 增量 reload 子系统。""" + + async def reload(self, plugin_id: str) -> None: + """单 Plugin 热重载。""" + + async def install(self, source: str, scope=..., **opts) -> PluginManifest: ... + + async def uninstall(self, plugin_id: str, scope=..., purge: bool = False) -> None: ... +``` + +### 8.1 热重载矩阵 + +| capability | reload 行为 | +|------------|-------------| +| skills / commands | 移除旧 source → rescan → refresh system prompt | +| agents | 重建 `AgentRegistry` 中该 plugin 条目 | +| hooks | 替换 `HookRegistry` 中该 plugin 段 | +| mcp | `MCPRuntime.apply_config` diff → disconnect 移除的 server | +| settings | revert 旧补丁 → apply 新 `settings.json` | +| bin | 更新 PATH 快照 | +| user_config | 重读 data 目录;不自动弹表单 | +| ui_metadata | 刷新 Plugin 列表缓存 | + +### 8.2 状态机 + +```plaintext +installed → disabled → enabled → loading → ready + ↘ error (manifest / MCP connect / skill parse) +``` + +`error` 状态:Plugin 内 **其他子资源仍可用**(例如 hooks 失败但 skills 成功),UI 展示 per-capability 状态。 + +--- + +## 9. 环境变量与路径变量 + +### 9.1 脚本运行时(Hook command executor 已实现部分) + +| 变量 | 含义 | Claude 别名 | +|------|------|-------------| +| `MS_AGENT_PROJECT_DIR` | 项目根 | `CLAUDE_PROJECT_DIR` | +| `MS_AGENT_PLUGIN_ROOT` | 当前 plugin 根 | `CLAUDE_PLUGIN_ROOT` | +| `MS_AGENT_PLUGIN_DATA` | `~/.ms_agent/plugins/data//` | — | +| `MS_AGENT_SESSION_ID` | 当前 session | — | + +当前实现状态: + +- `build_hook_env()` 已支持上述变量; +- `PluginHooksLoader` 已在加载阶段把 `${MS_AGENT_PLUGIN_ROOT}` / `${CLAUDE_PLUGIN_ROOT}` 展开到 command 字符串; +- 执行阶段 `HookRuntime._ctx()` 尚未携带 `plugin_root` / `plugin_data_dir`,因此 `MS_AGENT_PLUGIN_ROOT` / `MS_AGENT_PLUGIN_DATA` 需要由 `PluginLoader` 给 **每个 handler** 标注来源后才能稳定注入;这属于 Phase 0 必做项,不能只在 registry 外层保存 plugin id。 + +### 9.2 配置/template 展开(Loader 阶段) + +在 `tools/mcp.json`、`.mcp.json`、hook `command`、MCP `env` 中展开: + +- `${MS_AGENT_PLUGIN_ROOT}` / `${CLAUDE_PLUGIN_ROOT}` +- `${MS_AGENT_PLUGIN_DATA}` / `${CLAUDE_PLUGIN_DATA}` +- `${MS_AGENT_PROJECT_DIR}` / `${CLAUDE_PROJECT_DIR}` +- `${user_config.KEY}` / `${CLAUDE_PLUGIN_OPTION_KEY}`(P1) + +### 9.3 待补全 + +`HookExecutionContext.plugin_data_dir` 当前未由 handler 元数据稳定携带;需在 `PluginLoader` → `HookRegistry` 的 handler 上标注 `plugin_id` / `plugin_root` / `plugin_data_dir`,executor 在执行每个 handler 时生成对应 ctx。建议给 `HookHandlerConfig` 增加只读 metadata(如 `source_plugin_id`, `source_plugin_root`, `source_plugin_data_dir`),避免从 command 字符串反推来源,也避免多个 Plugin hooks merge 后丢失来源。 + +--- + +## 10. 与 Command / Permission 的协作 + +### 10.1 Slash Command 与 Agents + +- Plugin **skills** 自动进入 `SkillCommandBridge` 拦截链(`/skill-id`) +- Plugin **commands**(P1)在 `SkillCommandBridge` **之后**注册 interceptor,命名空间 `/plugin-id:cmd` +- Plugin **agents**(P1 list / P2 execute)由 `AgentRegistry` 暴露;不经过 Slash 链,由 `AgentDelegate` 或 UI 子 agent 选择器触发 + +### 10.2 Permission + +Plugin 不改变权限模型: + +- MCP tools:`server---tool` 格式进入 whitelist/blacklist +- Agent 发起的 tool/MCP 调用:仍按 `SafetyGuard → PreToolUse → PermissionEnforcer → call_tool → PostToolUse` +- Hook command 脚本自身:由 `HookExecutor` 直接启动,**不经过** `ToolManager.single_call_tool()`,因此不会被 PermissionEnforcer 按 shell 命令逐条确认;只能通过 hook source 默认关闭、安装来源信任、timeout、fail_closed、执行器白名单等机制治理 +- Hook 脚本如果只是影响后续 Agent tool 调用(例如返回 `deny` / `updated_args`),后续 tool 调用仍按原权限链处理 + +### 10.3 Hooks enabled_sources 安全默认 + +CLI 和 Playground 默认保持 `enabled_sources: [native]`。Playground 可在用户显式确认某个含 hooks 的 Plugin 后开启 `plugin` source,或由受信任的企业 profile 预置开启。文档需警告:开启 plugin hooks = 允许已安装且 enabled 的 Plugin 执行任意 command hook。 + +--- + +## 11. 集成点与代码变更 + +### 11.1 LLMAgent 启动链 + +```plaintext +LLMAgent.__init__ + → ConfigResolver.resolve() # 含 plugins merge + → PluginRuntime.start() / load_all() # 新增;产出 skills/hooks/mcp/settings/bin 等贡献 + → build_hook_runtime(plugin_hook_registries) # 不再自行扫描 plugin 根目录 + → SkillCatalog.load_from_config() # 含 plugin skill sources + → MCPRuntime.start() / apply_config() # 含 plugin mcp servers + → prepare_tools() +``` + +实现上可先在 `prepare_tools()` 前 lazy 初始化 `PluginRuntime`,但必须保证: + +1. `PluginLoader.load_all()` 在 `build_hook_runtime()` 之前完成,才能注入 plugin hooks; +2. plugin MCP server 在 `MCPRuntime.start()` / `sync_tools()` 前进入 resolved MCP config; +3. plugin skills 在 `SkillCatalog.load_from_config()` 前变成 `SkillSource`; +4. legacy `_discover_plugin_roots()` 与新 `plugins.json` 不双重加载。 + +### 11.2 建议接线顺序 + +| 步骤 | 组件 | 变更 | +|------|------|------| +| 1 | `plugins/manifest.py`, `config_manager.py` | 新增 | +| 2 | `plugins/installer.py` | 本地安装 P0 | +| 3 | `plugins/loader.py` | 统一分发;迁移 factory 内 discovery | +| 4 | `skill/catalog.py` | plugin source 元数据 + 优先级 | +| 5 | `config/resolver.py` | `_merge_plugins` | +| 6 | `plugins/runtime.py` | 聚合 API | +| 7 | `mcp/runtime.py` | 消费 plugin 来源 server(本文 Phase 2;对齐 `mcp_runtime_management.md` Phase 3) | +| 8 | WebUI Session | `PluginRuntime` 注入 | + +### 11.3 不改动 + +- `ToolManager` 核心调用链(仅 MCP sync 回调扩展 metadata) +- `PermissionEnforcer` / `SafetyGuard` 规则 +- CLI `Config.from_task()` 直读 YAML 路径(Playground 才走 `ConfigResolver`) + +可新增但不改变语义的安全增强:在 `PluginInstaller` / `PluginRuntime.toggle()` 层做来源提示、签名 / hash 校验、hook executor 白名单和 UI 风险确认;不要把这些包装成 `PermissionEnforcer` 对 hook subprocess 的逐命令拦截。 + +--- + +## 12. API 与 UI 数据模型 + +### 12.1 REST API(Playground 后端) + +以下为新增 Playground 后端接口;当前代码库尚无 `/api/plugins` 路由,需随 `PluginRuntime` 一起落地。 + +| 方法 | 路径 | 说明 | +|------|------|------| +| GET | `/api/plugins` | `PluginRuntime.list_all()` | +| POST | `/api/plugins/install` | body: `{ "source": "...", "scope": "global\|project" }` | +| DELETE | `/api/plugins/{id}` | `?purge=true` | +| PATCH | `/api/plugins/{id}` | `{ "enabled": true/false }` | +| POST | `/api/plugins/{id}/reload` | 热重载 | + +### 12.2 list_all 响应示例 + +```json +{ + "plugins": [ + { + "plugin_id": "commit-helper", + "name": "commit-helper", + "version": "1.2.0", + "description": "Conventional commit assistant", + "enabled": true, + "scope": "global", + "path": "/Users/me/.ms_agent/plugins/commit-helper", + "capabilities": ["skills", "commands", "agents", "hooks"], + "status": "ready", + "capabilities_status": { + "skills": { "count": 1, "status": "ready" }, + "commands": { "count": 4, "status": "ready" }, + "agents": { "count": 1, "status": "ready" }, + "hooks": { "count": 4, "status": "ready" }, + "mcp": { "count": 0, "status": "skipped" }, + "settings": { "status": "skipped" }, + "bin": { "status": "skipped" }, + "user_config": { "status": "skipped" }, + "lsp": { "status": "detect_only" } + }, + "source": { "type": "github", "uri": "github://org/commit-helper@v1.2.0" }, + "installed_at": "2026-06-18T10:00:00Z" + } + ] +} +``` + +### 12.3 设置页联动 + +实验场「智能体设置 → Plugin」与「MCP / Skill / Hooks」并列: + +- 安装/卸载/开关 Plugin +- 展开查看 Plugin 内 skills 列表(跳转 Skill 开关页,只读展示 plugin 来源) +- Hooks 总开关仍在上级 `enabled_sources` + +--- + +## 13. 文件结构 + +```plaintext +ms_agent/plugins/ +├── __init__.py +├── types.py +├── manifest.py +├── registry.py +├── config_manager.py +├── installer.py +├── loader.py +└── runtime.py + +ms_agent/hooks/loaders/plugin.py # 保留;由 PluginLoader 调用 + +tests/plugins/ +├── test_manifest.py +├── test_config_manager.py +├── test_installer_local.py +├── test_loader_skills.py +├── test_loader_hooks.py +├── test_loader_mcp.py +└── fixtures/ + ├── hookify/ # 黄金测例 vendor 快照(附录 D) + └── sample-plugin/ # 最小 synthetic(开发期) +``` + +--- + +## 14. 兼容矩阵 + +| 来源 | Manifest | skills | commands | agents | hooks | mcp | 其他 | +|------|----------|--------|----------|--------|-------|-----|------| +| Claude Code | `.claude-plugin/` | ✅ | ✅ | ✅ | hooks.json | `.mcp.json` | bin, settings, LSP, monitors | +| Codex | `.codex-plugin/` | ✅ | — | — | hooks.json | `.mcp.json` | apps, assets, interface | +| OpenClaw bundle | 多格式 | ✅ | ✅ | 部分 | 部分 | ✅ | settings, LSP detect | +| Hermes 包 | yaml/config | ✅ | — | — | shell yaml | 独立 MCP | Python plugin ✗ | +| MS-Agent 原生 | `.ms-agent-plugin/` | ✅ | ✅ | ✅ | native | mcp.json | 全 §4.4 | + +--- + +## 15. 分阶段交付与验收 + +### Phase 0 — Manifest + 本地安装 + Skills + Hooks 迁移(P0) + +| 交付项 | 验收 | +|--------|------| +| `PluginManifest.parse` | 多组件 `scan_components()`;空包拒绝;`hookify` fixture | +| `PluginConfigManager` CRUD | global/project merge 单测 | +| `PluginInstaller.install(local)` | 复制到 `~/.ms_agent/plugins//`;写入 `format` + `manifest_path` | +| `PluginLoader` skills | `skills/` + 根 `SKILL.md` 分发为 `SkillSource`;不要求 commands/agents 执行 | +| `PluginLoader` hooks | 复用已实现 `PluginHooksLoader`,产出 `hook_registries` 注入 `build_hook_runtime()` | +| 迁移 `_discover_plugin_roots` | `build_hook_runtime` 不再自行扫描 plugin 根;legacy `config.plugins[]` 仅兼容且不双加载 | +| Hook env 元数据 | `HookHandlerConfig` 或等价 metadata 可携带 `plugin_id/root/data_dir`;command 执行期能拿到 `MS_AGENT_PLUGIN_ROOT` / `MS_AGENT_PLUGIN_DATA` | +| `PluginRuntime.list_all` / `toggle` | CLI 或单元测试;`capabilities_status` 含 §4.4 全键 | + +### Phase 1 — Commands / Agents 列表 + 远程安装 + 扩展组件基础(P1) + +| 交付项 | 验收 | +|--------|------| +| `PluginLoader` commands / agents(list) | `hookify`:4 commands + 1 agent 可见;agents 可先不执行 | +| `plugin_data_dir` 扩展使用 | hook 脚本可读 `MS_AGENT_PLUGIN_DATA`,userConfig / 状态文件写入同一 data 目录 | +| `commands/*.md` | `/hookify:help` 或并入 skill slash 可识别 | +| `github://` 安装 | 集成测试 mock git | +| `modelscope://` 安装 | 复用 skill 下载 | +| Playground API §12.1 | UI 可列表/开关 | +| 文档:enabled_sources 需含 `plugin` | 示例 agent.yaml | +| **`PluginFormatDetector`** | 识别 claude / openclaw / hermes / ms-agent | +| **OpenClaw 部分加载** | skills + Claude hooks.json + MCP 可用;handler.ts 标 `unsupported` | +| **Hermes shell in bundle** | `hooks/hermes.yaml` → `HermesShellLoader` | +| `bin/`、`settings.json`、`userConfig`(基础) | 扫描、状态展示、变量展开;PATH / 配置补丁按白名单落地 | + +### Phase 2 — MCP capability(优先级 P1,交付 Phase 2;对齐 `mcp_runtime_management.md` Phase 3) + +| 交付项 | 验收 | +|--------|------| +| `.mcp.json` + `tools/mcp.json` 解析 | server 出现在 `MCPRuntime.list_servers()` | +| Plugin disable 断开 MCP | 该 plugin 贡献的 server 从 LLM 工具列表消失 | +| 命名冲突前缀 | 单测 | +| 路径变量展开 | node args 含绝对路径 | + +### Phase 3 — Agents 执行 + 高级生态桥接(P2) + +| 交付项 | 验收 | +|--------|------| +| `agents/*.md` → `AgentDelegate` | 按模板 spawn 子 agent | +| `settings.json`、`userConfig` 高级闭环 | 补丁回滚、敏感配置存储、keychain / credentials 集成 | +| Codex / OpenClaw `HOOK.md` **metadata-only** 导入 | 文档 + UI 提示,不执行 handler.ts | +| OpenClaw `handler.ts` Node 子进程桥(可选) | 仅 side-effect 类 hook;不承诺 PreToolUse 等价 | +| Plugin 签名/校验 | 可选 minisign | + +--- + +## 16. 多生态兼容:OpenClaw 与 Hermes + +> 与 [`hooks-design.md`](hooks-design.md) §3.6 / §15 / 附录 B 对齐。本节回答:**OpenClaw bundle 检测难不难?能否与 Hermes 一并兼容?** + +### 16.1 结论(先说) + +| 能力 | 难度 | 能否并入 P1 | 说明 | +|------|------|-------------|------| +| **格式识别**(Claude / OpenClaw / Hermes / ms-agent) | 低 | ✅ 是 | `PluginFormatDetector`,无新运行时 | +| **子资源复用**(skills、MCP、Claude `hooks.json`) | 低 | ✅ 是 | 已有 Loader 直接吃 | +| **Hermes shell hooks**(包内或全局 config) | 低 | ✅ 是 | `HermesShellLoader` **已实现** | +| **OpenClaw HOOK.md 元数据 + 文档展示** | 低 | ✅ 是 | 解析 frontmatter,UI 列出 | +| **OpenClaw `handler.ts` 原样执行** | 高 | ❌ 否(P2 可选桥接) | TS 进程内 API,事件模型不同 | +| **Hermes Python plugin `register_hook()`** | 高 | ❌ 否 | Hermes 进程内 API | +| **Hermes Gateway hook**(`HOOK.yaml` + `handler.py`) | 中 | ❌ 否(P2 文档) | 仅 Gateway 生命周期 | + +**可以一并兼容的部分**:安装/发现/开关/Skills/MCP/Shell hooks —— 与 Claude Plugin 共用 `PluginLoader` 分发链。 +**不应承诺一并兼容的部分**:在 ms-agent 内嵌 OpenClaw Gateway 或 Hermes 的 **进程内 hook 虚拟机**。 + +### 16.2 为何 OpenClaw 曾被标 P2 + +OpenClaw 实际有 **两套 hook**,与 ms-agent(对齐 Claude Code)的 hook **不是同一类产品**: + +```plaintext +OpenClaw 内部 hook(HOOK.md + handler.ts) + 事件:command:new, gateway:startup, message:received, agent:bootstrap ... + 模型:Gateway 侧效应 / 消息通道 / 会话生命周期 + 执行:TypeScript 进程内,handler 接收 OpenClaw event 对象 + +OpenClaw Typed Plugin Hook(api.on(...)) + 事件:before_tool_call, before_agent_reply, session_end ... + 模型:有序中间件 / 策略门 + 执行:TS 进程内 Plugin SDK + +ms-agent / Claude Code Canonical Hook + 事件:PreToolUse, PostToolUse, UserPromptSubmit, Stop ... + 模型:Agent 工具管线拦截 + 执行:子进程 command + stdin JSON(或 P2 http/prompt) +``` + +OpenClaw 官方文档也明确:**工具拦截、策略门**应走 Typed Plugin Hook,Internal Hook 适合 `/new` 记日志、gateway 启动跑 `BOOT.md` 等 **粗粒度自动化**——与 ms-agent `PreToolUse` 语义不对等。 + +因此「OpenClaw 兼容」若理解为 **跑通全部 handler.ts**,难度高且产品边界模糊;若理解为 **识别 bundle + 加载其中 Claude 兼容部分**,难度低,**应与 P1 Plugin 模块一起做**。 + +### 16.3 OpenClaw bundle:P1 可做什么 + +**识别特征**(`PluginFormat.OPENCLAW`,优先级低于显式 `plugin.json`): + +| 信号 | 路径 | +|------|------| +| npm hook pack | `package.json` → `"openclaw": { "hooks": ["hooks/foo"] }` 或 `"openclaw.hooks"` | +| HOOK 目录 | `hooks//HOOK.md` + `handler.ts` | +| MCP | `openclaw.json` → `mcpServers`(或合并进宿主 config 的 MCP 段) | +| Skills | workspace `skills/` 或包内 `skills/` | + +**P1 加载策略**(`OpenClawBundleAdapter`): + +```python +def adapt_openclaw(root: Path) -> PluginLoadResult: + result = PluginLoadResult() + # 1. skills/ — 同 Claude,SkillCatalog + # 2. hooks/hooks.json — 若存在,PluginHooksLoader(Claude 社区常双发) + # 3. openclaw.json mcpServers — 转 tools/mcp.json 语义进 MCPRuntime + # 4. HOOK.md 目录 — 仅 parse frontmatter → capabilities_status.hooks.openclaw_internal + # handler.ts 标记 unsupported,UI 展示「需 OpenClaw Gateway 或导出 shell 版」 + return result +``` + +**与 OpenClaw 自身行为一致**:OpenClaw 对 Claude `hooks.json` 也是 **detect-only、不执行**;ms-agent 反而 **更兼容**(Claude command hook 可直接跑)。 + +**P2 可选**(非 P1 承诺):对 `handler.ts` 提供 **Node 子进程桥** —— 将 Canonical 事件 **近似** 映射为 OpenClaw event JSON,仅建议 side-effect 类 hook(如 command-logger);**不**用于 PreToolUse 策略门。 + +### 16.4 Hermes:已有什么、Plugin 层补什么 + +Hermes 三套 hook(详见 `hooks-design.md` 附录 B): + +| 类型 | ms-agent 现状 | Plugin 包内 | +|------|---------------|-------------| +| **Shell hooks** | ✅ `HermesShellLoader` + `enabled_sources: hermes` 读 `~/.hermes/config.yaml` | P1:包内 `hooks/hermes.yaml` 或 `hooks/config.yaml` 的 `hooks:` 段 → 同一 Loader | +| **Python plugin hook** | ❌ 不执行 | 安装时检测 `register(ctx)` / `pyproject` hermes 段 → `unsupported` + 迁移文档 | +| **Gateway hook** | ❌ 不执行 | 检测 `HOOK.yaml` → 提示仅 Gateway 可用 | + +**Hermes 与 Plugin 一并兼容的成本很低**,因为: + +1. Shell loader **已落地**(`ms_agent/hooks/loaders/hermes.py`),Plugin 只需多一个发现路径。 +2. Hermes shell hook 脚本与 Claude command hook **共用** `HookExecutor` + `ResponseAdapter`(`decision:block` / `action:block` 已归一化)。 +3. 不需要 Hermes 运行时即可跑 **包内的 shell 脚本**。 + +当前 `HermesShellLoader.load_file()` 只接收 `path` + `project_path`,尚未支持 plugin root/data 路径变量展开,也未接收 `enabled_executors`。因此 P1 的包内 Hermes adapter 需要补齐与 `PluginHooksLoader` 等价的能力:展开 `${MS_AGENT_PLUGIN_ROOT}` / `${CLAUDE_PLUGIN_ROOT}` / `${MS_AGENT_PLUGIN_DATA}`,并给 handler 标注 `source_plugin_*` metadata。 + +**PluginLoader 扩展**: + +```python +# hooks/ 目录多格式探测(按优先级,不互斥) +if (root / "hooks" / "hooks.json").is_file(): + merge(PluginHooksLoader...) # Claude / Codex plugin +if (root / "hooks" / "hermes.yaml").is_file(): + merge(HermesShellLoader.load_file(...)) +elif (root / "hooks" / "config.yaml").is_file(): + merge(HermesShellLoader.load_file(...)) # 仅 parse hooks: 段 +``` + +全局 Hermes 配置(`~/.hermes/config.yaml`)**不经过 Plugin 模块**,仍由 `build_hook_runtime` 在 `enabled_sources` 含 `hermes` 时加载——与 Plugin 正交、可叠加。 + +### 16.5 统一:`PluginFormatDetector` + +安装/扫描时自动识别,写入 `PluginManifest.format` 与 `capabilities_status`: + +```python +class PluginFormat(str, Enum): + MS_AGENT = "ms-agent" # plugin.json + ms_agent 段 + CLAUDE = "claude" # plugin.json(无 ms_agent)或纯 Claude 布局 + OPENCLAW = "openclaw" # package.json openclaw.* 或 HOOK pack + HERMES = "hermes" # 以 Hermes hooks yaml 为主,无 plugin.json + MIXED = "mixed" # 多格式并存(常见:Claude plugin + OpenClaw HOOK pack) +``` + +`list_all()` 响应增加: + +```json +{ + "format": "mixed", + "compatibility": { + "skills": "ready", + "hooks_claude": "ready", + "hooks_hermes_shell": "ready", + "hooks_openclaw_internal": "unsupported", + "hooks_hermes_python": "unsupported", + "tools_mcp": "ready" + }, + "migration_hints": [ + "3 OpenClaw internal hooks (handler.ts) skipped — export shell equivalents or run under OpenClaw Gateway" + ] +} +``` + +### 16.6 与「一并兼容」的产品表述 + +对用户可承诺: + +- 安装 OpenClaw hook pack / Hermes 技能包时,**Skills、MCP、Claude hooks.json、Hermes shell hooks 自动可用**。 +- UI 明确列出 **未加载** 的进程内 hook 及原因,避免静默失败。 +- 同一 `plugins.json` 管理 enable/disable,不区分来源框架。 + +不可承诺(除非远期单独立项 Node/Hermes 嵌入式运行时): + +- OpenClaw `handler.ts` / Typed `api.on()` 零改动运行。 +- Hermes Python `ctx.register_hook()` 零改动运行。 + +### 16.7 实现增量(并入 Phase 1) + +| 文件 | 变更 | +|------|------| +| `plugins/format_detector.py` | 新增:Claude / OpenClaw / Hermes 识别 | +| `plugins/adapters/openclaw.py` | 新增:部分加载 + unsupported 汇总 | +| `plugins/adapters/hermes.py` | 新增:包内 yaml hooks 路径 | +| `plugins/loader.py` | 调用各 adapter,合并 `PluginLoadResult` | +| `hooks/factory.py` | 可选:global hermes 与 plugin hermes 去重说明 | + +验收:fixture 含 `hooks/hooks.json` + `hooks/hermes.yaml` + `hooks/foo/HOOK.md` 的 mixed 包,安装后 Claude + Hermes shell 生效,OpenClaw internal 出现在 `migration_hints`。 + +--- + +## 17. 风险与对策 + +| 风险 | 对策 | +|------|------| +| 恶意 Plugin hook 执行任意命令 | 默认不启用 `plugin` source;Playground 展示明确风险说明;限制 enabled_executors;timeout / fail_closed;安装来源校验。注意 hook command 子进程不经过 SafetyGuard | +| skill_id 与内置 skill 冲突 | 加载顺序 + UI 标记来源;warning 日志 | +| MCP server 命名冲突 | `plugin..` 前缀 | +| GitHub 安装供应链 | 固定 commit / tag;`resolved_sha` 写入 plugins.json | P2 可选 hash 锁定 | +| 多 manifest 同目录 | 安装时 `AmbiguousPluginManifest`;要求 `--format` | 运行时读锁定 `manifest_path` | +| `--link` 与 Claude 共享目录 | 文档警告;默认 copy 隔离 | 产品默认 copy | +| Plugin 体积过大 | 安装前 size 检查;git shallow clone | +| 热重载竞态 | `PluginRuntime._reload_lock`;与 `MCPRuntime._sync_lock` 同级 | +| 密钥写入 plugin 配置 | 导出时脱敏;复用 MCP `Env` 替换 | + +--- + +## 18. 测试策略 + +```python +# tests/plugins/test_loader_hooks.py +def test_plugin_hooks_merge_with_native(): + """plugin PreToolUse 与 native hooks 合并;matcher 生效。""" + +# tests/plugins/test_installer_local.py +def test_install_idempotent(): + """重复 install 同 version 不重复复制。""" + +# tests/plugins/test_runtime_toggle.py +async def test_disable_plugin_removes_skills_from_catalog(): + """toggle enabled=false 后 SkillRuntime.list_all 不可见。""" + +# tests/plugins/test_loader_mcp.py +async def test_plugin_mcp_tools_sync(): + """安装带 tools/mcp.json 的 plugin 后 ToolManager 可见 server---tool。""" +``` + +**E2E 黄金测例**:见 [附录 D — hookify](#附录-d黄金测例--hookify)(官方社区 Plugin,覆盖 manifest / skills / commands / agents / hooks)。 + +--- + +## 19. 社区 Plugin 组件全景(调研) + +> 来源:Claude Code [Plugins reference](https://code.claude.com/docs/en/plugins-reference)、Codex [Build plugins](https://developers.openai.com/codex/plugins/build)、OpenClaw [Plugin CLI](https://documentation.openclaw.ai/cli/plugins)、Hermes 架构文档与 `hooks-design.md` 附录 B。 +> 目的:避免 F9 只覆盖 skill/hook/mcp 而遗漏社区包中高频出现的其他配置项。 + +### 19.1 组件总表 + +| 组件 | 典型路径 / manifest 字段 | Claude | Codex | OpenClaw | Hermes | MS-Agent 策略 | 优先级 | +|------|---------------------------|--------|-------|----------|--------|---------------|--------| +| **Skills** | `skills/*/SKILL.md`、根 `SKILL.md` | ✅ | ✅ | ✅ bundle | ✅ 目录 | → `SkillCatalog` | **P0** | +| **Commands**(legacy) | `commands/*.md` | ✅ | — | ✅ command-skills | — | → Skill 或 `CommandRouter` | **P1** | +| **Agents / Subagents** | `agents/*.md` | ✅ | — | 部分 | — | → `AgentDelegate` / 子 agent 模板(F1.2 扩展) | **P1** | +| **Hooks (shell)** | `hooks/hooks.json` | ✅ | ✅ | 部分 | ✅ yaml | → `HookRegistry` | **P0/P1** | +| **Hooks (prompt/http/agent/mcp_tool)** | hooks.json `type` 字段 | ✅ | 部分 | ✗ | ✗ | P2 Executor(见 hooks-design §17) | P2 | +| **MCP servers** | `.mcp.json`、`mcpServers` | ✅ | ✅ | ✅ | MCP 独立 | → `MCPRuntime` | **P1** | +| **App Connectors** | `.app.json`、`apps` | — | ✅ | — | — | P2 OAuth 后端 + 凭证存储 | P2 | +| **LSP servers** | `.lsp.json`、`lspServers` | ✅ | — | ✅ bundle 默认 | — | P3 或 detect-only(Playground 非 IDE) | P3 | +| **Output styles** | `output-styles/` | ✅ | — | — | — | P3 / ignore(纯 UI) | P3 | +| **Themes** | `themes/`、`experimental.themes` | ✅ | — | — | — | ignore(CLI/TUI 可选) | P3 | +| **Monitors** | `monitors/monitors.json` | ✅ exp | — | — | — | P3 对标 Monitor tool | P3 | +| **bin/** | 可执行文件 | ✅ | — | — | — | P1:注入 `code_executor` PATH 或 document | **P1** | +| **settings.json** | plugin 根 | ✅ | — | ✅ Claude defaults | — | P1:merge 进 project/global settings 子集 | **P1** | +| **scripts/** | 辅助脚本 | 引用 | 引用 | 引用 | 引用 | 不单独加载;随 hook/MCP 路径展开 | — | +| **assets/** | icon/logo/screenshots | — | ✅ `interface.*` | — | — | UI 元数据 only | P1 UI | +| **userConfig** | manifest 字段 | ✅ | — | plugin config | — | P1:安装/启用时表单 → `pluginConfigs` | **P1** | +| **dependencies** | manifest 数组 | ✅ | — | — | — | P1:安装时解析依赖链 | **P1** | +| **channels** | manifest 数组 | ✅ | — | — | — | P3(MCP 消息注入通道) | P3 | +| **defaultEnabled** | manifest bool | ✅ | — | — | — | 读入 `plugins.json` 默认 enabled | P1 | +| **OpenClaw internal hooks** | `HOOK.md`+`handler.ts` | detect | — | ✅ | — | unsupported(§16) | detect | +| **OpenClaw native plugin** | `openclaw.plugin.json`+TS | — | — | ✅ | — | unsupported(进程内 SDK) | detect | +| **Hermes Python plugin** | `register(ctx)` | — | — | — | ✅ | unsupported | detect | +| **Hermes Gateway hook** | `HOOK.yaml`+`handler.py` | — | — | — | ✅ | unsupported | detect | +| **Marketplace** | `marketplace.json` | ✅ | ✅ | ClawHub | — | 安装源,非 plugin 内容(§19.3) | P1 | +| **Rules / CLAUDE.md 片段** | `rules/`、包内 md | 部分 | — | — | — | P2:merge 进 personalization 或 project instruction | P2 | + +### 19.2 原设计已覆盖 vs 遗漏 + +**已覆盖(v0.1 设计层)**:skills、hooks/hooks.json、tools→MCP、commands(P2)、环境变量桥接、`plugins.json` CRUD。代码现状仅 hooks 局部落地;环境变量为 executor 预留,执行期来源 metadata 尚未接通。 + +**本次调研补充的遗漏项**(按 MS-Agent 价值排序): + +#### A. 高价值 — 建议并入 P1 + +1. **`.claude-plugin/` / `.codex-plugin/` manifest 路径** + 社区包几乎不用根目录 `plugin.json`;Detector 必须识别子目录 manifest。 + +2. **`.mcp.json` 文件名**(非 `tools/mcp.json`) + Loader 应同时探测:`.mcp.json`、`tools/mcp.json`、manifest 内联 `mcpServers`。 + +3. **`agents/` 子 agent 定义** + Claude 社区大量 plugin 通过 agents 提供专用 reviewer/planner。 + MS-Agent 映射:Playground F1.2 子 agent 模板 + `AgentDelegate` / `capabilities` 包装;frontmatter 字段 `model`、`tools`、`disallowedTools`、`skills` 写入 resolved agent config。 + +4. **`commands/` 遗留 slash** + 与 `skills/` 统一为 Skill 加载(Claude 2026 已合并语义);flat `.md` 走 `SkillLoader` 单文件模式或 `CommandRouter`。 + +5. **`bin/` PATH 注入** + Claude:启用 plugin 时把 `bin/` 加入 Bash tool 的 PATH。 + MS-Agent:`LocalCodeExecutor` / `WorkspaceContext` 扩展 `plugin_bin_paths`;disable 时移除。 + +6. **`settings.json` 默认配置** + Claude 仅支持 `agent`、`subagentStatusLine` 等键;OpenClaw bundle 还支持 Claude `settings.json` 默认值。 + MS-Agent:merge 到 session/project 的 agent.yaml 补丁(enabled 时 apply,disable 时 revert)。 + +7. **`userConfig` + `${user_config.*}` / `CLAUDE_PLUGIN_OPTION_*`** + 启用 plugin 时 UI 表单收集;写入 `~/.ms_agent/plugins/data//config.json`;展开到 MCP/hook/monitor 命令字符串。 + +8. **`dependencies` 插件依赖** + 安装 `formatter` 时自动安装 `secrets-vault@~2.1.0`;`PluginInstaller` 拓扑排序。 + +9. **根目录单文件 `SKILL.md`** + 无 `skills/` 时整包即一个 skill(marketplace 安装常见)。 + +10. **Codex `interface` / `assets/`** + Playground Plugin 列表展示 displayName、icon、screenshots;纯 UI,不进入 Runtime。 + +#### B. 中价值 — P2 + +11. **`.app.json` App Connectors(Codex)** + Slack/GitHub/Notion OAuth 连接器;需 Playground 后端 OAuth 跳转(`mcp_runtime_management.md` Phase 3 认证项)。 + +12. **Hook 扩展类型**:`prompt`、`http`、`agent`、`mcp_tool` + 已在 `hooks-design.md` §17;Plugin 内 hooks.json 常见 prompt 型策略 hook。 + +13. **`rules/` / 包内 instruction 片段** + 映射到 `PersonalizationInjector` 或 project `.ms-agent/config.yaml` patch。 + +14. **Skill frontmatter 扩展**(Claude 2026 统一 skill/command) + `allowed-tools`、`context: fork`、`agent`、`model`、`paths`、`disable-model-invocation` — 影响 SkillRuntime 与 AgentDelegate 行为。 + +#### C. 低价值 / 非 Playground 核心 — P3 或 detect-only + +15. **`.lsp.json`** — IDE 代码智能;OpenClaw 已支持 bundle 默认,MS-Agent CLI 可 detect + 文档说明。 +16. **`output-styles/`、`themes/`** — 纯终端/UI 呈现。 +17. **`monitors/`** — Claude 后台监视 + 通知;需 Monitor tool 对标。 +18. **`channels`** — MCP 驱动的消息注入通道。 +19. **OpenClaw/Hermes 进程内扩展** — 仅 detect(§16)。 + +### 19.3 Marketplace 与 Plugin 的边界 + +社区分发常通过 **marketplace.json**(非 plugin 内容),MS-Agent 安装器需支持但不应混入 `PluginLoader`: + +| 文件 | 作用 | MS-Agent 模块 | +|------|------|---------------| +| `marketplace.json` / `.agents/plugins/marketplace.json` | 插件目录、source.path、policy | `PluginInstaller` 索引源 | +| `.claude-plugin/marketplace.json` | Claude 官方/团队 marketplace | 同上 | +| Codex `~/.agents/plugins/marketplace.json` | 个人/仓库 curated list | 同上 | +| OpenClaw ClawHub / `plugins/installs.json` | 安装记录 + registry | 参考 `PluginConfigManager` 设计 | + +Marketplace entry 字段:`source`(local/git/url)、`policy.installation`、`policy.authentication`、`category`、`interface.displayName` — 用于 UI,不进入 Agent 运行时。 + +### 19.4 PluginLoadResult 扩展(修订) + +```python +@dataclass(frozen=True) +class PluginHookContribution: + plugin_id: str + registry: HookRegistry + plugin_root: Path + plugin_data_dir: Path + +@dataclass +class PluginLoadResult: + skill_sources: list[SkillSource] + hook_registries: list[PluginHookContribution] + mcp_servers: dict[str, dict] + command_defs: list[CommandDef] + agent_defs: list[AgentDef] # 新增:agents/*.md + settings_patch: dict[str, Any] # 新增:settings.json 片段 + bin_paths: list[Path] # 新增:bin/ + user_config_schema: dict[str, Any] # 新增:manifest userConfig + ui_metadata: dict[str, Any] # 新增:interface/assets + unsupported: list[UnsupportedCapability] # 新增:lsp/themes/monitors/... +``` + +### 19.5 修订后的分阶段交付(补充) + +在 §15 基础上补充分阶段验收,避免把 MCP / agents 执行混入 Phase 1: + +| Phase | 交付项 | 验收 | +|-------|--------|------| +| 0 | Manifest 多路径 | `.claude-plugin` / `.codex-plugin` / 根 `plugin.json`;多 manifest 冲突需显式 `--format` | +| 0 | 根 `SKILL.md` 单 skill | 安装后 catalog 可见 | +| 0 | Plugin hooks 迁移 | `PluginLoader` 产出 hook registry,`build_hook_runtime()` 不再自行扫描并双加载 | +| 1 | `commands/*.md` → skill/command | 至少一种路径可用 | +| 1 | `agents/*.md` 解析 | `list_all` 展示;P1 可先不执行 delegate | +| 1 | `bin/` PATH | shell 工具可调用 plugin 内命令 | +| 1 | `userConfig` 表单 + 变量展开 | `${user_config.key}` 在 hook command 中生效 | +| 1 | `dependencies` 安装顺序 | 依赖 plugin 先于主包 install | +| 1 | `compatibility` 完整报告 | §16 + §19 全部组件状态 | +| 1 | **黄金测例 hookify E2E** | 附录 D 中非 MCP 断言通过 | +| 2 | `.mcp.json` + `tools/mcp.json` 双路径 | `example-plugin` / synthetic fixture 单测,server 出现在 `MCPRuntime.list_servers()` | + +--- + +## 附录 D:黄金测例 — hookify + +> **选定结论**:MS-Agent Plugin 体系的**最终集成测例**采用 Anthropic 官方社区目录中的 [**hookify**](https://github.com/anthropics/claude-plugins-official/tree/main/plugins/hookify)(`hookify@claude-plugins-official`)。 +> 选型时间:2026-06-18;对照 §19 组件全景与真实社区分发路径。 + +### D.1 为何选 hookify(而非 demo 包或其它 official plugin) + +| 候选 | 来源 | 覆盖组件 | 不选原因 | +|------|------|----------|----------| +| `yasun1/claude-code-plugin-demo` → `my-first-plugin` | 社区 demo | 声称 5 类 | **非标准**:无 `hooks/hooks.json`(仅散落 `.sh`);MCP 在 `mcp-server/` 而非 `.mcp.json`;agents 用 `AGENT.md` 非 `*.md` | +| `example-plugin` | official | skills + commands + `.mcp.json` | 过薄;无 hooks/agents;MCP 仅为 HTTP 占位 | +| `feature-dev` | official | agents + commands | 无 hooks、无 MCP | +| `security-guidance` | official | hooks(复杂 Python) | 仅 hooks 单组件;依赖多、CI 重 | +| **`hookify`** | **official community** | **manifest + skills + commands + agents + hooks** | ✅ **选用** | + +**hookify** 优势: + +1. **真实分发路径**:`anthropics/claude-plugins-official` 社区 marketplace,与 Playground「安装社区 Plugin」一致。 +2. **标准 Claude 布局**:`.claude-plugin/plugin.json` + 约定目录,非教学用非标结构。 +3. **`hooks/hooks.json` 含包装层**(`{"hooks": {...}}`),与 `PluginHooksLoader` 路径一致。 +4. **四类 Canonical 事件**:`PreToolUse`、`PostToolUse`、`Stop`、`UserPromptSubmit`。 +5. **`${CLAUDE_PLUGIN_ROOT}`** 出现在 command 字符串,可验收路径展开与环境变量桥接。 +6. **多组件并存**:同包内 skills / commands / agents / hooks,一次 install 测分发链。 +7. **体积适中**:无 LSP/bin/userConfig,CI 可跑;比 `security-guidance` 轻、比 `example-plugin` 全。 + +**已知不覆盖**(由同仓库 **`example-plugin`** 作 MCP 补充冒烟,非黄金主测例): + +- `.mcp.json` → `example-plugin`(HTTP MCP 占位) +- `bin/`、`settings.json`、`userConfig`、LSP → 后续 synthetic fixture + +### D.2 包结构与安装源 + +```plaintext +anthropics/claude-plugins-official/plugins/hookify/ +├── .claude-plugin/ +│ └── plugin.json # name: hookify +├── hooks/ +│ ├── hooks.json # plugin 包装格式 + 4 事件 +│ ├── pretooluse.py +│ ├── posttooluse.py +│ ├── stop.py +│ └── userpromptsubmit.py +├── skills/ +│ └── writing-rules/ +│ └── SKILL.md # skill_id: writing-hookify-rules +├── commands/ +│ ├── hookify.md +│ ├── configure.md +│ ├── help.md +│ └── list.md +├── agents/ +│ └── conversation-analyzer.md +├── core/ # hook 运行时依赖(Python 模块) +├── matchers/ +└── examples/ # 示例 .local.md 规则 +``` + +**安装 URI(测试 / Playground)**: + +```text +github://anthropics/claude-plugins-official@main#plugins/hookify +``` + +或 marketplace 本地路径(开发): + +```text +file:///path/to/claude-plugins-official/plugins/hookify +``` + +`plugins.json` 记录示例: + +```json +{ + "id": "hookify", + "enabled": true, + "managed_by": "ms-agent", + "format": "claude", + "manifest_path": ".claude-plugin/plugin.json", + "source": { + "type": "github", + "uri": "github://anthropics/claude-plugins-official@main#plugins/hookify", + "resolved_sha": "" + }, + "path": "~/.ms_agent/plugins/hookify", + "installed_at": "2026-06-18T12:00:00Z" +} +``` + +### D.3 组件 → MS-Agent 验收映射 + +| hookify 组件 | 预期 MS-Agent 行为 | 验收方式 | +|--------------|-------------------|----------| +| `.claude-plugin/plugin.json` | `PluginManifest.parse` → `plugin_id=hookify` | 单元测试 | +| `skills/writing-rules/SKILL.md` | `SkillCatalog` 含 `writing-hookify-rules` | `SkillRuntime.list_all()` | +| `commands/*.md` (×4) | 注册为 slash 或转 skill;`capabilities_status.commands=ready` | `/hookify` 可识别 | +| `agents/conversation-analyzer.md` | `list_all` 展示;`capabilities_status.agents=ready`;P1 可不执行 delegate | metadata 断言 | +| `hooks/hooks.json` | merge 进 `HookRegistry`(`enabled_sources` 含 `plugin`) | registry 含 4 事件 | +| `${CLAUDE_PLUGIN_ROOT}` | 展开为安装绝对路径 | hook command 不含未展开变量 | +| `MS_AGENT_PLUGIN_DATA` | pretooluse.py 可写规则状态目录 | 环境变量单测 | +| `core/`、`matchers/` | 不单独加载;随 Python hook 引用 | 无 assert | +| — 无 `.mcp.json` | `capabilities_status.mcp=skipped` | `list_all` 报告 | + +### D.4 E2E 场景(最终测例脚本) + +```python +# tests/plugins/test_golden_hookify.py — 目标文件(实现 Phase 1 后启用) + +HOOKIFY_URI = "github://anthropics/claude-plugins-official@main#plugins/hookify" + +async def test_golden_hookify_install_and_manifest(): + manifest = await PluginRuntime.install(HOOKIFY_URI, scope="global") + assert manifest.plugin_id == "hookify" + assert manifest.format in ("claude", "mixed") + assert "hooks" in manifest.capabilities + +async def test_golden_hookify_skills_loaded(): + runtime = await start_session_with_plugins(["hookify"]) + skills = runtime.skill_runtime.list_all() + ids = {s["skill_id"] for s in skills} + assert "writing-hookify-rules" in ids + assert any(s.get("plugin_id") == "hookify" for s in skills) + +async def test_golden_hookify_hooks_registered(): + load_result = PluginLoader.load_all([manifest_for("hookify")], ctx_for_test()) + hr = build_hook_runtime( + config_with_plugin_source_enabled(), + session_id="t1", + plugin_hook_registries=load_result.hook_registries, + ) + assert not hr.registry.is_empty + for event in ("PreToolUse", "PostToolUse", "Stop", "UserPromptSubmit"): + assert event in hr.registry._index + +async def test_golden_hookify_pretooluse_runs(): + """安装后执行一次 read_file;pretooluse.py 应被调用(exit 0,不阻断)。""" + ... + +async def test_golden_hookify_slash_command(): + router = build_command_router(skill_catalog=...) + result = await router.dispatch(parse("/hookify")) + assert result is not None # MESSAGE 或 SUBMIT_PROMPT + +async def test_golden_hookify_toggle_disable(): + await runtime.toggle("hookify", enabled=False) + assert "writing-hookify-rules" not in visible_skill_ids() + assert hook_registry_for_plugin("hookify").is_empty +``` + +### D.5 Fixture 策略 + +| 方式 | 路径 | 用途 | +|------|------|------| +| **CI 推荐** | `git sparse-checkout` 仅 `plugins/hookify` | 网络安装集成测 | +| **离线单测** | `tests/plugins/fixtures/hookify/`(vendor 快照,pin commit SHA) | 无网 / 确定性回归 | +| **MCP 补充** | 同仓库 `plugins/example-plugin`(仅 `.mcp.json` 冒烟) | 不并入黄金主流程 | + +Vendor 命令(维护者): + +```bash +git clone --depth 1 --filter=blob:none --sparse \ + https://github.com/anthropics/claude-plugins-official.git /tmp/cc-plugins +cd /tmp/cc-plugins && git sparse-checkout set plugins/hookify +cp -R plugins/hookify tests/plugins/fixtures/hookify +# 在 fixtures/hookify/VENDOR_SHA 记录 commit SHA +``` + +### D.6 与 §15 交付的关系 + +- **Phase 0 完成标准**:`PluginManifest.parse`、skills 加载、plugin hooks registry 注入通过 hookify 本地 fixture;不要求远程安装和 slash command。 +- **Phase 1 完成标准**:附录 D.4 中 hookify 非 MCP 场景全绿(含 github 安装、commands、agents list、toggle)。 +- **Phase 2 MCP**:另跑 `example-plugin`,不阻塞 hookify 黄金测例。 + +--- + +## 附录 A:plugins.json 示例 + +**~/.ms_agent/plugins.json** + +```json +{ + "plugins": [ + { + "id": "commit-helper", + "enabled": true, + "managed_by": "ms-agent", + "format": "claude", + "manifest_path": ".claude-plugin/plugin.json", + "source": { + "type": "github", + "uri": "github://org/commit-helper@v1.2.0", + "resolved_sha": "abc123def456..." + }, + "path": "/Users/me/.ms_agent/plugins/commit-helper", + "installed_at": "2026-06-18T10:00:00Z" + }, + { + "id": "local-linter", + "enabled": false, + "source": { + "type": "local", + "uri": "/path/to/local-linter" + }, + "path": "/Users/me/.ms_agent/plugins/local-linter", + "installed_at": "2026-06-17T08:00:00Z" + } + ] +} +``` + +**项目级 `/.ms-agent/plugins.json`**:结构相同;同 id 覆盖 global 的 `enabled`。 + +--- + +## 附录 B:plugin.json 字段对照(Claude Code / Codex) + +| 字段 | Claude | Codex | MS-Agent 处理 | +|------|--------|-------|---------------| +| `name` | ✅ 必填 | ✅ | `plugin_id` | +| `version` | 可选 | ✅ | 升级检测 | +| `description` | ✅ | ✅ | UI | +| `author` / `homepage` / `repository` / `license` / `keywords` | ✅ | ✅ | UI 元数据 | +| `displayName` | ✅ | — | UI(`interface.displayName`) | +| `skills` | 路径 | 路径 | → SkillCatalog | +| `commands` | 路径 | — | → Skill / Command | +| `agents` | 路径 | — | → AgentDef(P1) | +| `hooks` | 路径/inline | 路径/inline | → HookRegistry | +| `mcpServers` | 路径/inline | 路径/inline | → MCPRuntime | +| `apps` | — | 路径 | → AppConnector(P2) | +| `lspServers` | 路径/inline | — | detect-only(P3) | +| `outputStyles` | 路径 | — | ignore P3 | +| `experimental.themes` | 路径 | — | ignore | +| `experimental.monitors` | 路径 | — | P3 | +| `userConfig` | ✅ | — | 启用表单 + `${user_config.*}` | +| `dependencies` | ✅ | — | 安装拓扑 | +| `defaultEnabled` | ✅ | — | plugins.json 默认 | +| `channels` | ✅ | — | P3 | +| `interface` | — | ✅ | UI only | +| `ms_agent.*` | — | — | MS-Agent 扩展 | + +Claude 未在 manifest 声明路径时,使用 **约定目录**(§4.1)。Codex 额外约定:无 manifest `hooks` 字段时自动读 `hooks/hooks.json`。 + +--- + +## 附录 C:跨文档约定 + +| 主题 | 约定 | +|------|------| +| 工具名分隔符 | `---`(`permission-design.md` / `hooks-design.md`) | +| MCP server 合并 | `ConfigResolver.resolve_mcp()`(`mcp_runtime_management.md` §5) | +| Skill disabled vs Plugin disabled | Plugin off = 全部子资源 off;Skill off = 仅 prompt 注入 off,`/` 仍可触发 | +| Hook source 开关 | `hooks.enabled_sources` 含 `plugin` 才加载 Plugin hooks | +| 工作空间元数据目录 | `.ms-agent/plugins/`、`plugins.json` 与 `mcp.json` 同级 | +| WebUI 迁移 | 与 MCP §10.1 相同三阶段:并存 → 收敛 | + +--- + +**文档维护**:实现 Phase 0 完成后,在 `hooks-design.md` §15 增加指向本文的链接;`mcp_runtime_management.md` Phase 3 Plugin tools 条目标记为「设计见 plugins-design.md」。 diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index 58180d3d6..0c6ee3fc3 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -6,7 +6,8 @@ from ms_agent.llm import Message from ms_agent.utils import read_history, save_history -from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_RETRY_COUNT +from ms_agent.utils.constants import DEFAULT_RETRY_COUNT +from ms_agent.utils.workspace_context import resolve_workspace_root class Agent(ABC): @@ -42,8 +43,14 @@ def __init__(self, self.trust_remote_code = trust_remote_code self.config.tag = tag self.config.trust_remote_code = trust_remote_code - self.output_dir = getattr(self.config, 'output_dir', - DEFAULT_OUTPUT_DIR) + workspace_root = resolve_workspace_root(self.config) + self.output_dir = str(workspace_root) + try: + from omegaconf import open_dict + with open_dict(self.config): + self.config.output_dir = self.output_dir + except Exception: + pass @abstractmethod async def run( diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index b734e913d..bbf173c04 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -4,6 +4,7 @@ import inspect import json import os.path +from pathlib import Path import sys import threading import uuid @@ -148,6 +149,7 @@ def __init__( self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( kwargs.get('mcp_config', {})) self.mcp_client = kwargs.get('mcp_client', None) + self.mcp_runtime = kwargs.get('mcp_runtime', None) self.config_handler = self.register_config_handler() # Skill system (initialized in prepare_skills) @@ -157,6 +159,7 @@ def __init__( # Skill runtime (initialized in prepare_skills) self._skill_runtime: Optional[SkillRuntime] = None + self._plugin_runtime = None # Slash-command router for interactive input (lazily built) self._command_router = None @@ -227,6 +230,9 @@ async def prepare_skills(self): self._skill_runtime.set_system_content_builder( self._build_system_content ) + if getattr(self, '_plugin_runtime', None) is not None: + self._plugin_runtime.skill_runtime = self._skill_runtime + self._plugin_runtime._sync_skill_runtime(self.config) def _build_system_content(self) -> str: """Build the full system prompt content. @@ -488,7 +494,32 @@ async def on_tool_call(self, messages: List[Message]): await self.loop_callback('on_tool_call', messages) async def after_tool_call(self, messages: List[Message]): - if messages[-1].role == 'assistant' and not messages[-1].tool_calls: + assistant = messages[-1] + would_stop = assistant.role == 'assistant' and not assistant.tool_calls + + hook_runtime = getattr(self, '_hook_runtime', None) + if would_stop and hook_runtime is not None and not hook_runtime.is_empty: + from ms_agent.hooks.context import ( + append_stop_blocking_feedback, + apply_hook_result_to_messages, + ) + + last_text = assistant.content if isinstance(assistant.content, str) else '' + stop = await hook_runtime.run_stop( + reason='no_tool_calls', + last_assistant_message=last_text, + stop_hook_active=getattr(self.runtime, 'stop_hook_active', False), + ) + if stop.action in ('block', 'deny'): + append_stop_blocking_feedback(messages, stop.reason) + self.runtime.should_stop = False + self.runtime.stop_hook_active = True + await self.loop_callback('after_tool_call', messages) + return + apply_hook_result_to_messages( + messages, stop, hook_event='Stop') + + if would_stop: self.runtime.should_stop = True await self.loop_callback('after_tool_call', messages) @@ -527,6 +558,7 @@ async def parallel_tool_call(self, name=tool_call_query['tool_name'], resources=tool_call_result_format.resources, tool_detail=tool_call_result_format.tool_detail, + hook_attachments=tool_call_result_format.hook_attachments, ) if _new_message.tool_call_id is None: @@ -537,16 +569,135 @@ async def parallel_tool_call(self, self.log_output(_new_message.content) return messages + def _build_permission_objects(self): + """Create SafetyGuard and PermissionEnforcer from config if configured.""" + from ms_agent.permission import ( + AutoPermissionHandler, + PermissionConfig, + PermissionEnforcer, + PermissionMemory, + SafetyGuard, + ) + from ms_agent.permission.config import SafetyConfig + + raw = {} + if hasattr(self.config, 'permission'): + raw = dict(self.config.permission) if self.config.permission else {} + + from ms_agent.utils.workspace_context import resolve_workspace_root + + workspace_root = str(resolve_workspace_root(self.config)) + perm_config = PermissionConfig.from_dict(raw, project_root=workspace_root) + + allowed_dirs = [workspace_root] + for directory in perm_config.safety.allowed_directories: + if directory not in allowed_dirs: + allowed_dirs.append(directory) + read_only_dirs = list(perm_config.safety.read_only_directories) + safety_guard = SafetyGuard( + config=perm_config.safety, + allowed_dirs=allowed_dirs, + read_only_dirs=read_only_dirs, + workspace_root=workspace_root, + ) + + handler = AutoPermissionHandler() + memory = PermissionMemory(project_path=workspace_root) + enforcer = PermissionEnforcer(config=perm_config, handler=handler, memory=memory) + + return safety_guard, enforcer, perm_config + async def prepare_tools(self): """Initialize and connect the tool manager.""" + import uuid + + from ms_agent.hooks.bridge import CallbackToHookBridge + from ms_agent.hooks.factory import build_hook_runtime + from ms_agent.plugins.runtime import PluginRuntime + from ms_agent.utils.workspace_context import resolve_workspace_root + self.task_manager = TaskManager() + + safety_guard, permission_enforcer, perm_config = self._build_permission_objects() + session_id = ( + self.runtime.session_id + or getattr(self, 'tag', None) + or str(uuid.uuid4()) + ) + raw_hooks = {} + if hasattr(self.config, 'hooks') and self.config.hooks: + raw_hooks = OmegaConf.to_container(self.config.hooks, resolve=True) or {} + enabled_executors = frozenset( + raw_hooks.get('enabled_executors', ['command']) or ['command']) + self._plugin_runtime = PluginRuntime( + skill_runtime=self._skill_runtime, + mcp_runtime=self.mcp_runtime, + ) + self._plugin_runtime.start_sync( + str(resolve_workspace_root(self.config)), + session_id, + config=self.config, + enabled_executors=enabled_executors, + ) + self._register_plugin_commands() + plugin_mcp_servers = self._plugin_runtime.load_result.mcp_servers + if plugin_mcp_servers: + from ms_agent.plugins.runtime import dedupe_mcp_server_names + plugin_mcp_servers = dedupe_mcp_server_names( + plugin_mcp_servers, + set(self.mcp_config.setdefault('mcpServers', {}).keys()), + ) + self._plugin_runtime.load_result.mcp_servers = plugin_mcp_servers + self.mcp_config['mcpServers'].update(plugin_mcp_servers) + hook_runtime = build_hook_runtime( + self.config, + session_id=session_id, + plugin_hook_registries=self._plugin_runtime.load_result.hook_registries, + ) + mcp_rt = self.mcp_runtime + if mcp_rt is not None and plugin_mcp_servers: + from ms_agent.config.mcp_schema import ResolvedMCPConfig + merged_servers = { + state.name: dict(state.config) + for state in mcp_rt.list_servers() + } + merged_servers.update(plugin_mcp_servers) + await mcp_rt.apply_config( + ResolvedMCPConfig(mcp_servers=merged_servers)) + self.tool_manager = ToolManager( self.config, - self.mcp_config, + self.mcp_config if mcp_rt is None else {}, self.mcp_client, + permission_enforcer=permission_enforcer, + safety_guard=safety_guard, + permission_mode=perm_config.mode, + read_policy=perm_config.safety.read_policy, + hook_runtime=hook_runtime, + permission_config=perm_config, trust_remote_code=self.trust_remote_code, + mcp_callable_check=mcp_rt.is_callable if mcp_rt else None, + mcp_failure_handler=mcp_rt.record_failure if mcp_rt else None, + mcp_unavailable_detail=mcp_rt.unavailable_detail if mcp_rt else None, + mcp_success_handler=mcp_rt.record_success if mcp_rt else None, ) + if mcp_rt is not None: + self.tool_manager._skip_mcp_reindex = True + if self._plugin_runtime.agent_registry.has_agents(): + self.tool_manager.ensure_plugin_agent_tools( + self._plugin_runtime.agent_registry, + ) + if hook_runtime.has_session_handlers: + self.register_callback(CallbackToHookBridge(self.config, hook_runtime)) + self._hook_runtime = hook_runtime + if not self.runtime.session_id: + self.runtime.session_id = hook_runtime.session_id + if mcp_rt is not None and not mcp_rt.is_started: + await mcp_rt.start() await self.tool_manager.connect() + if mcp_rt is not None: + mcp_rt.bind_tool_manager(self.tool_manager) + await mcp_rt.sync_tools() for tool in self.tool_manager.extra_tools: if hasattr(tool, 'set_task_manager'): tool.set_task_manager(self.task_manager) @@ -555,6 +706,8 @@ async def cleanup_tools(self): """Cleanup resources used by the tool manager.""" if self.task_manager is not None: self.task_manager.kill_all() + if self.mcp_runtime is not None: + await self.mcp_runtime.stop() if self.tool_manager is not None: await self.tool_manager.cleanup() @@ -646,8 +799,18 @@ def _get_command_router(self): router = CommandRouter() register_builtin_commands(router) self._command_router = router + self._register_plugin_commands() return self._command_router + def _register_plugin_commands(self) -> None: + if self._command_router is None or self._plugin_runtime is None: + return + from ms_agent.plugins.commands import register_plugin_commands + register_plugin_commands( + self._command_router, + self._plugin_runtime.load_result.command_defs, + ) + def _resolve_interactive(self, messages) -> bool: """Decide whether this run is an interactive session. @@ -709,11 +872,7 @@ def _build_personalization_section(self) -> str: return PersonalizationInjector.build(config) async def do_rag(self, messages: List[Message]): - """Process RAG or knowledge search to enrich the user query with context. - - This method handles both traditional RAG and sirchmunk-based knowledge search. - For knowledge search, it also populates searching_detail and search_result - fields in the message for frontend display and next-turn LLM context. + """Process RAG to enrich the user query with context. Args: messages (List[Message]): The message list to process. @@ -885,6 +1044,35 @@ async def step( """ messages = deepcopy(messages) messages = self._append_task_notifications(messages) + from ms_agent.hooks.context import ( + condense_hook_attachments_for_llm, + extract_latest_user_prompt, + apply_hook_result_to_messages, + ) + messages = condense_hook_attachments_for_llm(messages) + + # UserPromptSubmit for multi-turn user input (InputCallback path) + hook_runtime = getattr(self, '_hook_runtime', None) + if (hook_runtime is not None and not hook_runtime.is_empty + and messages and messages[-1].role == 'user' + and self.runtime.round > 0): + prompt_text = extract_latest_user_prompt(messages) + submit = await hook_runtime.run_user_prompt_submit(prompt_text) + if submit.action in ('deny', 'block'): + if messages and messages[-1].role == 'user': + messages.pop() + messages.append(Message( + role='system', + content=( + f'UserPromptSubmit operation blocked by hook:\n' + f'{submit.reason}\n\nOriginal prompt: {prompt_text}'), + )) + self.runtime.should_stop = True + yield messages + return + apply_hook_result_to_messages( + messages, submit, hook_event='UserPromptSubmit') + if (not self.load_cache) or messages[-1].role != 'assistant': messages = await self.condense_memory(messages) await self.on_generate_response(messages) @@ -1203,9 +1391,38 @@ async def run_loop(self, messages: Union[List[Message], str], if self.runtime.round == 0: messages = await self.create_messages(messages) - await self.do_rag(messages) + + hook_runtime = getattr(self, '_hook_runtime', None) + if hook_runtime is not None: + hook_runtime.session_id = self.runtime.session_id + + # SessionStart before UserPromptSubmit (§9.3) await self.on_task_begin(messages) + # UserPromptSubmit — first user message + if hook_runtime is not None and not hook_runtime.is_empty: + from ms_agent.hooks.context import ( + extract_latest_user_prompt, + apply_hook_result_to_messages, + ) + prompt_text = extract_latest_user_prompt(messages) + submit = await hook_runtime.run_user_prompt_submit(prompt_text) + if submit.action in ('deny', 'block'): + messages.append(Message( + role='system', + content=( + f'UserPromptSubmit operation blocked by hook:\n' + f'{submit.reason}\n\nOriginal prompt: {prompt_text}'), + )) + await self.on_task_end(messages) + yield messages + await self.cleanup_tools() + return + apply_hook_result_to_messages( + messages, submit, hook_event='UserPromptSubmit') + + await self.do_rag(messages) + for message in messages: if message.role != 'system': self.log_output('[' + message.role + ']:') diff --git a/ms_agent/agent/runtime.py b/ms_agent/agent/runtime.py index 55a0dbf9e..ef207f37e 100644 --- a/ms_agent/agent/runtime.py +++ b/ms_agent/agent/runtime.py @@ -16,14 +16,22 @@ class Runtime: round: int = 0 + stop_hook_active: bool = False + + session_id: str = '' + def to_dict(self): return { 'should_stop': self.should_stop, 'tag': self.tag, 'round': self.round, + 'stop_hook_active': self.stop_hook_active, + 'session_id': self.session_id, } def from_dict(self, data: dict): self.should_stop = data['should_stop'] self.tag = data['tag'] self.round = data['round'] + self.stop_hook_active = data.get('stop_hook_active', False) + self.session_id = data.get('session_id', '') diff --git a/ms_agent/cli/cli.py b/ms_agent/cli/cli.py index da709e98d..28e08b7bb 100644 --- a/ms_agent/cli/cli.py +++ b/ms_agent/cli/cli.py @@ -1,6 +1,7 @@ import argparse from ms_agent.cli.app import AppCMD +from ms_agent.cli.plugin import PluginCMD from ms_agent.cli.run import RunCMD from ms_agent.cli.ui import UICMD @@ -20,6 +21,7 @@ def run_cmd(): RunCMD.define_args(subparsers) AppCMD.define_args(subparsers) UICMD.define_args(subparsers) + PluginCMD.define_args(subparsers) # unknown args will be handled in config.py args, _ = parser.parse_known_args() diff --git a/ms_agent/cli/plugin.py b/ms_agent/cli/plugin.py new file mode 100644 index 000000000..54e48c8b2 --- /dev/null +++ b/ms_agent/cli/plugin.py @@ -0,0 +1,229 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import argparse +import asyncio +import json +import os +from pathlib import Path + +from ms_agent.plugins.config_manager import PluginConfigManager +from ms_agent.plugins.installer import PluginInstaller, UnsupportedPluginSource +from ms_agent.plugins.runtime import PluginRuntime +from ms_agent.utils import get_logger + +from .base import CLICommand + +logger = get_logger() + + +def subparser_func(args): + return PluginCMD(args) + + +class PluginCMD(CLICommand): + """Install and manage MS-Agent community plugins.""" + + name = 'plugin' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: argparse.ArgumentParser): + parser: argparse.ArgumentParser = parsers.add_parser( + PluginCMD.name, + help='Install and manage plugins', + ) + subparsers = parser.add_subparsers( + dest='plugin_command', + required=True, + help='Plugin management commands', + ) + + install = subparsers.add_parser( + 'install', + help='Install a plugin from local path, github://, modelscope://, or marketplace alias', + ) + install.add_argument( + 'source', + help=( + 'Plugin source, e.g. ./path, github://org/repo@ref#subdir, ' + 'or hookify@claude-plugins-official' + ), + ) + install.add_argument( + '--scope', + choices=('global', 'project'), + default='global', + help='Install scope (default: global)', + ) + install.add_argument( + '--project-path', + default=None, + help='Project path for project-scoped install', + ) + install.add_argument( + '--link', + action='store_true', + help='Symlink local plugin sources instead of copying', + ) + install.add_argument( + '--force', + action='store_true', + help='Replace an existing managed plugin copy', + ) + install.add_argument( + '--disabled', + action='store_true', + help='Install but keep the plugin disabled', + ) + + list_cmd = subparsers.add_parser( + 'list', + help='List installed plugins', + ) + list_cmd.add_argument( + '--project-path', + default=None, + help='Project path for merged plugin listing', + ) + list_cmd.add_argument( + '--json', + action='store_true', + help='Print machine-readable JSON', + ) + + toggle = subparsers.add_parser( + 'toggle', + help='Enable or disable an installed plugin', + ) + toggle.add_argument('plugin_id') + toggle.add_argument( + '--enable', + action='store_true', + help='Enable the plugin (default action)', + ) + toggle.add_argument( + '--disable', + action='store_true', + help='Disable the plugin', + ) + toggle.add_argument( + '--scope', + choices=('global', 'project'), + default='global', + ) + toggle.add_argument( + '--project-path', + default=None, + ) + + uninstall = subparsers.add_parser( + 'uninstall', + help='Remove a plugin record', + ) + uninstall.add_argument('plugin_id') + uninstall.add_argument( + '--scope', + choices=('global', 'project'), + default='global', + ) + uninstall.add_argument( + '--purge', + action='store_true', + help='Delete managed plugin files', + ) + uninstall.add_argument( + '--project-path', + default=None, + ) + + parser.set_defaults(func=subparser_func) + + def execute(self): + command = self.args.plugin_command + if command == 'install': + self._install() + elif command == 'list': + self._list() + elif command == 'toggle': + asyncio.run(self._toggle()) + elif command == 'uninstall': + asyncio.run(self._uninstall()) + else: + raise SystemExit(f'Unknown plugin command: {command}') + + def _global_root(self) -> Path: + return Path(os.environ.get('MS_AGENT_HOME', '~/.ms_agent')).expanduser() + + def _project_path(self) -> str: + return self.args.project_path or os.getcwd() + + def _install(self) -> None: + global_root = self._global_root() + manager = PluginConfigManager(global_dir=global_root) + installer = PluginInstaller( + config_manager=manager, + global_root=global_root, + project_root=self._project_path(), + ) + try: + manifest = installer.install( + self.args.source, + scope=self.args.scope, + project_path=self._project_path(), + link=self.args.link, + force=self.args.force, + enabled=False if self.args.disabled else None, + ) + except UnsupportedPluginSource as exc: + raise SystemExit(str(exc)) from exc + + print( + f"Installed plugin '{manifest.plugin_id}' " + f"({manifest.format.value}) at {manifest.root}") + print(f"Capabilities: {', '.join(sorted(manifest.capabilities))}") + + def _list(self) -> None: + global_root = self._global_root() + runtime = PluginRuntime(global_root=global_root) + runtime.start_sync(self._project_path(), 'cli') + plugins = runtime.list_all() + if self.args.json: + print(json.dumps({'plugins': plugins}, indent=2)) + return + if not plugins: + print('No plugins installed.') + return + for item in plugins: + status = item.get('status', 'unknown') + enabled = 'enabled' if item.get('enabled') else 'disabled' + caps = ', '.join(item.get('capabilities') or []) + print( + f"- {item['plugin_id']} [{status}, {enabled}] " + f"caps={caps or 'none'}" + ) + + async def _toggle(self) -> None: + if self.args.disable and self.args.enable: + raise SystemExit('Use only one of --enable or --disable') + enabled = not self.args.disable + runtime = PluginRuntime(global_root=self._global_root()) + await runtime.toggle( + self.args.plugin_id, + enabled, + scope=self.args.scope, + project_path=self._project_path(), + ) + state = 'enabled' if enabled else 'disabled' + print(f"Plugin '{self.args.plugin_id}' {state}.") + + async def _uninstall(self) -> None: + runtime = PluginRuntime(global_root=self._global_root()) + await runtime.uninstall( + self.args.plugin_id, + scope=self.args.scope, + purge=self.args.purge, + ) + print(f"Plugin '{self.args.plugin_id}' uninstalled.") diff --git a/ms_agent/config/__init__.py b/ms_agent/config/__init__.py index fbab11f3d..1176d65c3 100644 --- a/ms_agent/config/__init__.py +++ b/ms_agent/config/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .config import Config from .env import Env +from .mcp_manager import MCPConfigManager +from .mcp_schema import ResolvedMCPConfig, normalize_mcp_server_entry from .resolver import ConfigResolver diff --git a/ms_agent/config/mcp_manager.py b/ms_agent/config/mcp_manager.py new file mode 100644 index 000000000..9c8aabec9 --- /dev/null +++ b/ms_agent/config/mcp_manager.py @@ -0,0 +1,262 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Persistent CRUD for global and project MCP server definitions.""" +from __future__ import annotations + +import copy +import json +from datetime import datetime, timezone +from pathlib import Path +from threading import Lock +from typing import Any, Dict, Literal, Optional + +from ms_agent.config.env import Env +from ms_agent.config.mcp_schema import normalize_mcp_server_entry + +MCPScope = Literal['global', 'project', 'merged'] + + +class MCPConfigManager: + """Global / project two-level MCP configuration persistence.""" + + def __init__( + self, + global_root: Path | str, + project_root: Path | str | None = None, + ): + self.global_root = Path(global_root).expanduser() + self.project_root = ( + Path(project_root).expanduser() if project_root else None + ) + self._lock = Lock() + + # ── paths ────────────────────────────────────────────────────────── + + @property + def global_settings_path(self) -> Path: + return self.global_root / 'settings.json' + + @property + def global_mcp_path(self) -> Path: + return self.global_root / 'mcp.json' + + @property + def project_mcp_path(self) -> Path: + if self.project_root is None: + raise ValueError('project_root is required for project scope') + return self.project_root / '.ms-agent' / 'mcp.json' + + def _ensure_dir(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + + # ── IO ───────────────────────────────────────────────────────────── + + def _read_json(self, path: Path) -> Dict[str, Any]: + if not path.is_file(): + return {} + with open(path, encoding='utf-8') as f: + return json.load(f) + + def _write_json(self, path: Path, data: Dict[str, Any]) -> None: + self._ensure_dir(path) + with open(path, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + def _load_scope_raw(self, scope: Literal['global', 'project']) -> Dict[str, Dict[str, Any]]: + if scope == 'global': + servers: Dict[str, Dict[str, Any]] = {} + settings = self._read_json(self.global_settings_path) + if isinstance(settings.get('mcp_servers'), dict): + servers.update(settings['mcp_servers']) + mcp_file = self._read_json(self.global_mcp_path) + file_servers = mcp_file.get('mcpServers', mcp_file) + if isinstance(file_servers, dict): + for name, entry in file_servers.items(): + servers.setdefault(name, entry) + return servers + + assert self.project_root is not None + mcp_file = self._read_json(self.project_mcp_path) + file_servers = mcp_file.get('mcpServers', mcp_file) + return dict(file_servers) if isinstance(file_servers, dict) else {} + + def _save_scope_raw( + self, + scope: Literal['global', 'project'], + servers: Dict[str, Dict[str, Any]], + ) -> None: + if scope == 'global': + # Keep settings.json mcp_servers in sync for WebUI compatibility. + settings = self._read_json(self.global_settings_path) + if not settings: + settings = {} + settings['mcp_servers'] = copy.deepcopy(servers) + self._write_json(self.global_settings_path, settings) + self._write_json(self.global_mcp_path, {'mcpServers': servers}) + return + + self._write_json(self.project_mcp_path, {'mcpServers': servers}) + + def _normalize_scope( + self, + servers: Dict[str, Dict[str, Any]], + *, + source: Literal['global', 'project'], + ) -> Dict[str, Dict[str, Any]]: + result: Dict[str, Dict[str, Any]] = {} + for name, entry in servers.items(): + normalized = normalize_mcp_server_entry(entry, source=source) + if normalized is not None: + result[name] = normalized + return result + + # ── CRUD ─────────────────────────────────────────────────────────── + + def list(self, scope: MCPScope = 'merged') -> Dict[str, Dict[str, Any]]: + with self._lock: + if scope == 'global': + return self._normalize_scope( + self._load_scope_raw('global'), source='global') + if scope == 'project': + return self._normalize_scope( + self._load_scope_raw('project'), source='project') + global_servers = self._normalize_scope( + self._load_scope_raw('global'), source='global') + project_servers = self._normalize_scope( + self._load_scope_raw('project'), source='project') + from ms_agent.config.mcp_schema import merge_mcp_layers + return merge_mcp_layers(global_servers, project_servers) + + def get(self, name: str, scope: MCPScope = 'merged') -> Optional[Dict[str, Any]]: + servers = self.list(scope) + entry = servers.get(name) + return copy.deepcopy(entry) if entry else None + + def add( + self, + name: str, + server: Dict[str, Any], + scope: Literal['global', 'project'] = 'project', + ) -> None: + with self._lock: + raw = self._load_scope_raw(scope) + entry = copy.deepcopy(server) + entry.setdefault('enabled', True) + entry.setdefault( + 'meta', + { + 'added_at': datetime.now(timezone.utc).isoformat(), + }, + ) + raw[name] = entry + self._save_scope_raw(scope, raw) + + def update( + self, + name: str, + patch: Dict[str, Any], + scope: Literal['global', 'project'] = 'project', + ) -> None: + with self._lock: + raw = self._load_scope_raw(scope) + if name not in raw: + raise KeyError(f'MCP server not found in {scope} scope: {name}') + merged = copy.deepcopy(raw[name]) + merged.update(copy.deepcopy(patch)) + raw[name] = merged + self._save_scope_raw(scope, raw) + + def remove(self, name: str, scope: Literal['global', 'project'] = 'project') -> None: + """Remove or mask a server. + + Project scope masks a global server (``enabled: false``) without + deleting the global definition. Global scope deletes the entry. + """ + with self._lock: + raw = self._load_scope_raw(scope) + if scope == 'project': + raw[name] = {'enabled': False, '_removed': True} + elif name in raw: + del raw[name] + else: + raise KeyError(f'MCP server not found in global scope: {name}') + self._save_scope_raw(scope, raw) + + def set_enabled( + self, + name: str, + enabled: bool, + scope: Literal['global', 'project'] = 'project', + ) -> None: + with self._lock: + raw = self._load_scope_raw(scope) + if name not in raw: + if scope == 'project': + raw[name] = {'enabled': enabled} + else: + raise KeyError(f'MCP server not found in {scope} scope: {name}') + else: + raw[name]['enabled'] = enabled + raw[name].pop('_removed', None) + self._save_scope_raw(scope, raw) + + # ── import / export ────────────────────────────────────────────────── + + def import_cursor_format(self, path: Path | str, merge: bool = True) -> int: + path = Path(path).expanduser() + data = self._read_json(path) + incoming = data.get('mcpServers', data) + if not isinstance(incoming, dict): + return 0 + with self._lock: + raw = self._load_scope_raw('global') if merge else {} + count = 0 + for name, entry in incoming.items(): + if not isinstance(entry, dict): + continue + raw[name] = copy.deepcopy(entry) + raw[name].setdefault('enabled', True) + count += 1 + self._save_scope_raw('global', raw) + return count + + def export_mcp_json( + self, + path: Path | str, + scope: MCPScope = 'merged', + *, + redact_secrets: bool = True, + ) -> None: + servers = self.list(scope) + if redact_secrets: + servers = self._redact_servers(servers) + self._write_json(Path(path).expanduser(), {'mcpServers': servers}) + + @staticmethod + def _redact_servers(servers: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + redacted: Dict[str, Dict[str, Any]] = {} + secret_keys = {'api_key', 'token', 'secret', 'password', 'authorization'} + for name, entry in servers.items(): + item = copy.deepcopy(entry) + env = item.get('env') + if isinstance(env, dict): + item['env'] = { + k: '***' if any(s in k.lower() for s in secret_keys) else v + for k, v in env.items() + } + headers = item.get('headers') + if isinstance(headers, dict): + item['headers'] = { + k: '***' if any(s in k.lower() for s in secret_keys) else v + for k, v in headers.items() + } + redacted[name] = item + return redacted + + def resolve_env(self, server: Dict[str, Any]) -> Dict[str, str]: + """Fill empty env values from ``Env.load_env()`` (same as MCPClient).""" + envs = Env.load_env() + env_dict = copy.deepcopy(server.get('env') or {}) + return { + key: value if value else envs.get(key, '') + for key, value in env_dict.items() + } diff --git a/ms_agent/config/mcp_schema.py b/ms_agent/config/mcp_schema.py new file mode 100644 index 000000000..3fa9dfefb --- /dev/null +++ b/ms_agent/config/mcp_schema.py @@ -0,0 +1,187 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Normalized MCP configuration schema and merge helpers.""" +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional + +from omegaconf import DictConfig, ListConfig, OmegaConf + +MCPSource = Literal['global', 'project', 'agent_yaml', 'plugin', 'session'] + +# Connection-related fields kept when normalizing agent.yaml / JSON entries. +MCP_CONNECTION_FIELDS = frozenset({ + 'enabled', + 'transport', + 'type', + 'command', + 'args', + 'url', + 'env', + 'headers', + 'timeout', + 'include', + 'exclude', + 'source', + 'meta', + 'session_kwargs', + 'httpx_client_factory', + 'encoding', + 'encoding_error_handler', + 'sse_read_timeout', +}) + +# YAML / agent metadata stripped during normalization. +MCP_STRIP_FIELDS = frozenset({ + 'mcp', + 'implementation', + 'trust_remote_code', + '_removed', +}) + + +@dataclass +class ResolvedMCPConfig: + """Normalized multi-layer MCP configuration.""" + + mcp_servers: Dict[str, Dict[str, Any]] = field(default_factory=dict) + + def to_mcp_json(self) -> Dict[str, Any]: + return {'mcpServers': copy.deepcopy(self.mcp_servers)} + + def enabled_servers(self) -> Dict[str, Dict[str, Any]]: + return { + name: cfg + for name, cfg in self.mcp_servers.items() + if cfg.get('enabled', True) + } + + +def _coerce_entry_dict(entry: Any) -> Optional[Dict[str, Any]]: + if isinstance(entry, (DictConfig, ListConfig)): + container = OmegaConf.to_container(entry, resolve=True) + return container if isinstance(container, dict) else None + if isinstance(entry, dict): + return entry + return None + + +def normalize_mcp_server_entry( + entry: Dict[str, Any], + *, + source: MCPSource = 'global', + default_enabled: bool = True, +) -> Optional[Dict[str, Any]]: + """Normalize a single MCP server entry for merge / runtime consumption. + + Returns ``None`` when the entry should not appear in ``mcpServers`` (e.g. + ``mcp: false`` in agent.yaml). + """ + if not entry: + return None + coerced = _coerce_entry_dict(entry) + if coerced is None: + return None + entry = coerced + if entry.get('mcp') is False: + return None + if entry.get('_removed'): + return {'enabled': False, 'source': source} + + normalized: Dict[str, Any] = {} + for key, value in entry.items(): + if key in MCP_STRIP_FIELDS: + continue + if key in MCP_CONNECTION_FIELDS: + normalized[key] = copy.deepcopy(value) + + if 'enabled' not in normalized: + normalized['enabled'] = default_enabled + normalized['source'] = normalized.get('source', source) + return normalized + + +def normalize_mcp_servers_layer( + servers: Optional[Dict[str, Any]], + *, + source: MCPSource, +) -> Dict[str, Dict[str, Any]]: + if not servers: + return {} + result: Dict[str, Dict[str, Any]] = {} + for name, entry in servers.items(): + coerced = _coerce_entry_dict(entry) + if coerced is None: + continue + normalized = normalize_mcp_server_entry(coerced, source=source) + if normalized is not None: + result[name] = normalized + return result + + +def merge_mcp_server_entry( + base: Dict[str, Any], + override: Dict[str, Any], +) -> Dict[str, Any]: + """Merge two server entries; ``override`` wins on explicit fields.""" + merged = copy.deepcopy(base) + for key, value in override.items(): + if key in MCP_STRIP_FIELDS: + continue + if key == 'meta' and isinstance(value, dict): + merged_meta = dict(merged.get('meta') or {}) + merged_meta.update(value) + merged['meta'] = merged_meta + else: + merged[key] = copy.deepcopy(value) + + # enabled: only override when explicitly set in the patch layer + if 'enabled' not in override and 'enabled' in base: + merged['enabled'] = base['enabled'] + return merged + + +def merge_mcp_layers(*layers: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """Union merge by server name; later layers override earlier ones.""" + merged: Dict[str, Dict[str, Any]] = {} + for layer in layers: + for name, entry in layer.items(): + if name in merged: + merged[name] = merge_mcp_server_entry(merged[name], entry) + else: + merged[name] = copy.deepcopy(entry) + return merged + + +def collect_builtin_tool_names( + agent_config: DictConfig | ListConfig | None, +) -> set[str]: + """Names declared as built-in tools (``mcp: false``) in agent.yaml. + + These entries must not appear in merged ``mcpServers`` even when a lower + layer (e.g. global settings) defines a same-named MCP server. The built-in + implementation is provided via ``ToolManager.extra_tools`` instead. + """ + if agent_config is None or not hasattr(agent_config, 'tools'): + return set() + tools = agent_config.tools + container = OmegaConf.to_container(tools, resolve=True) + if not isinstance(container, dict): + return set() + return { + name + for name, entry in container.items() + if isinstance(entry, dict) and entry.get('mcp') is False + } + + +def connection_params_for_client(server: Dict[str, Any]) -> Dict[str, Any]: + """Extract connect kwargs from a normalized server entry.""" + params: Dict[str, Any] = {} + for key in MCP_CONNECTION_FIELDS: + if key in ('enabled', 'source', 'meta'): + continue + if key in server: + params[key] = copy.deepcopy(server[key]) + return params diff --git a/ms_agent/config/resolver.py b/ms_agent/config/resolver.py index ae813838c..edcc32391 100644 --- a/ms_agent/config/resolver.py +++ b/ms_agent/config/resolver.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. """ConfigResolver — multi-layer config merging. Merges config from five layers (later wins): @@ -11,21 +12,32 @@ - Union by name, project-level overrides global on conflict - Each entry carries an `enabled` flag +Also provides MCP-specific resolution via MCPConfigManager (Playground F7). + This class does NOT replace Config.from_task(). CLI mode continues to use Config.from_task() directly. ConfigResolver is for server/UI scenarios where layered config is needed. """ from __future__ import annotations +import copy import json import os from copy import deepcopy from pathlib import Path from typing import Any, Dict, List, Optional, Union -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf +from ms_agent.config.mcp_manager import MCPConfigManager +from ms_agent.config.mcp_schema import ( + ResolvedMCPConfig, + collect_builtin_tool_names, + merge_mcp_layers, + normalize_mcp_servers_layer, +) from ms_agent.utils import get_logger +from ms_agent.plugins.config_manager import PluginConfigManager logger = get_logger() @@ -42,10 +54,29 @@ class ConfigResolver: - """Multi-layer config resolver for server/UI scenarios.""" + """Multi-layer config resolver for server/UI and Playground scenarios.""" + + def __init__( + self, + global_dir: Union[str, Path] = '~/.ms_agent', + project_root: Union[str, Path, None] = None, + agent_config: DictConfig | ListConfig | None = None, + mcp_manager: MCPConfigManager | None = None, + ) -> None: + self._global_dir = Path(global_dir).expanduser() + self.project_root = ( + Path(project_root).expanduser() if project_root else None + ) + self.agent_config = agent_config + self.mcp_manager = mcp_manager or MCPConfigManager( + self._global_dir, self.project_root) + self.plugin_manager = PluginConfigManager( + self._global_dir, self.project_root) - def __init__(self, global_dir: str = '~/.ms_agent') -> None: - self._global_dir = Path(os.path.expanduser(global_dir)) + @property + def global_root(self) -> Path: + """Alias for the global config directory (Playground MCP APIs).""" + return self._global_dir def resolve( self, @@ -73,13 +104,20 @@ def resolve( if global_settings: layers.append(global_settings) - if agent_config is not None: - if isinstance(agent_config, str): - agent_config = OmegaConf.load(agent_config) - layers.append(agent_config) + effective_agent_config = ( + agent_config if agent_config is not None else self.agent_config + ) + if effective_agent_config is not None: + if isinstance(effective_agent_config, str): + effective_agent_config = OmegaConf.load(effective_agent_config) + layers.append(effective_agent_config) - if project_path: - project_patch = self._load_project_patch(project_path) + effective_project_path = project_path + if effective_project_path is None and self.project_root is not None: + effective_project_path = str(self.project_root) + + if effective_project_path: + project_patch = self._load_project_patch(effective_project_path) if project_patch: layers.append(project_patch) @@ -88,14 +126,94 @@ def resolve( merged = self._merge_layers(layers) - merged = self._merge_mcp(merged, project_path) - merged = self._merge_skills(merged, project_path) + merged = self._merge_mcp(merged, effective_project_path) + merged = self._merge_skills(merged, effective_project_path) + merged = self._merge_plugins(merged, effective_project_path) from ms_agent.config.config import Config merged = Config.fill_missing_fields(merged) return merged + def resolve_mcp( + self, + session_id: str | None = None, + session_override: Dict[str, Dict[str, Any]] | None = None, + ) -> ResolvedMCPConfig: + """Merge framework → global → agent.yaml → project → session MCP layers.""" + del session_id # reserved for Phase 3 session.json + + from ms_agent.config.config import Config + + global_layer = normalize_mcp_servers_layer( + self.mcp_manager.list('global'), + source='global', + ) + agent_yaml_layer: Dict[str, Dict[str, Any]] = {} + if self.agent_config is not None: + raw = Config.convert_mcp_servers_to_json(self.agent_config) + agent_yaml_layer = normalize_mcp_servers_layer( + raw.get('mcpServers'), + source='agent_yaml', + ) + + project_layer = normalize_mcp_servers_layer( + self.mcp_manager.list('project'), + source='project', + ) + + session_layer = normalize_mcp_servers_layer( + session_override, + source='session', + ) + + merged = merge_mcp_layers( + {}, + global_layer, + agent_yaml_layer, + project_layer, + session_layer, + ) + for name in collect_builtin_tool_names(self.agent_config): + merged.pop(name, None) + return ResolvedMCPConfig(mcp_servers=merged) + + def resolve_mcp_all_layers( + self, + session_override: Dict[str, Dict[str, Any]] | None = None, + ) -> Dict[str, Dict[str, Any]]: + """Return merged servers including disabled entries (for UI listing).""" + from ms_agent.config.config import Config + + global_layer = normalize_mcp_servers_layer( + self.mcp_manager.list('global'), source='global') + agent_yaml_layer: Dict[str, Dict[str, Any]] = {} + if self.agent_config is not None: + raw = Config.convert_mcp_servers_to_json(self.agent_config) + agent_yaml_layer = normalize_mcp_servers_layer( + raw.get('mcpServers'), source='agent_yaml') + project_layer = normalize_mcp_servers_layer( + self.mcp_manager.list('project'), source='project') + session_layer = normalize_mcp_servers_layer( + session_override, source='session') + merged = merge_mcp_layers( + global_layer, + agent_yaml_layer, + project_layer, + session_layer, + ) + for name in collect_builtin_tool_names(self.agent_config): + merged.pop(name, None) + return merged + + def with_agent_config( + self, + agent_config: DictConfig | ListConfig | None, + ) -> 'ConfigResolver': + clone = copy.copy(self) + clone.agent_config = agent_config + return clone + # -- layer loading -- def _load_framework_defaults(self) -> DictConfig: @@ -156,6 +274,34 @@ def _merge_mcp( OmegaConf.update(config, '_merged_mcp', merged, merge=True) return config + def _merge_plugins( + self, config: DictConfig, project_path: Optional[str] + ) -> DictConfig: + manager = ( + PluginConfigManager(self._global_dir, project_path) + if project_path + else self.plugin_manager + ) + records = manager.load_merged(project_path) + if not records: + return config + + payload = {'plugins': [record.to_dict() | {'scope': record.scope} + for record in records]} + OmegaConf.update(config, '_merged_plugins', payload, merge=True) + + enabled_paths = [ + record.path for record in records + if record.enabled and record.path + ] + existing = [] + if hasattr(config, 'plugins') and config.plugins: + existing = [str(item) for item in config.plugins] + merged_paths = existing + [p for p in enabled_paths if p not in existing] + if merged_paths: + OmegaConf.update(config, 'plugins', merged_paths, merge=True) + return config + def _merge_skills( self, config: DictConfig, project_path: Optional[str] ) -> DictConfig: diff --git a/ms_agent/hooks/__init__.py b/ms_agent/hooks/__init__.py new file mode 100644 index 000000000..55c58a761 --- /dev/null +++ b/ms_agent/hooks/__init__.py @@ -0,0 +1,29 @@ +"""Hooks system — shell-based lifecycle hooks with multi-platform config support.""" + +from ms_agent.hooks.bridge import CallbackToHookBridge +from ms_agent.hooks.context import ( + HookAttachment, + apply_hook_result_to_messages, + append_stop_blocking_feedback, + condense_hook_attachments_for_llm, + extract_latest_user_prompt, +) +from ms_agent.hooks.events import HookResult +from ms_agent.hooks.factory import build_hook_runtime +from ms_agent.hooks.registry import HookHandlerConfig, HookRegistry, MatcherGroup +from ms_agent.hooks.runtime import HookRuntime + +__all__ = [ + 'CallbackToHookBridge', + 'HookAttachment', + 'HookHandlerConfig', + 'HookRegistry', + 'HookResult', + 'HookRuntime', + 'MatcherGroup', + 'apply_hook_result_to_messages', + 'append_stop_blocking_feedback', + 'build_hook_runtime', + 'condense_hook_attachments_for_llm', + 'extract_latest_user_prompt', +] diff --git a/ms_agent/hooks/bridge.py b/ms_agent/hooks/bridge.py new file mode 100644 index 000000000..70e2b63ca --- /dev/null +++ b/ms_agent/hooks/bridge.py @@ -0,0 +1,24 @@ +"""CallbackToHookBridge — SessionStart only.""" + +from __future__ import annotations + +from omegaconf import DictConfig +from typing import List + +from ms_agent.agent.runtime import Runtime +from ms_agent.callbacks.base import Callback +from ms_agent.hooks.runtime import HookRuntime +from ms_agent.llm.utils import Message + + +class CallbackToHookBridge(Callback): + def __init__(self, config: DictConfig, hook_runtime: HookRuntime) -> None: + super().__init__(config) + self._hooks = hook_runtime + + async def on_task_begin( + self, + runtime: Runtime, + messages: List[Message], + ) -> None: + await self._hooks.run_session_start(runtime, messages) diff --git a/ms_agent/hooks/context.py b/ms_agent/hooks/context.py new file mode 100644 index 000000000..5195f51d3 --- /dev/null +++ b/ms_agent/hooks/context.py @@ -0,0 +1,107 @@ +"""Hook attachment types and message integration.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Union + +from ms_agent.hooks.events import HookResult +from ms_agent.llm.utils import Message + + +@dataclass(frozen=True) +class HookAttachment: + type: Literal[ + 'hook_additional_context', + 'hook_blocking_feedback', + 'hook_stopped_continuation', + ] + hook_event: str + tool_call_id: str | None + content: Union[str, list[str]] + + +def _append_hook_attachment( + messages: list[Message], + attachment: HookAttachment, +) -> None: + if not messages: + return + last = messages[-1] + if not hasattr(last, 'hook_attachments') or last.hook_attachments is None: + last.hook_attachments = [] + last.hook_attachments.append(attachment) + + +def append_stop_blocking_feedback( + messages: list[Message], + reason: str, +) -> None: + """Attach Stop hook block feedback to the assistant turn (§8.5 / §9.4).""" + if not messages: + return + assistant = messages[-1] + attachment = HookAttachment( + type='hook_blocking_feedback', + hook_event='Stop', + tool_call_id=None, + content=reason or '', + ) + if not hasattr(assistant, 'hook_attachments') or assistant.hook_attachments is None: + assistant.hook_attachments = [] + assistant.hook_attachments.append(attachment) + + +def apply_hook_result_to_messages( + messages: list[Message], + result: HookResult, + *, + hook_event: str, + tool_call_id: str | None = None, +) -> bool: + """Return False when caller should abort (UserPromptSubmit deny).""" + if result.action in ('deny', 'block') and hook_event == 'UserPromptSubmit': + return False + if result.additional_context: + _append_hook_attachment( + messages, + HookAttachment( + type='hook_additional_context', + hook_event=hook_event, + tool_call_id=tool_call_id, + content=result.additional_context, + ), + ) + return True + + +def _attachment_to_meta_message(att: HookAttachment) -> Message: + content = att.content if isinstance(att.content, str) else '\n'.join( + att.content) + if att.type == 'hook_blocking_feedback': + return Message(role='user', content=f'Stop hook feedback:\n{content}') + if att.type == 'hook_stopped_continuation': + return Message(role='user', content=f'[hook:{att.hook_event}]\n{content}') + prefix = f'[hook:{att.hook_event}]' + return Message(role='user', content=f'{prefix}\n{content}') + + +def condense_hook_attachments_for_llm(messages: list[Message]) -> list[Message]: + """Convert hook_attachments into meta user messages for the LLM.""" + out: list[Message] = [] + for msg in messages: + out.append(msg) + attachments = getattr(msg, 'hook_attachments', None) or [] + for att in attachments: + out.append(_attachment_to_meta_message(att)) + if attachments: + msg.hook_attachments = [] + return out + + +def extract_latest_user_prompt(messages: list[Message]) -> str: + for msg in reversed(messages): + if msg.role == 'user': + return msg.content if isinstance(msg.content, str) else str( + msg.content) + return '' diff --git a/ms_agent/hooks/events.py b/ms_agent/hooks/events.py new file mode 100644 index 000000000..5e4d1d445 --- /dev/null +++ b/ms_agent/hooks/events.py @@ -0,0 +1,64 @@ +"""Canonical hook events and unified result envelope.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class SessionStartEvent: + session_id: str + project_path: str = '' + event: str = field(default='SessionStart', init=False) + + +@dataclass(frozen=True) +class PreToolUseEvent: + session_id: str + tool_name: str + tool_args: dict[str, Any] = field(default_factory=dict) + event: str = field(default='PreToolUse', init=False) + + +@dataclass(frozen=True) +class PostToolUseEvent: + session_id: str + tool_name: str + tool_args: dict[str, Any] = field(default_factory=dict) + tool_result: str = '' + event: str = field(default='PostToolUse', init=False) + + +@dataclass(frozen=True) +class UserPromptSubmitEvent: + session_id: str + prompt: str + event: str = field(default='UserPromptSubmit', init=False) + + +@dataclass(frozen=True) +class StopEvent: + session_id: str + reason: str = '' + last_assistant_message: str = '' + stop_hook_active: bool = False + event: str = field(default='Stop', init=False) + + +@dataclass(frozen=True) +class PermissionRequestEvent: + session_id: str + tool_name: str + tool_args: dict[str, Any] = field(default_factory=dict) + event: str = field(default='PermissionRequest', init=False) + + +@dataclass(frozen=True) +class HookResult: + action: str # allow | deny | ask | block | pass | error + reason: str = '' + additional_context: str = '' + updated_args: dict[str, Any] | None = None + exit_code: int = 0 + stderr: str = '' diff --git a/ms_agent/hooks/executor.py b/ms_agent/hooks/executor.py new file mode 100644 index 000000000..5de4071d7 --- /dev/null +++ b/ms_agent/hooks/executor.py @@ -0,0 +1,141 @@ +"""Hook executor dispatcher.""" + +from __future__ import annotations + +import time +from collections.abc import Awaitable, Callable +from typing import Any + +from ms_agent.hooks.events import HookResult +from ms_agent.hooks.executors.command import CommandHookExecutor, HookExecutionContext +from ms_agent.hooks.registry import HookHandlerConfig +from ms_agent.hooks.response_adapter import ResponseAdapter + +OnHandlerComplete = Callable[ + [HookHandlerConfig, HookResult, float], + Awaitable[None], +] + + +class HookExecutor: + """Route hook handlers to backends by type.""" + + def __init__( + self, + working_dir: str | None = None, + *, + command: CommandHookExecutor | None = None, + response_adapter: ResponseAdapter | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + fail_closed: bool = False, + ) -> None: + adapter = response_adapter or ResponseAdapter() + self._backends: dict[str, Any] = {} + if 'command' in enabled_executors: + self._backends['command'] = command or CommandHookExecutor( + working_dir=working_dir, + response_adapter=adapter, + fail_closed=fail_closed, + ) + self._ctx: HookExecutionContext | None = None + + def set_context(self, ctx: HookExecutionContext) -> None: + self._ctx = ctx + + async def execute( + self, + handler: HookHandlerConfig, + event_data: dict[str, Any], + ctx: HookExecutionContext | None = None, + ) -> HookResult: + backend = self._backends.get(handler.type) + if backend is None: + return HookResult( + action='error', + reason=f"Hook type '{handler.type}' not enabled", + ) + return await backend.execute(handler, event_data, ctx or self._ctx) + + async def execute_all( + self, + handlers: list[HookHandlerConfig], + event_data: dict[str, Any], + *, + blockable: bool = False, + ctx: HookExecutionContext | None = None, + on_handler_complete: OnHandlerComplete | None = None, + ) -> HookResult: + merged_context_parts: list[str] = [] + final_updated_args: dict[str, Any] | None = None + aggregated_action: str | None = None + exec_ctx = ctx or self._ctx + + def _merge_action(current: str | None, new: str) -> str | None: + if new in ('deny', 'block'): + return 'deny' + if new == 'ask' and current != 'deny': + return 'ask' + if new == 'allow' and current is None: + return 'allow' + return current + + for handler in handlers: + started = time.perf_counter() + result = await self.execute( + handler, + event_data, + _context_for_handler(exec_ctx, handler), + ) + duration_ms = (time.perf_counter() - started) * 1000.0 + if on_handler_complete is not None: + await on_handler_complete(handler, result, duration_ms) + + if result.additional_context: + merged_context_parts.append(result.additional_context) + + if blockable and result.action in ('deny', 'block'): + return HookResult( + action=result.action if result.action == 'block' else 'deny', + reason=result.reason, + additional_context='\n'.join(merged_context_parts), + updated_args=result.updated_args or final_updated_args, + exit_code=result.exit_code, + stderr=result.stderr, + ) + + aggregated_action = _merge_action(aggregated_action, result.action) + + if result.updated_args is not None: + final_updated_args = result.updated_args + event_data = {**event_data, 'tool_args': result.updated_args} + if 'tool_input' in event_data: + event_data['tool_input'] = result.updated_args + + if aggregated_action is None: + aggregated_action = 'pass' + + return HookResult( + action=aggregated_action, + additional_context='\n'.join(merged_context_parts), + updated_args=final_updated_args, + ) + + +def _context_for_handler( + ctx: HookExecutionContext | None, + handler: HookHandlerConfig, +) -> HookExecutionContext | None: + if ctx is None: + return None + if not handler.source_plugin_root and not handler.source_plugin_data_dir: + return ctx + return HookExecutionContext( + session_id=ctx.session_id, + project_path=ctx.project_path, + plugin_root=handler.source_plugin_root or ctx.plugin_root, + plugin_data_dir=handler.source_plugin_data_dir or ctx.plugin_data_dir, + llm=ctx.llm, + messages=ctx.messages, + abort_signal=ctx.abort_signal, + tool_manager=ctx.tool_manager, + ) diff --git a/ms_agent/hooks/executors/__init__.py b/ms_agent/hooks/executors/__init__.py new file mode 100644 index 000000000..265c5b64f --- /dev/null +++ b/ms_agent/hooks/executors/__init__.py @@ -0,0 +1,13 @@ +"""Hook executor backends.""" + +from ms_agent.hooks.executors.command import ( + CommandHookExecutor, + HookExecutionContext, + build_hook_env, +) + +__all__ = [ + 'CommandHookExecutor', + 'HookExecutionContext', + 'build_hook_env', +] diff --git a/ms_agent/hooks/executors/command.py b/ms_agent/hooks/executors/command.py new file mode 100644 index 000000000..76bcb3ac5 --- /dev/null +++ b/ms_agent/hooks/executors/command.py @@ -0,0 +1,152 @@ +"""Command hook executor — subprocess stdin/stdout protocol.""" + +from __future__ import annotations + +import asyncio +import json +import os +import shlex +from dataclasses import dataclass +from typing import Any + +from ms_agent.hooks.events import HookResult +from ms_agent.hooks.registry import HookHandlerConfig +from ms_agent.hooks.response_adapter import ResponseAdapter +from ms_agent.utils import get_logger + +logger = get_logger() + + +@dataclass +class HookExecutionContext: + session_id: str + project_path: str + plugin_root: str | None = None + plugin_data_dir: str | None = None + llm: Any | None = None + messages: list | None = None + abort_signal: asyncio.Event | None = None + tool_manager: Any | None = None + + +def plugin_compat_payload( + event_data: dict[str, Any], + ctx: HookExecutionContext | None, +) -> dict[str, Any]: + """Adapt MS-Agent hook payloads for Claude-format plugin scripts.""" + if ctx is None or not ctx.plugin_root: + return event_data + payload = dict(event_data) + claude_tool = payload.get('tool_name_claude') + if claude_tool: + payload['tool_name'] = claude_tool + payload.setdefault( + 'hook_event_name', + payload.get('event') or payload.get('hook_event_name', ''), + ) + if payload.get('event') == 'UserPromptSubmit': + payload.setdefault('user_prompt', payload.get('prompt', '')) + return payload + + +def build_hook_env(ctx: HookExecutionContext) -> dict[str, str]: + env = dict(os.environ) + env['MS_AGENT_PROJECT_DIR'] = ctx.project_path + env['CLAUDE_PROJECT_DIR'] = ctx.project_path + if ctx.plugin_root: + env['MS_AGENT_PLUGIN_ROOT'] = ctx.plugin_root + env['CLAUDE_PLUGIN_ROOT'] = ctx.plugin_root + if ctx.plugin_data_dir: + env['MS_AGENT_PLUGIN_DATA'] = ctx.plugin_data_dir + env['CLAUDE_PLUGIN_DATA'] = ctx.plugin_data_dir + if ctx.session_id: + env['MS_AGENT_SESSION_ID'] = ctx.session_id + return env + + +class CommandHookExecutor: + def __init__( + self, + working_dir: str | None = None, + response_adapter: ResponseAdapter | None = None, + fail_closed: bool = False, + ) -> None: + self._working_dir = working_dir + self._response_adapter = response_adapter or ResponseAdapter() + self._fail_closed = fail_closed + + async def execute( + self, + handler: HookHandlerConfig, + event_data: dict[str, Any], + ctx: HookExecutionContext, + ) -> HookResult: + payload = plugin_compat_payload(event_data, ctx) + stdin_data = json.dumps(payload, ensure_ascii=False).encode('utf-8') + proc = None + try: + proc = await asyncio.create_subprocess_exec( + *shlex.split(handler.command or ''), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self._working_dir, + env=build_hook_env(ctx), + ) + stdout, stderr = await asyncio.wait_for( + proc.communicate(input=stdin_data), + timeout=handler.timeout, + ) + except asyncio.TimeoutError: + if proc is not None: + try: + proc.kill() + await proc.wait() + except ProcessLookupError: + pass + reason = f'Hook timed out after {handler.timeout}s' + if handler.fail_closed or self._fail_closed: + return HookResult(action='deny', reason=reason, exit_code=-1) + return HookResult(action='error', reason=reason, exit_code=-1) + except FileNotFoundError: + reason = f'Hook command not found: {handler.command}' + if handler.fail_closed or self._fail_closed: + return HookResult(action='deny', reason=reason, exit_code=-1) + return HookResult(action='error', reason=reason, exit_code=-1) + except Exception as e: + reason = str(e) + if handler.fail_closed or self._fail_closed: + return HookResult(action='deny', reason=reason, exit_code=-1) + return HookResult(action='error', reason=reason, exit_code=-1) + + exit_code = proc.returncode or 0 + stderr_text = stderr.decode('utf-8', errors='replace').strip() + stdout_text = stdout.decode('utf-8', errors='replace').strip() + + if exit_code == 2: + return HookResult( + action='deny', + reason=stderr_text or 'Blocked by hook', + exit_code=exit_code, + stderr=stderr_text, + ) + + if exit_code != 0: + logger.warning( + "Hook '%s' exited %d: %s", handler.command, exit_code, stderr_text) + if handler.fail_closed or self._fail_closed: + return HookResult( + action='deny', + reason=stderr_text or f'Hook exited {exit_code}', + exit_code=exit_code, + stderr=stderr_text, + ) + return HookResult( + action='error', + reason=stderr_text, + exit_code=exit_code, + stderr=stderr_text, + ) + + return self._response_adapter.parse( + stdout_text, exit_code, stderr_text, event_data.get('event')) diff --git a/ms_agent/hooks/factory.py b/ms_agent/hooks/factory.py new file mode 100644 index 000000000..0327bf202 --- /dev/null +++ b/ms_agent/hooks/factory.py @@ -0,0 +1,248 @@ +"""Build HookRuntime from agent config and multi-source loaders.""" + +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, OmegaConf + +from ms_agent.hooks.executor import HookExecutor +from ms_agent.hooks.loaders.claude import ClaudeSettingsLoader +from ms_agent.hooks.loaders.cursor import CursorHooksLoader +from ms_agent.hooks.loaders.hermes import HermesShellLoader +from ms_agent.hooks.loaders.native import NativeJsonLoader, NativeYamlLoader +from ms_agent.hooks.loaders.plugin import PluginHooksLoader +from ms_agent.hooks.registry import HookRegistry +from ms_agent.hooks.runtime import HookRuntime +from ms_agent.hooks.tool_name_mapper import ToolNameMapper +from ms_agent.utils import get_logger + +logger = get_logger() + +_EMPTY_REGISTRY = HookRegistry(_index={}) + + +def _empty_runtime(project_path: str = '', session_id: str = '') -> HookRuntime: + return HookRuntime( + registry=_EMPTY_REGISTRY, + executor=HookExecutor(working_dir=project_path or None), + session_id=session_id or str(uuid.uuid4()), + project_path=project_path, + tool_name_mapper=ToolNameMapper(), + ) + + +def _parse_hooks_meta(raw: dict[str, Any]) -> tuple[frozenset[str], frozenset[str], bool, str]: + enabled_sources = frozenset( + raw.get('enabled_sources', ['native']) or ['native']) + enabled_executors = frozenset( + raw.get('enabled_executors', ['command']) or ['command']) + fail_closed = bool(raw.get('fail_closed', False)) + default_model = str(raw.get('default_model', 'qwen-plus')) + return enabled_sources, enabled_executors, fail_closed, default_model + + +def build_hook_runtime( + config: DictConfig | Any, + *, + session_id: str | None = None, + plugin_hook_registries: list[Any] | None = None, +) -> HookRuntime: + """Construct HookRuntime; returns empty runtime when hooks are not configured.""" + from ms_agent.utils.workspace_context import resolve_workspace_root + + project_path = str(resolve_workspace_root(config)) + global_ms_agent_dir = str(Path.home() / '.ms_agent') + sid = session_id or str(uuid.uuid4()) + + raw_hooks: dict[str, Any] = {} + if hasattr(config, 'hooks') and config.hooks: + raw_hooks = OmegaConf.to_container(config.hooks, resolve=True) or {} + + enabled_sources, enabled_executors, fail_closed, default_model = _parse_hooks_meta( + raw_hooks) + + registry = HookRegistry(_index={}) + loaders: list[tuple[str, HookRegistry]] = [] + + # Priority order (low -> high), aligned with §5.3 + if 'native' in enabled_sources: + global_native = Path(global_ms_agent_dir) / 'hooks.yaml' + if global_native.is_file(): + loaders.append(( + 'global_native', + NativeYamlLoader.load_file( + global_native, + enabled_executors=enabled_executors, + ), + )) + + if 'claude' in enabled_sources: + claude_global = Path.home() / '.claude' / 'settings.json' + if claude_global.is_file(): + loaders.append(( + 'claude_global', + ClaudeSettingsLoader.load_file( + claude_global, + project_path, + enabled_executors=enabled_executors, + ), + )) + claude_project = Path(project_path) / '.claude' / 'settings.json' + if claude_project.is_file(): + loaders.append(( + 'claude_project', + ClaudeSettingsLoader.load_file( + claude_project, + project_path, + enabled_executors=enabled_executors, + ), + )) + + if 'cursor' in enabled_sources: + cursor_global = Path.home() / '.cursor' / 'hooks.json' + if cursor_global.is_file(): + loaders.append(( + 'cursor_global', + CursorHooksLoader.load_file( + cursor_global, + project_path, + enabled_executors=enabled_executors, + ), + )) + cursor_project = Path(project_path) / '.cursor' / 'hooks.json' + if cursor_project.is_file(): + loaders.append(( + 'cursor_project', + CursorHooksLoader.load_file( + cursor_project, + project_path, + enabled_executors=enabled_executors, + ), + )) + + # agent.yaml hooks section (without meta keys) — native source, §5.3 priority 6 + if 'native' in enabled_sources: + event_hooks = { + k: v for k, v in raw_hooks.items() + if k not in ( + 'enabled_sources', 'enabled_executors', 'default_model', + 'fail_closed', 'allowed_http_hook_urls', + 'http_hook_allowed_env_vars', + ) + } + if event_hooks: + loaders.append(( + 'agent_yaml', + HookRegistry.from_dict( + event_hooks, + enabled_executors=enabled_executors, + source='agent.yaml', + ), + )) + + if 'native' in enabled_sources: + ms_agent_hooks_json = Path(project_path) / '.ms-agent' / 'hooks.json' + if ms_agent_hooks_json.is_file(): + loaders.append(( + 'ms_agent_json', + NativeJsonLoader.load_file( + ms_agent_hooks_json, + enabled_executors=enabled_executors, + ), + )) + + if 'plugin' in enabled_sources: + if plugin_hook_registries is not None: + for contrib in plugin_hook_registries: + if not contrib.registry.is_empty: + loaders.append((f'plugin:{contrib.plugin_id}', contrib.registry)) + else: + plugin_roots = _discover_plugin_roots(config, project_path) + seen_plugin_ids: set[str] = set() + for root in plugin_roots: + plugin_id = Path(root).name + if plugin_id in seen_plugin_ids: + continue + seen_plugin_ids.add(plugin_id) + plugin_data_dir = Path.home() / '.ms_agent' / 'plugins' / 'data' / plugin_id + reg = PluginHooksLoader.load_plugin( + root, + project_path=project_path, + plugin_data_dir=plugin_data_dir, + enabled_executors=enabled_executors, + ) + if not reg.is_empty: + reg = reg.with_plugin_source( + plugin_id=Path(root).name, + plugin_root=str(root), + plugin_data_dir=str(plugin_data_dir), + ) + loaders.append((f'plugin:{root}', reg)) + + if 'hermes' in enabled_sources: + hermes_cfg = Path.home() / '.hermes' / 'config.yaml' + if hermes_cfg.is_file(): + loaders.append(( + 'hermes', + HermesShellLoader.load_file(hermes_cfg, project_path), + )) + + for _name, reg in loaders: + registry = registry.merge(reg) + + if registry.is_empty: + return _empty_runtime(project_path, sid) + + working_dir = getattr(config, 'local_dir', None) or project_path + executor = HookExecutor( + working_dir=working_dir, + enabled_executors=enabled_executors, + fail_closed=fail_closed, + ) + + return HookRuntime( + registry=registry, + executor=executor, + session_id=sid, + project_path=project_path, + tool_name_mapper=ToolNameMapper(enabled_sources=enabled_sources), + default_model=default_model, + ) + + +def _discover_plugin_roots(config: Any, project_path: str) -> list[str]: + from ms_agent.plugins.registry import PluginRegistry + + registry = PluginRegistry() + managed_paths = registry.managed_plugin_paths(project_path) + managed_ids = registry.managed_plugin_ids(project_path) + roots: list[str] = [] + seen: set[str] = set() + plugins_dir = Path(project_path) / '.ms-agent' / 'plugins' + if plugins_dir.is_dir(): + for child in plugins_dir.iterdir(): + if not child.is_dir(): + continue + resolved = str(child.resolve()) + if resolved in seen: + continue + seen.add(resolved) + roots.append(str(child)) + if hasattr(config, 'plugins'): + for p in (config.plugins or []): + path = Path(str(p)) + if not path.is_absolute(): + path = Path(project_path) / path + if not path.is_dir(): + continue + resolved = str(path.resolve()) + if resolved in seen or resolved in managed_paths: + continue + if path.name in managed_ids: + continue + seen.add(resolved) + roots.append(str(path)) + return roots diff --git a/ms_agent/hooks/loaders/__init__.py b/ms_agent/hooks/loaders/__init__.py new file mode 100644 index 000000000..c518930c1 --- /dev/null +++ b/ms_agent/hooks/loaders/__init__.py @@ -0,0 +1,16 @@ +"""Hook loaders package.""" + +from ms_agent.hooks.loaders.claude import ClaudeSettingsLoader +from ms_agent.hooks.loaders.cursor import CursorHooksLoader +from ms_agent.hooks.loaders.hermes import HermesShellLoader +from ms_agent.hooks.loaders.native import NativeJsonLoader, NativeYamlLoader +from ms_agent.hooks.loaders.plugin import PluginHooksLoader + +__all__ = [ + 'ClaudeSettingsLoader', + 'CursorHooksLoader', + 'HermesShellLoader', + 'NativeJsonLoader', + 'NativeYamlLoader', + 'PluginHooksLoader', +] diff --git a/ms_agent/hooks/loaders/claude.py b/ms_agent/hooks/loaders/claude.py new file mode 100644 index 000000000..916ca8369 --- /dev/null +++ b/ms_agent/hooks/loaders/claude.py @@ -0,0 +1,164 @@ +"""Claude Code settings.json hook loader.""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +from ms_agent.hooks.registry import HookRegistry, _parse_hook_handler, MatcherGroup +from ms_agent.hooks.tool_name_mapper import ToolNameMapper +from ms_agent.utils import get_logger + +logger = get_logger() + +_CLAUDE_EVENT_MAP = { + 'SessionStart': 'SessionStart', + 'UserPromptSubmit': 'UserPromptSubmit', + 'PreToolUse': 'PreToolUse', + 'PostToolUse': 'PostToolUse', + 'Stop': 'Stop', + 'SubagentStop': 'SubagentStop', + 'PermissionRequest': 'PermissionRequest', +} + + +class ClaudeSettingsLoader: + @staticmethod + def load_file( + path: Path | str, + project_path: str, + *, + plugin_root: str | None = None, + plugin_data_dir: str | None = None, + user_config: dict[str, Any] | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + with open(path, encoding='utf-8') as f: + data = json.load(f) + hooks = data.get('hooks', {}) + return ClaudeSettingsLoader.parse_hooks( + hooks, + project_path, + plugin_root=plugin_root, + plugin_data_dir=plugin_data_dir, + user_config=user_config, + enabled_executors=enabled_executors, + ) + + @staticmethod + def parse_hooks_file( + path: Path | str, + *, + plugin_root: str | None = None, + plugin_data_dir: str | None = None, + user_config: dict[str, Any] | None = None, + project_path: str = '', + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + with open(path, encoding='utf-8') as f: + data = json.load(f) + hooks = data.get('hooks', data) + return ClaudeSettingsLoader.parse_hooks( + hooks, + project_path, + plugin_root=plugin_root, + plugin_data_dir=plugin_data_dir, + user_config=user_config, + enabled_executors=enabled_executors, + ) + + @staticmethod + def parse_hooks( + hooks: dict[str, Any], + project_path: str, + *, + plugin_root: str | None = None, + plugin_data_dir: str | None = None, + user_config: dict[str, Any] | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + if not hooks: + return HookRegistry(_index={}) + + mapper = ToolNameMapper(enabled_sources=frozenset({'claude'})) + index: dict[str, tuple[MatcherGroup, ...]] = {} + + for event_name, groups_raw in hooks.items(): + canonical = _CLAUDE_EVENT_MAP.get(event_name) + if not canonical or canonical not in HookRegistry.VALID_EVENTS: + logger.warning('Skipping unknown Claude hook event: %s', event_name) + continue + + groups = [] + for g in (groups_raw or []): + matcher = g.get('matcher') + if matcher and canonical in HookRegistry.TOOL_EVENTS: + matcher = mapper.external_matcher_to_native(matcher, 'claude') + matcher = _expand_path_vars( + matcher, project_path, plugin_root, plugin_data_dir, + user_config) + + hooks_raw = g.get('hooks', []) + handlers = [] + for h in hooks_raw: + h = _expand_command_vars( + h, project_path, plugin_root, plugin_data_dir, + user_config) + t = h.get('type', 'command') or 'command' + if t not in enabled_executors: + logger.warning( + 'Claude hook type %s not in enabled_executors %s, skipping', + t, + sorted(enabled_executors), + ) + continue + parsed = _parse_hook_handler(h) + if parsed: + handlers.append(parsed) + if handlers: + groups.append(MatcherGroup( + matcher=matcher if canonical in HookRegistry.TOOL_EVENTS else None, + hooks=tuple(handlers), + )) + if groups: + index[canonical] = tuple(groups) + + return HookRegistry(_index=index) + + +def _expand_path_vars( + value: str, + project_path: str, + plugin_root: str | None, + plugin_data_dir: str | None = None, + user_config: dict[str, Any] | None = None, +) -> str: + value = value.replace('${CLAUDE_PROJECT_DIR}', project_path) + value = value.replace('${MS_AGENT_PROJECT_DIR}', project_path) + if plugin_root: + value = value.replace('${CLAUDE_PLUGIN_ROOT}', plugin_root) + value = value.replace('${MS_AGENT_PLUGIN_ROOT}', plugin_root) + if plugin_data_dir: + value = value.replace('${CLAUDE_PLUGIN_DATA}', plugin_data_dir) + value = value.replace('${MS_AGENT_PLUGIN_DATA}', plugin_data_dir) + for key, item in (user_config or {}).items(): + value = value.replace(f'${{user_config.{key}}}', str(item)) + value = value.replace(f'${{CLAUDE_PLUGIN_OPTION_{key.upper()}}}', str(item)) + return value + + +def _expand_command_vars( + h: dict[str, Any], + project_path: str, + plugin_root: str | None, + plugin_data_dir: str | None = None, + user_config: dict[str, Any] | None = None, +) -> dict[str, Any]: + out = dict(h) + cmd = out.get('command') + if isinstance(cmd, str): + out['command'] = _expand_path_vars( + cmd, project_path, plugin_root, plugin_data_dir, user_config) + return out diff --git a/ms_agent/hooks/loaders/cursor.py b/ms_agent/hooks/loaders/cursor.py new file mode 100644 index 000000000..1af1baa7a --- /dev/null +++ b/ms_agent/hooks/loaders/cursor.py @@ -0,0 +1,95 @@ +"""Cursor hooks.json loader.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from ms_agent.hooks.registry import HookRegistry, _parse_hook_handler, MatcherGroup +from ms_agent.hooks.tool_name_mapper import ToolNameMapper +from ms_agent.utils import get_logger + +logger = get_logger() + +_CURSOR_EVENT_MAP = { + 'sessionStart': 'SessionStart', + 'beforeSubmitPrompt': 'UserPromptSubmit', + 'preToolUse': 'PreToolUse', + 'postToolUse': 'PostToolUse', + 'stop': 'Stop', + 'subagentStop': 'SubagentStop', + 'beforeShellExecution': 'PreToolUse', + 'afterFileEdit': 'PostToolUse', +} + + +class CursorHooksLoader: + @staticmethod + def load_file( + path: Path | str, + project_path: str, + *, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + with open(path, encoding='utf-8') as f: + data = json.load(f) + hooks = data.get('hooks', data) + return CursorHooksLoader.parse_hooks( + hooks, project_path, enabled_executors=enabled_executors) + + @staticmethod + def parse_hooks( + hooks: dict[str, Any], + project_path: str, + *, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + if not hooks: + return HookRegistry(_index={}) + + mapper = ToolNameMapper(enabled_sources=frozenset({'cursor'})) + index: dict[str, tuple[MatcherGroup, ...]] = {} + + for event_name, entries in hooks.items(): + canonical = _CURSOR_EVENT_MAP.get(event_name) + if not canonical or canonical not in HookRegistry.VALID_EVENTS: + logger.warning('Skipping unknown Cursor hook event: %s', event_name) + continue + + groups = [] + for entry in (entries or []): + matcher = entry.get('matcher') + if event_name == 'beforeShellExecution': + matcher = matcher or f'*{ToolNameMapper.TOOL_SPLITER}shell_executor' + elif event_name == 'afterFileEdit': + matcher = matcher or f'*{ToolNameMapper.TOOL_SPLITER}write_file' + elif matcher and canonical in HookRegistry.TOOL_EVENTS: + matcher = mapper.external_matcher_to_native(matcher, 'cursor') + + t = entry.get('type', 'command') or 'command' + if t not in enabled_executors: + logger.warning( + 'Cursor hook type %s not in enabled_executors %s, skipping', + t, + sorted(enabled_executors), + ) + continue + + h = { + 'type': t, + 'command': entry.get('command'), + 'timeout': entry.get('timeout', 30), + 'failClosed': entry.get('failClosed', False), + } + parsed = _parse_hook_handler(h) + if parsed: + groups.append(MatcherGroup( + matcher=matcher if canonical in HookRegistry.TOOL_EVENTS else None, + hooks=(parsed,), + )) + if groups: + existing = index.get(canonical, ()) + index[canonical] = existing + tuple(groups) + + return HookRegistry(_index=index) diff --git a/ms_agent/hooks/loaders/hermes.py b/ms_agent/hooks/loaders/hermes.py new file mode 100644 index 000000000..9d1b15783 --- /dev/null +++ b/ms_agent/hooks/loaders/hermes.py @@ -0,0 +1,132 @@ +"""Hermes shell hook loader.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml + +from ms_agent.hooks.registry import HookRegistry, _parse_hook_handler, MatcherGroup +from ms_agent.hooks.tool_name_mapper import ToolNameMapper +from ms_agent.utils import get_logger + +logger = get_logger() + +_HERMES_EVENT_MAP = { + 'on_session_start': 'SessionStart', + 'pre_llm_call': 'UserPromptSubmit', + 'pre_tool_call': 'PreToolUse', + 'post_tool_call': 'PostToolUse', + 'on_session_end': 'Stop', + 'subagent_stop': 'SubagentStop', + 'pre_approval_request': 'PermissionRequest', +} + + +class HermesShellLoader: + @staticmethod + def load_file( + path: Path | str, + project_path: str, + *, + plugin_root: str | None = None, + plugin_data_dir: str | None = None, + user_config: dict[str, Any] | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + with open(path, encoding='utf-8') as f: + data = yaml.safe_load(f) or {} + hooks = data.get('hooks', {}) + return HermesShellLoader.parse_hooks( + hooks, + project_path, + plugin_root=plugin_root, + plugin_data_dir=plugin_data_dir, + user_config=user_config, + enabled_executors=enabled_executors, + ) + + @staticmethod + def parse_hooks( + hooks: dict[str, Any], + project_path: str, + *, + plugin_root: str | None = None, + plugin_data_dir: str | None = None, + user_config: dict[str, Any] | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + if not hooks: + return HookRegistry(_index={}) + + mapper = ToolNameMapper(enabled_sources=frozenset({'hermes'})) + index: dict[str, tuple[MatcherGroup, ...]] = {} + + for event_name, entries in hooks.items(): + canonical = _HERMES_EVENT_MAP.get(event_name) + if not canonical or canonical not in HookRegistry.VALID_EVENTS: + logger.warning('Skipping unknown Hermes hook event: %s', event_name) + continue + + groups = [] + for entry in (entries or []): + if isinstance(entry, str): + entry = {'command': entry} + if not isinstance(entry, dict): + continue + + matcher = entry.get('matcher') or entry.get('tool') + if matcher and canonical in HookRegistry.TOOL_EVENTS: + matcher = mapper.external_matcher_to_native(matcher, 'hermes') + matcher = _expand_vars( + matcher, project_path, plugin_root, plugin_data_dir, + user_config) + + cmd = entry.get('command') or entry.get('script') + h = { + 'type': 'command', + 'command': _expand_vars( + str(cmd), project_path, plugin_root, plugin_data_dir, + user_config) if cmd else cmd, + 'timeout': entry.get('timeout', 30), + 'fail_closed': entry.get('fail_closed', False), + } + if h['type'] not in enabled_executors: + logger.warning( + 'Hermes hook type %s not in enabled_executors %s, skipping', + h['type'], + sorted(enabled_executors), + ) + continue + parsed = _parse_hook_handler(h) + if parsed: + groups.append(MatcherGroup( + matcher=matcher if canonical in HookRegistry.TOOL_EVENTS else None, + hooks=(parsed,), + )) + if groups: + index[canonical] = tuple(groups) + + return HookRegistry(_index=index) + + +def _expand_vars( + value: str, + project_path: str, + plugin_root: str | None, + plugin_data_dir: str | None, + user_config: dict[str, Any] | None, +) -> str: + value = value.replace('${MS_AGENT_PROJECT_DIR}', project_path) + value = value.replace('${CLAUDE_PROJECT_DIR}', project_path) + if plugin_root: + value = value.replace('${MS_AGENT_PLUGIN_ROOT}', plugin_root) + value = value.replace('${CLAUDE_PLUGIN_ROOT}', plugin_root) + if plugin_data_dir: + value = value.replace('${MS_AGENT_PLUGIN_DATA}', plugin_data_dir) + value = value.replace('${CLAUDE_PLUGIN_DATA}', plugin_data_dir) + for key, item in (user_config or {}).items(): + value = value.replace(f'${{user_config.{key}}}', str(item)) + value = value.replace(f'${{CLAUDE_PLUGIN_OPTION_{key.upper()}}}', str(item)) + return value diff --git a/ms_agent/hooks/loaders/native.py b/ms_agent/hooks/loaders/native.py new file mode 100644 index 000000000..4f059be4c --- /dev/null +++ b/ms_agent/hooks/loaders/native.py @@ -0,0 +1,49 @@ +"""Native YAML/JSON hook loaders.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import yaml + +from ms_agent.hooks.registry import HookRegistry + + +class NativeYamlLoader: + @staticmethod + def load_file( + path: Path | str, + *, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + with open(path, encoding='utf-8') as f: + data = yaml.safe_load(f) or {} + hooks = data.get('hooks', data) + if not isinstance(hooks, dict): + return HookRegistry(_index={}) + return HookRegistry.from_dict( + hooks, + enabled_executors=enabled_executors, + source=str(path), + ) + + +class NativeJsonLoader: + @staticmethod + def load_file( + path: Path | str, + *, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + with open(path, encoding='utf-8') as f: + data = json.load(f) + hooks = data.get('hooks', data) + if not isinstance(hooks, dict): + return HookRegistry(_index={}) + return HookRegistry.from_dict( + hooks, + enabled_executors=enabled_executors, + source=str(path), + ) diff --git a/ms_agent/hooks/loaders/plugin.py b/ms_agent/hooks/loaders/plugin.py new file mode 100644 index 000000000..80fd753bb --- /dev/null +++ b/ms_agent/hooks/loaders/plugin.py @@ -0,0 +1,32 @@ +"""Plugin hooks/hooks.json loader (F9).""" + +from __future__ import annotations + +from pathlib import Path + +from ms_agent.hooks.loaders.claude import ClaudeSettingsLoader +from ms_agent.hooks.registry import HookRegistry + + +class PluginHooksLoader: + @staticmethod + def load_plugin( + plugin_root: str | Path, + *, + project_path: str, + plugin_data_dir: str | Path | None = None, + user_config: dict | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> HookRegistry: + root = Path(plugin_root) + hooks_path = root / 'hooks' / 'hooks.json' + if not hooks_path.is_file(): + return HookRegistry(_index={}) + return ClaudeSettingsLoader.parse_hooks_file( + hooks_path, + plugin_root=str(root), + plugin_data_dir=str(plugin_data_dir) if plugin_data_dir else None, + user_config=user_config, + project_path=project_path, + enabled_executors=enabled_executors, + ) diff --git a/ms_agent/hooks/permission_resolve.py b/ms_agent/hooks/permission_resolve.py new file mode 100644 index 000000000..b3c775887 --- /dev/null +++ b/ms_agent/hooks/permission_resolve.py @@ -0,0 +1,110 @@ +"""Merge PreToolUse hook decisions with PermissionEnforcer.""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from ms_agent.hooks.events import HookResult +from ms_agent.permission.enforcer import PermissionDecision, PermissionEnforcer +from ms_agent.permission.config import PermissionConfig +from ms_agent.permission.matcher import PermissionMatcher + +if TYPE_CHECKING: + from ms_agent.hooks.runtime import HookRuntime + + +async def check_rule_based_permissions( + tool_name: str, + tool_args: dict[str, Any], + config: PermissionConfig, + matcher: PermissionMatcher | None = None, +) -> PermissionDecision | None: + """Rule layer only: blacklist deny, explicit ask rules. No handler popup.""" + m = matcher or PermissionMatcher() + for pattern in config.blacklist: + if m.match_with_content(pattern, tool_name, tool_args): + return PermissionDecision( + action='deny', + reason=f'Denied by blacklist rule: {pattern}', + ) + for pattern in config.ask_rules: + if m.match_with_content(pattern, tool_name, tool_args): + return PermissionDecision( + action='ask', + reason=f'Ask rule matched: {pattern}', + ) + return None + + +async def _run_permission_request_hook( + hook_runtime: HookRuntime | None, + tool_name: str, + tool_args: dict[str, Any], + permission_config: PermissionConfig | None, +) -> HookResult | None: + if hook_runtime is None or hook_runtime.is_empty: + return None + if permission_config is None or permission_config.mode != 'interactive': + return None + if not hook_runtime.registry.get_handlers('PermissionRequest', tool_name): + return None + return await hook_runtime.run_permission_request(tool_name, tool_args) + + +async def resolve_hook_permission_decision( + hook_result: HookResult | None, + tool_name: str, + tool_args: dict[str, Any], + *, + permission_enforcer: PermissionEnforcer | None, + permission_config: PermissionConfig | None, + hook_runtime: HookRuntime | None = None, +) -> PermissionDecision | str: + if hook_result and hook_result.action == 'deny': + return f'Blocked by hook: {hook_result.reason}' + + args = ( + hook_result.updated_args + if hook_result and hook_result.updated_args + else tool_args + ) + + if hook_result and hook_result.action == 'allow': + if permission_config: + rule = await check_rule_based_permissions( + tool_name, args, permission_config) + if rule and rule.action == 'deny': + return rule + if rule and rule.action == 'ask': + if permission_enforcer: + return await permission_enforcer.check( + tool_name, args, force_decision=rule) + return PermissionDecision( + action='allow', + reason=hook_result.reason or 'Allowed by PreToolUse hook', + ) + + if hook_result and hook_result.action == 'ask': + if permission_enforcer: + return await permission_enforcer.check( + tool_name, + args, + force_decision=PermissionDecision( + action='ask', reason=hook_result.reason), + ) + + pr = await _run_permission_request_hook( + hook_runtime, tool_name, args, permission_config) + if pr and pr.action == 'deny': + return f'Blocked by hook: {pr.reason}' + if pr and pr.action == 'ask' and permission_enforcer: + return await permission_enforcer.check( + tool_name, + args, + force_decision=PermissionDecision( + action='ask', reason=pr.reason), + ) + + if permission_enforcer: + return await permission_enforcer.check(tool_name, args) + return PermissionDecision(action='allow', reason='No permission enforcer') diff --git a/ms_agent/hooks/registry.py b/ms_agent/hooks/registry.py new file mode 100644 index 000000000..4b3efad96 --- /dev/null +++ b/ms_agent/hooks/registry.py @@ -0,0 +1,202 @@ +"""HookRegistry — canonical event index and config parsing.""" + +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from typing import Any, ClassVar + +from ms_agent.utils import get_logger +from ms_agent.utils.pattern_matcher import match_pattern + +logger = get_logger() + + +@dataclass(frozen=True) +class HookHandlerConfig: + type: str = 'command' + timeout: float = 30.0 + fail_closed: bool = False + command: str | None = None + url: str | None = None + headers: dict[str, str] = field(default_factory=dict) + allowed_env_vars: tuple[str, ...] = () + prompt: str | None = None + model: str | None = None + max_turns: int = 20 + source_plugin_id: str | None = None + source_plugin_root: str | None = None + source_plugin_data_dir: str | None = None + + +@dataclass(frozen=True) +class MatcherGroup: + matcher: str | None + hooks: tuple[HookHandlerConfig, ...] + + +def _filter_handlers_by_executor( + hooks_raw: list[dict[str, Any]], + enabled_executors: frozenset[str], + *, + source: str = 'config', +) -> tuple[HookHandlerConfig, ...]: + handlers: list[HookHandlerConfig] = [] + for h in hooks_raw: + t = h.get('type', 'command') or 'command' + if t not in enabled_executors: + logger.warning( + 'Hook type %s not in enabled_executors %s, skipping (%s)', + t, + sorted(enabled_executors), + source, + ) + continue + parsed = _parse_hook_handler(h) + if parsed is not None: + handlers.append(parsed) + return tuple(handlers) + + +def _parse_hook_handler(h: dict[str, Any]) -> HookHandlerConfig | None: + t = h.get('type', 'command') + timeout = float(h.get('timeout', 30.0)) + fail_closed = bool(h.get('failClosed', h.get('fail_closed', False))) + if t == 'command': + if not h.get('command'): + return None + return HookHandlerConfig( + type='command', + command=h['command'], + timeout=timeout, + fail_closed=fail_closed, + ) + if t == 'http': + if not h.get('url'): + return None + return HookHandlerConfig( + type='http', + url=h['url'], + headers=dict(h.get('headers') or {}), + allowed_env_vars=tuple( + h.get('allowedEnvVars', h.get('allowed_env_vars', []))), + timeout=timeout, + fail_closed=fail_closed, + ) + if t in ('prompt', 'agent'): + if not h.get('prompt'): + return None + return HookHandlerConfig( + type=t, + prompt=h['prompt'], + model=h.get('model'), + max_turns=int(h.get('maxTurns', h.get('max_turns', 20))), + timeout=timeout, + fail_closed=fail_closed, + ) + logger.warning('Unknown hook handler type: %s', t) + return None + + +@dataclass(frozen=True) +class HookRegistry: + _index: dict[str, tuple[MatcherGroup, ...]] + + VALID_EVENTS: ClassVar[frozenset[str]] = frozenset({ + 'SessionStart', 'PreToolUse', 'PostToolUse', + 'UserPromptSubmit', 'Stop', 'PermissionRequest', + 'SubagentStop', + }) + + TOOL_EVENTS: ClassVar[frozenset[str]] = frozenset({ + 'PreToolUse', 'PostToolUse', 'PermissionRequest', + }) + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + *, + enabled_executors: frozenset[str] = frozenset({'command'}), + source: str = 'config', + ) -> HookRegistry: + if not d: + return cls(_index={}) + + index: dict[str, tuple[MatcherGroup, ...]] = {} + for event_type, groups_raw in d.items(): + if event_type in ( + 'enabled_sources', 'enabled_executors', 'default_model', + 'fail_closed', 'allowed_http_hook_urls', + 'http_hook_allowed_env_vars'): + continue + if event_type not in cls.VALID_EVENTS: + logger.warning('Unknown hook event type: %s', event_type) + continue + groups = [] + for g in (groups_raw or []): + matcher = g.get('matcher') if event_type in cls.TOOL_EVENTS else None + hooks_raw = g.get('hooks', []) + handlers = _filter_handlers_by_executor( + hooks_raw, + enabled_executors, + source=source, + ) + if handlers: + groups.append(MatcherGroup(matcher=matcher, hooks=handlers)) + if groups: + index[event_type] = tuple(groups) + return cls(_index=index) + + def merge(self, other: HookRegistry) -> HookRegistry: + merged: dict[str, tuple[MatcherGroup, ...]] = {} + all_events = set(self._index) | set(other._index) + for event in all_events: + self_groups = self._index.get(event, ()) + other_groups = other._index.get(event, ()) + merged[event] = self_groups + other_groups + return HookRegistry(_index=merged) + + def with_plugin_source( + self, + *, + plugin_id: str, + plugin_root: str, + plugin_data_dir: str, + ) -> HookRegistry: + index: dict[str, tuple[MatcherGroup, ...]] = {} + for event, groups in self._index.items(): + updated_groups = [] + for group in groups: + updated_handlers = tuple( + replace( + handler, + source_plugin_id=plugin_id, + source_plugin_root=plugin_root, + source_plugin_data_dir=plugin_data_dir, + ) + for handler in group.hooks + ) + updated_groups.append( + MatcherGroup(matcher=group.matcher, hooks=updated_handlers)) + index[event] = tuple(updated_groups) + return HookRegistry(_index=index) + + def get_handlers( + self, + event_type: str, + tool_name: str | None = None, + ) -> list[HookHandlerConfig]: + groups = self._index.get(event_type, []) + result: list[HookHandlerConfig] = [] + for group in groups: + if event_type not in self.TOOL_EVENTS: + result.extend(group.hooks) + elif group.matcher is None: + result.extend(group.hooks) + elif tool_name is not None and match_pattern(group.matcher, tool_name): + result.extend(group.hooks) + return result + + @property + def is_empty(self) -> bool: + return not self._index diff --git a/ms_agent/hooks/response_adapter.py b/ms_agent/hooks/response_adapter.py new file mode 100644 index 000000000..2427e0703 --- /dev/null +++ b/ms_agent/hooks/response_adapter.py @@ -0,0 +1,154 @@ +"""Normalize stdout/HTTP hook responses from multiple ecosystems.""" + +from __future__ import annotations + +import json +from typing import Any + +from ms_agent.hooks.events import HookResult + + +class ResponseAdapter: + """Parse hook stdout JSON into a canonical HookResult.""" + + def parse( + self, + stdout_text: str, + exit_code: int = 0, + stderr_text: str = '', + event: str | None = None, + ) -> HookResult: + if not stdout_text: + return HookResult(action='pass', exit_code=exit_code, stderr=stderr_text) + + try: + data = json.loads(stdout_text) + except json.JSONDecodeError: + return HookResult( + action='error', + reason=f'Invalid JSON in hook stdout: {stdout_text[:200]}', + exit_code=exit_code, + stderr=stderr_text, + ) + + if not isinstance(data, dict): + return HookResult(action='pass', exit_code=exit_code) + + return self._normalize_dict(data, event, exit_code, stderr_text) + + def _normalize_dict( + self, + data: dict[str, Any], + event: str | None, + exit_code: int, + stderr_text: str, + ) -> HookResult: + updated_args = ( + data.get('updatedArgs') + or data.get('updated_input') + or data.get('updatedInput') + ) + if updated_args is not None and not isinstance(updated_args, dict): + updated_args = None + + additional_context = ( + data.get('additionalContext') + or data.get('additional_context') + or data.get('agent_message') + or data.get('context') + or '' + ) + + # Claude hookSpecificOutput + hso = data.get('hookSpecificOutput') + if isinstance(hso, dict): + perm = hso.get('permissionDecision') + if perm: + action = self._map_permission(perm) + return HookResult( + action=action, + reason=data.get('reason', '') or hso.get('reason', ''), + additional_context=additional_context, + updated_args=updated_args or hso.get('updatedInput'), + exit_code=exit_code, + stderr=stderr_text, + ) + if hso.get('updatedInput'): + updated_args = hso['updatedInput'] + + # Direct decision fields + decision = data.get('decision') or data.get('permission') + action_field = data.get('action') + if decision: + action = self._map_decision(str(decision)) + return HookResult( + action=action, + reason=data.get('reason', '') or data.get('user_message', ''), + additional_context=additional_context, + updated_args=updated_args, + exit_code=exit_code, + stderr=stderr_text, + ) + if action_field in ('block', 'deny'): + return HookResult( + action='deny', + reason=data.get('reason', '') or data.get('message', ''), + additional_context=additional_context, + updated_args=updated_args, + exit_code=exit_code, + stderr=stderr_text, + ) + + # Only updated_args without permission decision -> passthrough + if updated_args is not None: + return HookResult( + action='pass', + additional_context=additional_context, + updated_args=updated_args, + exit_code=exit_code, + stderr=stderr_text, + ) + + if additional_context: + return HookResult( + action='pass', + additional_context=additional_context, + exit_code=exit_code, + stderr=stderr_text, + ) + + # Stop event: continue=false means allow stop (pass) + if event == 'Stop' and data.get('continue') is False: + return HookResult(action='pass', exit_code=exit_code) + + # Cursor stop followup_message -> block + if event == 'Stop' and data.get('followup_message'): + return HookResult( + action='block', + reason=str(data['followup_message']), + exit_code=exit_code, + ) + + return HookResult(action='pass', exit_code=exit_code, stderr=stderr_text) + + @staticmethod + def _map_decision(decision: str) -> str: + d = decision.lower() + if d in ('deny', 'block', 'reject'): + return 'deny' + if d in ('allow', 'approve', 'permit'): + return 'allow' + if d == 'ask': + return 'ask' + return 'pass' + + @staticmethod + def _map_permission(perm: str) -> str: + p = perm.lower() + if p == 'deny': + return 'deny' + if p == 'allow': + return 'allow' + if p == 'ask': + return 'ask' + return 'pass' diff --git a/ms_agent/hooks/runtime.py b/ms_agent/hooks/runtime.py new file mode 100644 index 000000000..80863eeb8 --- /dev/null +++ b/ms_agent/hooks/runtime.py @@ -0,0 +1,248 @@ +"""HookRuntime facade — registry + executor + payload building.""" + +from __future__ import annotations + +import time +from dataclasses import asdict, dataclass +from typing import Any, Awaitable, Callable + +from ms_agent.hooks.context import HookAttachment +from ms_agent.hooks.events import ( + HookResult, + PermissionRequestEvent, + PostToolUseEvent, + PreToolUseEvent, + SessionStartEvent, + StopEvent, + UserPromptSubmitEvent, +) +from ms_agent.hooks.executor import HookExecutor +from ms_agent.hooks.executors.command import HookExecutionContext +from ms_agent.hooks.registry import HookHandlerConfig, HookRegistry +from ms_agent.hooks.tool_name_mapper import ToolNameMapper + +HookEventCallback = Callable[[dict[str, Any]], Awaitable[None]] + + +def _handler_name(handler: HookHandlerConfig) -> str: + if handler.command: + return handler.command + if handler.url: + return handler.url + return handler.type + + +@dataclass +class HookRuntime: + registry: HookRegistry + executor: HookExecutor + session_id: str + project_path: str + tool_name_mapper: ToolNameMapper + on_hook_event: HookEventCallback | None = None + default_model: str = 'qwen-plus' + + BLOCKABLE_EVENTS = frozenset({ + 'PreToolUse', 'UserPromptSubmit', 'Stop', 'PermissionRequest', + }) + + @property + def is_empty(self) -> bool: + return self.registry.is_empty + + @property + def has_session_handlers(self) -> bool: + return bool(self.registry.get_handlers('SessionStart')) + + def _ctx(self) -> HookExecutionContext: + return HookExecutionContext( + session_id=self.session_id, + project_path=self.project_path, + ) + + def _build_payload(self, event_obj: Any) -> dict[str, Any]: + payload = asdict(event_obj) + payload['project_path'] = self.project_path + payload['cwd'] = self.project_path + payload.setdefault('extra', {}) + if 'tool_args' in payload: + payload.setdefault('tool_input', payload['tool_args']) + return self.tool_name_mapper.enrich_payload( + payload, payload.get('tool_name')) + + async def _notify_hook_event( + self, + *, + hook_event: str, + handler: HookHandlerConfig, + result: HookResult, + duration_ms: float, + ) -> None: + if self.on_hook_event is None: + return + await self.on_hook_event({ + 'hook_event': hook_event, + 'hook_name': _handler_name(handler), + 'action': result.action, + 'reason': result.reason, + 'duration_ms': duration_ms, + }) + + async def _run_event( + self, + event_type: str, + event_obj: Any, + tool_name: str | None = None, + ) -> HookResult: + handlers = self.registry.get_handlers(event_type, tool_name) + if not handlers: + return HookResult(action='pass') + + payload = self._build_payload(event_obj) + blockable = event_type in self.BLOCKABLE_EVENTS + + async def _on_handler_complete( + handler: HookHandlerConfig, + result: HookResult, + duration_ms: float, + ) -> None: + await self._notify_hook_event( + hook_event=event_type, + handler=handler, + result=result, + duration_ms=duration_ms, + ) + + result = await self.executor.execute_all( + handlers, + payload, + blockable=blockable, + ctx=self._ctx(), + on_handler_complete=_on_handler_complete, + ) + + # Stop: map deny to block for continuation semantics (§9.4) + if event_type == 'Stop' and result.action == 'deny': + result = HookResult( + action='block', + reason=result.reason, + additional_context=result.additional_context, + updated_args=result.updated_args, + exit_code=result.exit_code, + stderr=result.stderr, + ) + + return result + + @staticmethod + def _attachments_for_context( + hook_event: str, + result: HookResult, + *, + tool_call_id: str | None = None, + ) -> list[HookAttachment]: + if not result.additional_context: + return [] + return [ + HookAttachment( + type='hook_additional_context', + hook_event=hook_event, + tool_call_id=tool_call_id, + content=result.additional_context, + ), + ] + + async def run_session_start(self, runtime: Any, messages: list) -> HookResult: + return await self._run_event( + 'SessionStart', + SessionStartEvent( + session_id=self.session_id, + project_path=self.project_path, + ), + ) + + async def run_pre_tool_use( + self, + tool_name: str, + tool_args: dict[str, Any], + *, + session_id: str | None = None, + project_path: str | None = None, + ) -> tuple[HookResult, list[HookAttachment]]: + result = await self._run_event( + 'PreToolUse', + PreToolUseEvent( + session_id=session_id or self.session_id, + tool_name=tool_name, + tool_args=tool_args, + ), + tool_name=tool_name, + ) + attachments = self._attachments_for_context( + 'PreToolUse', + result, + tool_call_id=None, + ) + return result, attachments + + async def run_post_tool_use( + self, + tool_name: str, + tool_args: dict[str, Any], + tool_result: str, + *, + tool_call_id: str | None = None, + ) -> tuple[HookResult, list[HookAttachment]]: + result = await self._run_event( + 'PostToolUse', + PostToolUseEvent( + session_id=self.session_id, + tool_name=tool_name, + tool_args=tool_args, + tool_result=tool_result, + ), + tool_name=tool_name, + ) + attachments = self._attachments_for_context( + 'PostToolUse', + result, + tool_call_id=tool_call_id, + ) + return result, attachments + + async def run_user_prompt_submit(self, prompt: str) -> HookResult: + return await self._run_event( + 'UserPromptSubmit', + UserPromptSubmitEvent(session_id=self.session_id, prompt=prompt), + ) + + async def run_permission_request( + self, + tool_name: str, + tool_args: dict[str, Any], + ) -> HookResult: + return await self._run_event( + 'PermissionRequest', + PermissionRequestEvent( + session_id=self.session_id, + tool_name=tool_name, + tool_args=tool_args, + ), + tool_name=tool_name, + ) + + async def run_stop( + self, + reason: str = '', + last_assistant_message: str = '', + stop_hook_active: bool = False, + ) -> HookResult: + return await self._run_event( + 'Stop', + StopEvent( + session_id=self.session_id, + reason=reason, + last_assistant_message=last_assistant_message, + stop_hook_active=stop_hook_active, + ), + ) diff --git a/ms_agent/hooks/tool_name_mapper.py b/ms_agent/hooks/tool_name_mapper.py new file mode 100644 index 000000000..c0daa1f37 --- /dev/null +++ b/ms_agent/hooks/tool_name_mapper.py @@ -0,0 +1,88 @@ +"""Map ms-agent tool names to external ecosystem aliases.""" + +from __future__ import annotations + +# ms-agent suffix -> external names +_TOOL_SUFFIX_MAP: dict[str, dict[str, str]] = { + 'shell_executor': { + 'claude': 'Bash', + 'cursor': 'Shell', + 'hermes': 'terminal', + }, + 'read_file': { + 'claude': 'Read', + 'cursor': 'Read', + 'hermes': 'read_file', + }, + 'write_file': { + 'claude': 'Write', + 'cursor': 'Write', + 'hermes': 'write_file', + }, + 'edit_file': { + 'claude': 'Edit', + 'cursor': 'Write', + 'hermes': 'patch', + }, +} + + +class ToolNameMapper: + """Bidirectional tool name mapping for hook payloads and matchers.""" + + TOOL_SPLITER = '---' + + def __init__(self, enabled_sources: frozenset[str] = frozenset({'native'})): + self._enabled_sources = enabled_sources + + def to_external(self, tool_name: str, platform: str) -> str | None: + if self.TOOL_SPLITER not in tool_name: + return None + suffix = tool_name.split(self.TOOL_SPLITER, 1)[1] + mapping = _TOOL_SUFFIX_MAP.get(suffix, {}) + return mapping.get(platform) + + def enrich_payload( + self, + payload: dict, + tool_name: str | None = None, + ) -> dict: + """Add external tool name aliases to stdin payload.""" + tn = tool_name or payload.get('tool_name', '') + if not tn: + return payload + enriched = dict(payload) + if 'claude' in self._enabled_sources or 'native' in self._enabled_sources: + ext = self.to_external(tn, 'claude') + if ext: + enriched['tool_name_claude'] = ext + if 'cursor' in self._enabled_sources: + ext = self.to_external(tn, 'cursor') + if ext: + enriched['tool_name_cursor'] = ext + if 'hermes' in self._enabled_sources: + ext = self.to_external(tn, 'hermes') + if ext: + enriched['tool_name_hermes'] = ext + args = enriched.get('tool_args') + if args is not None: + enriched.setdefault('tool_input', args) + enriched.setdefault('hook_event_name', enriched.get('event', '')) + return enriched + + def external_matcher_to_native(self, matcher: str, platform: str) -> str: + """Convert external tool matcher to ms-agent format where possible.""" + if self.TOOL_SPLITER in matcher: + return matcher + reverse: dict[str, str] = {} + for suffix, platforms in _TOOL_SUFFIX_MAP.items(): + name = platforms.get(platform) + if name: + reverse[name] = f'*{self.TOOL_SPLITER}{suffix}' + for ext_name, native_pattern in reverse.items(): + if matcher == ext_name: + return native_pattern + # Shell/Bash/terminal wildcard + if matcher in ('Bash', 'Shell', 'terminal'): + return f'*{self.TOOL_SPLITER}shell_executor' + return matcher diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 9b639e55a..e118a8d35 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -93,6 +93,9 @@ class Message: # role=tool: extra payload for UIs / SSE only; omitted from LLM API via to_dict_clean(). tool_detail: Optional[str] = None + # Hook attachments for UI / LLM condensation; omitted from to_dict_clean(). + hook_attachments: List[Any] = field(default_factory=list) + def to_dict(self): return asdict(self) @@ -120,6 +123,7 @@ def to_dict_clean(self): 'prompt_tokens', 'api_calls', 'tool_detail', + 'hook_attachments', 'searching_detail', 'search_result', '_responses_output_items', @@ -143,6 +147,7 @@ class ToolResult: resources: List[str] = field(default_factory=list) extra: dict = field(default_factory=dict) tool_detail: Optional[str] = None + hook_attachments: List[Any] = field(default_factory=list) @staticmethod def from_raw(raw): @@ -157,9 +162,13 @@ def from_raw(raw): text=str(model_text), resources=raw.get('resources', []), tool_detail=None if td is None else str(td), + hook_attachments=raw.get('hook_attachments', []), extra={ k: v for k, v in raw.items() - if k not in ['text', 'resources', 'result', 'tool_detail'] + if k not in [ + 'text', 'resources', 'result', 'tool_detail', + 'hook_attachments', + ] }) raise TypeError('tool_call_result must be str or dict') diff --git a/ms_agent/mcp/__init__.py b/ms_agent/mcp/__init__.py new file mode 100644 index 000000000..570139258 --- /dev/null +++ b/ms_agent/mcp/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""MCP runtime state machine and ToolManager synchronization.""" +from .runtime import ( + DEGRADED_FAILURE_THRESHOLD, + MCPFailureRecord, + MCPRuntime, + MCPServerState, + classify_mcp_failure, + classify_failure_message, + is_connection_error, +) + +__all__ = [ + 'DEGRADED_FAILURE_THRESHOLD', + 'MCPFailureRecord', + 'MCPRuntime', + 'MCPServerState', + 'classify_mcp_failure', + 'classify_failure_message', + 'is_connection_error', +] diff --git a/ms_agent/mcp/runtime.py b/ms_agent/mcp/runtime.py new file mode 100644 index 000000000..3577a2c0a --- /dev/null +++ b/ms_agent/mcp/runtime.py @@ -0,0 +1,508 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""MCP runtime state machine and ToolManager synchronization.""" +from __future__ import annotations + +import asyncio +import copy +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Deque, Dict, List, Literal, Optional, TYPE_CHECKING + +from ms_agent.config.mcp_schema import ( + ResolvedMCPConfig, + connection_params_for_client, +) +from ms_agent.tools.mcp_client import MCPClient +from ms_agent.utils import enhance_error, get_logger + +if TYPE_CHECKING: + from ms_agent.tools.tool_manager import ToolManager + +logger = get_logger() + +MCPServerStatus = Literal[ + 'registered', + 'connecting', + 'connected', + 'degraded', + 'error', + 'disabled', +] + +FAILURE_HISTORY_LIMIT = 20 +# Transient failures (timeout / 5xx) must reach this count before degraded. +DEGRADED_FAILURE_THRESHOLD = 3 + +MCPFailureKind = Literal['none', 'transient', 'hard'] + + +@dataclass +class MCPFailureRecord: + """Single failure snapshot (in-memory, for UI / diagnostics).""" + + at: str + phase: Literal['connect', 'call_tool', 'list_tools'] + message: str + tool_name: str | None = None + + +@dataclass +class MCPServerState: + name: str + config: dict + enabled: bool + status: MCPServerStatus + last_error: str | None = None + last_success_at: str | None = None + last_failure_at: str | None = None + consecutive_failures: int = 0 + failure_history: Deque[MCPFailureRecord] = field( + default_factory=lambda: deque(maxlen=FAILURE_HISTORY_LIMIT)) + tool_count: int = 0 + cached_tools: list[dict] = field(default_factory=list) + connected_at: str | None = None + + +def _utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def classify_mcp_failure(exc: BaseException) -> MCPFailureKind: + """Classify transport failures for degraded policy. + + - ``hard``: session/process gone — degrade immediately. + - ``transient``: timeout / upstream 5xx — may be jitter; degrade only + after ``DEGRADED_FAILURE_THRESHOLD`` consecutive failures. + - ``none``: business / argument errors — do not change ``status``. + """ + if isinstance(exc, asyncio.TimeoutError): + return 'transient' + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + return 'hard' + if isinstance(exc, ConnectionError): + return 'hard' + msg = str(exc).lower() + hard_markers = ( + 'connection closed', + 'session closed', + 'broken pipe', + 'disconnected', + 'connection refused', + ) + if any(m in msg for m in hard_markers): + return 'hard' + transient_markers = ( + 'timeout', + 'timed out', + '502', + '503', + ) + if any(m in msg for m in transient_markers): + return 'transient' + return 'none' + + +def classify_failure_message(message: str) -> MCPFailureKind: + """Message-only fallback when the original exception is unavailable.""" + return classify_mcp_failure(Exception(message)) + + +def is_connection_error(exc: BaseException) -> bool: + """Whether a tool exception should be reported to MCP failure tracking.""" + return classify_mcp_failure(exc) in ('transient', 'hard') + + +class MCPRuntime: + """Configuration-driven MCP lifecycle and ToolManager sync.""" + + def __init__( + self, + *, + mcp_client: MCPClient | None = None, + config: ResolvedMCPConfig | None = None, + owns_client: bool | None = None, + connect_policy: Literal['skip', 'fail_fast'] = 'skip', + ): + self._config = config + self._connect_policy = connect_policy + self._states: Dict[str, MCPServerState] = {} + self._tool_manager: ToolManager | None = None + self._sync_lock = asyncio.Lock() + self._started = False + + if mcp_client is not None: + self._client = mcp_client + self._owns_client = ( + owns_client if owns_client is not None else False) + else: + mcp_json = config.to_mcp_json() if config else None + self._client = MCPClient(mcp_config=mcp_json) + self._owns_client = True + + if config is not None: + self._register_from_config(config) + + @property + def client(self) -> MCPClient: + return self._client + + @property + def is_started(self) -> bool: + return self._started + + # ── lifecycle ────────────────────────────────────────────────────── + + def _register_from_config(self, config: ResolvedMCPConfig) -> None: + for name, server_cfg in config.mcp_servers.items(): + enabled = bool(server_cfg.get('enabled', True)) + status: MCPServerStatus = 'disabled' if not enabled else 'registered' + self._states[name] = MCPServerState( + name=name, + config=copy.deepcopy(server_cfg), + enabled=enabled, + status=status, + ) + + async def start(self) -> None: + """Connect all enabled servers (idempotent).""" + async with self._sync_lock: + self._started = True + failures: list[BaseException] = [] + for name, state in self._states.items(): + if not state.enabled: + state.status = 'disabled' + continue + if self._client.is_connected(name): + if state.status != 'connected': + state.status = 'connected' + continue + try: + await self._connect_server(name, state) + except Exception as exc: + if self._connect_policy == 'fail_fast': + raise + failures.append(exc) + if failures and self._connect_policy == 'fail_fast': + raise failures[0] + + async def stop(self) -> None: + """Disconnect all servers when this runtime owns the client.""" + async with self._sync_lock: + self._started = False + if self._owns_client: + await self._client.cleanup() + for state in self._states.values(): + if state.enabled: + state.status = 'registered' + else: + state.status = 'disabled' + state.cached_tools.clear() + state.tool_count = 0 + state.connected_at = None + + async def _connect_server(self, name: str, state: MCPServerState) -> None: + state.status = 'connecting' + state.last_error = None + try: + await self._client.connect_single_server( + name, connection_params_for_client(state.config)) + state.status = 'connected' + state.connected_at = _utc_now() + state.last_success_at = state.connected_at + state.consecutive_failures = 0 + await self._refresh_cached_tools(name, state) + except Exception as exc: + new_exc = enhance_error(exc, f'Connect `{name}` failed, details:') + await self._record_connect_failure(name, str(new_exc)) + if self._connect_policy == 'fail_fast': + raise new_exc from exc + + async def _refresh_cached_tools( + self, + name: str, + state: MCPServerState, + ) -> None: + try: + tools = await self._client.get_tools_for_server(name) + state.cached_tools = [dict(t) for t in tools] + state.tool_count = len(state.cached_tools) + state.last_success_at = _utc_now() + except Exception as exc: + await self.record_failure( + name, 'list_tools', str(exc), exc=exc) + + # ── enable / disable ─────────────────────────────────────────────── + + async def enable_server(self, name: str) -> MCPServerState: + async with self._sync_lock: + return await self._enable_server_unlocked(name) + + async def _enable_server_unlocked(self, name: str) -> MCPServerState: + state = self._require_state(name) + state.enabled = True + state.config['enabled'] = True + if not self._client.is_connected(name): + await self._connect_server(name, state) + elif state.status in ('registered', 'disabled', 'error'): + state.status = 'connected' + await self._sync_tools_unlocked() + return state + + async def disable_server(self, name: str) -> MCPServerState: + async with self._sync_lock: + return await self._disable_server_unlocked(name) + + async def _disable_server_unlocked(self, name: str) -> MCPServerState: + state = self._require_state(name) + state.enabled = False + state.config['enabled'] = False + state.status = 'disabled' + state.cached_tools.clear() + state.tool_count = 0 + await self._sync_tools_unlocked() + return state + + async def reload_server(self, name: str) -> MCPServerState: + async with self._sync_lock: + await self._disable_server_unlocked(name) + state = self._require_state(name) + state.enabled = True + state.config['enabled'] = True + if self._client.is_connected(name): + await self._client.disconnect_server(name) + await self._connect_server(name, state) + await self._sync_tools_unlocked() + return state + + async def reconnect_server(self, name: str) -> MCPServerState: + async with self._sync_lock: + state = self._require_state(name) + if not state.enabled: + raise ValueError(f'Server {name} is disabled') + if self._client.is_connected(name): + await self._client.disconnect_server(name) + state.status = 'registered' + state.cached_tools.clear() + state.tool_count = 0 + await self._connect_server(name, state) + await self._sync_tools_unlocked() + return state + + # ── config hot update ────────────────────────────────────────────── + + async def apply_config(self, config: ResolvedMCPConfig) -> list[MCPServerState]: + async with self._sync_lock: + self._config = config + old_names = set(self._states) + new_names = set(config.mcp_servers) + removed = old_names - new_names + added = new_names - old_names + changed = { + n for n in old_names & new_names + if self._states[n].config != config.mcp_servers[n] + } + + for name in removed: + if self._client.is_connected(name): + await self._client.disconnect_server(name) + if name in self._states: + state = self._states[name] + state.enabled = False + state.status = 'disabled' + state.cached_tools.clear() + state.tool_count = 0 + self._states.pop(name, None) + + for name in added: + entry = copy.deepcopy(config.mcp_servers[name]) + enabled = bool(entry.get('enabled', True)) + self._states[name] = MCPServerState( + name=name, + config=entry, + enabled=enabled, + status='disabled' if not enabled else 'registered', + ) + + for name in changed: + state = self._states[name] + old_enabled = state.enabled + state.config = copy.deepcopy(config.mcp_servers[name]) + state.enabled = bool(state.config.get('enabled', True)) + if not state.enabled: + state.status = 'disabled' + state.cached_tools.clear() + state.tool_count = 0 + elif old_enabled != state.enabled or name in changed: + if self._client.is_connected(name): + await self._client.disconnect_server(name) + state.status = 'registered' + + for name, state in self._states.items(): + if not state.enabled: + continue + if not self._client.is_connected(name): + try: + await self._connect_server(name, state) + except Exception: + if self._connect_policy == 'fail_fast': + raise + elif name in changed: + await self._refresh_cached_tools(name, state) + + await self._sync_tools_unlocked() + return list(self._states.values()) + + # ── query ────────────────────────────────────────────────────────── + + def list_servers(self) -> list[MCPServerState]: + return list(self._states.values()) + + def get_server(self, name: str) -> MCPServerState | None: + return self._states.get(name) + + def is_callable(self, server_name: str) -> bool: + state = self._states.get(server_name) + return state is not None and state.status == 'connected' + + def unavailable_detail(self, server_name: str) -> dict: + state = self._states.get(server_name) + if state is None: + return { + 'success': False, + 'error': 'mcp_unavailable', + 'server_name': server_name, + 'message': f'Unknown MCP server: {server_name}', + } + return { + 'success': False, + 'error': 'mcp_unavailable', + 'server_name': server_name, + 'status': state.status, + 'message': state.last_error or ( + f'MCP server {server_name} is not callable (status={state.status})' + ), + } + + # ── failure tracking ─────────────────────────────────────────────── + + async def record_failure( + self, + name: str, + phase: str, + message: str, + *, + tool_name: str | None = None, + exc: BaseException | None = None, + ) -> None: + async with self._sync_lock: + degraded = self._apply_failure_state( + name, phase, message, tool_name=tool_name, exc=exc) + if degraded: + await self._sync_tools_unlocked() + + def _apply_failure_state( + self, + name: str, + phase: str, + message: str, + *, + tool_name: str | None = None, + exc: BaseException | None = None, + ) -> bool: + """Update failure counters; return True if status became degraded.""" + state = self._states.get(name) + if state is None: + return False + failure_kind = ( + classify_mcp_failure(exc) if exc is not None + else classify_failure_message(message)) + if failure_kind == 'none': + return False + now = _utc_now() + record = MCPFailureRecord( + at=now, + phase=phase, # type: ignore[arg-type] + message=message, + tool_name=tool_name, + ) + state.failure_history.append(record) + state.last_error = message + state.last_failure_at = now + state.consecutive_failures += 1 + should_degrade = ( + failure_kind == 'hard' + or state.consecutive_failures >= DEGRADED_FAILURE_THRESHOLD + ) + if should_degrade and state.status == 'connected': + state.status = 'degraded' + return True + return False + + async def record_success(self, name: str) -> None: + """Reset failure counters after a successful MCP RPC.""" + async with self._sync_lock: + state = self._states.get(name) + if state is None: + return + state.consecutive_failures = 0 + state.last_success_at = _utc_now() + + async def _record_connect_failure(self, name: str, message: str) -> None: + state = self._states.get(name) + if state is None: + return + state.status = 'error' + state.last_error = message + state.last_failure_at = _utc_now() + state.consecutive_failures += 1 + state.failure_history.append( + MCPFailureRecord( + at=state.last_failure_at, + phase='connect', + message=message, + )) + + # ── ToolManager integration ──────────────────────────────────────── + + def bind_tool_manager(self, tool_manager: 'ToolManager') -> None: + self._tool_manager = tool_manager + + async def sync_tools(self) -> None: + async with self._sync_lock: + await self._sync_tools_unlocked() + + async def _sync_tools_unlocked(self) -> None: + if self._tool_manager is None: + return + indexable: set[str] = set() + callable_servers: set[str] = set() + for name, state in self._states.items(): + if not state.enabled: + continue + if state.status == 'connected': + indexable.add(name) + callable_servers.add(name) + failures = await self._tool_manager.sync_mcp_tools( + visible_servers=set(self._states.keys()), + indexable_servers=indexable, + callable_servers=callable_servers, + cached_tools_by_server=None, + ) + needs_resync = False + for server_name, exc in failures: + if self._apply_failure_state( + server_name, + 'list_tools', + str(exc), + exc=exc, + ): + needs_resync = True + if needs_resync: + await self._sync_tools_unlocked() + + def _require_state(self, name: str) -> MCPServerState: + state = self._states.get(name) + if state is None: + raise KeyError(f'Unknown MCP server: {name}') + return state diff --git a/ms_agent/permission/__init__.py b/ms_agent/permission/__init__.py new file mode 100644 index 000000000..24e74cc6a --- /dev/null +++ b/ms_agent/permission/__init__.py @@ -0,0 +1,35 @@ +"""Permission module — dual-layer permission control for tool calls. + +Outer layer (PermissionEnforcer): user-intent based, configurable, overridable. +Inner layer (SafetyGuard): safety baseline, non-bypassable. +""" + +from .ask_resolver import resolve_ask +from .config import PermissionConfig, SafetyConfig +from .enforcer import PermissionDecision, PermissionEnforcer +from .handler import ( + AutoPermissionHandler, + CLIPermissionHandler, + PermissionAction, + PermissionHandler, + PermissionResponse, + WebPermissionHandler, +) +from .memory import PermissionMemory +from .safety import SafetyGuard + +__all__ = [ + 'resolve_ask', + 'PermissionConfig', + 'SafetyConfig', + 'PermissionDecision', + 'PermissionEnforcer', + 'AutoPermissionHandler', + 'CLIPermissionHandler', + 'PermissionAction', + 'PermissionHandler', + 'PermissionResponse', + 'WebPermissionHandler', + 'PermissionMemory', + 'SafetyGuard', +] diff --git a/ms_agent/permission/ask_resolver.py b/ms_agent/permission/ask_resolver.py new file mode 100644 index 000000000..c1d54f5d8 --- /dev/null +++ b/ms_agent/permission/ask_resolver.py @@ -0,0 +1,63 @@ +"""Resolve SafetyGuard ``ask`` decisions based on permission mode. + +auto mode: per-category allow/deny (no interactive prompts) +strict mode: all ask → deny +interactive: ask unchanged (delegated to handler) +""" + +from __future__ import annotations + +from typing import Literal + +from .shell_validator import SafetyDecision + +_AUTO_CATEGORY_POLICY: dict[str, Literal['allow', 'deny']] = { + 'process_input_sub': 'allow', + 'process_output_sub': 'deny', + 'parse_failure': 'deny', + 'cd_write_compound': 'deny', + 'command_validator': 'deny', + 'shell_expansion': 'deny', + 'read_outside_dirs': 'deny', +} + + +def resolve_ask( + decision: SafetyDecision, + mode: str, + read_policy: str = 'loose', +) -> SafetyDecision: + """Resolve a SafetyGuard ``ask`` into ``allow`` or ``deny`` (or keep ``ask``). + + Only processes decisions with ``action='ask'``; others pass through unchanged. + """ + if decision.action != 'ask': + return decision + + if mode == 'strict': + return SafetyDecision( + action='deny', + reason=f'Denied in strict mode: {decision.reason}', + category=decision.category, + ) + + if mode == 'interactive': + return decision + + # auto mode — resolve by category + category = decision.category + + if category == 'read_outside_dirs': + action: Literal['allow', 'deny'] = 'allow' if read_policy == 'loose' else 'deny' + return SafetyDecision( + action=action, + reason=decision.reason, + category=category, + ) + + resolved_action = _AUTO_CATEGORY_POLICY.get(category, 'deny') + return SafetyDecision( + action=resolved_action, + reason=decision.reason, + category=category, + ) diff --git a/ms_agent/permission/config.py b/ms_agent/permission/config.py new file mode 100644 index 000000000..5394fc55d --- /dev/null +++ b/ms_agent/permission/config.py @@ -0,0 +1,141 @@ +"""Configuration parsing for the permission module. + +Reads the ``permission`` section from agent YAML and produces frozen +dataclasses consumed by SafetyGuard and PermissionEnforcer. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any, Literal + +# Default safety rules baked into SafetyConfig when none are configured. +_DEFAULT_SAFETY_PATTERNS: tuple[str, ...] = ( + 'code_executor---shell_executor:rm -rf /*', + 'code_executor---shell_executor:mkfs *', + 'code_executor---shell_executor:dd if=*', +) + +_DEFAULT_SENSITIVE_PATHS: tuple[str, ...] = ( + '/etc/*', + '/sys/*', + '/boot/*', + '/dev/*', + '/proc/*', + '~/.ssh/*', + '~/.gnupg/*', + '~/.bashrc', + '~/.zshrc', + '~/.profile', + '.git/config', + '.git/hooks/*', + '**/.git/**', +) + +_DEFAULT_DANGEROUS_REMOVAL: tuple[str, ...] = ( + '*', + '/*', + '/', + '~', +) + + +@dataclass(frozen=True) +class SafetyConfig: + """Inner-layer safety configuration (non-bypassable).""" + patterns: tuple[str, ...] = _DEFAULT_SAFETY_PATTERNS + sensitive_paths: tuple[str, ...] = _DEFAULT_SENSITIVE_PATHS + dangerous_removal_paths: tuple[str, ...] = _DEFAULT_DANGEROUS_REMOVAL + read_policy: Literal['loose', 'strict'] = 'loose' + max_command_chars: int = 8192 + allowed_directories: tuple[str, ...] = () + read_only_directories: tuple[str, ...] = () + + @classmethod + def from_dict(cls, d: dict[str, Any], project_root: str | None = None) -> SafetyConfig: + patterns = tuple(d.get('patterns', _DEFAULT_SAFETY_PATTERNS)) + sensitive = tuple(d.get('sensitive_paths', _DEFAULT_SENSITIVE_PATHS)) + dangerous = tuple(d.get('dangerous_removal_paths', _DEFAULT_DANGEROUS_REMOVAL)) + + path_validation = d.get('path_validation', {}) + read_policy = path_validation.get('read_policy', 'loose') + max_chars = path_validation.get('max_command_chars', 8192) + + def _expand_dirs(raw: list[str]) -> tuple[str, ...]: + out: list[str] = [] + for entry in raw: + if entry == '${PROJECT_ROOT}' and project_root: + out.append(project_root) + else: + out.append(os.path.expandvars(entry)) + return tuple(out) + + allowed = _expand_dirs(list(d.get('allowed_directories', []))) + read_only = _expand_dirs(list(d.get('read_only_directories', []))) + + return cls( + patterns=patterns, + sensitive_paths=sensitive, + dangerous_removal_paths=dangerous, + read_policy=read_policy, + max_command_chars=max_chars, + allowed_directories=allowed, + read_only_directories=read_only, + ) + + +_DEFAULT_BLACKLIST: tuple[str, ...] = ( + 'code_executor---shell_executor:curl *', + 'code_executor---shell_executor:wget *', + 'code_executor---shell_executor:ssh *', + 'code_executor---shell_executor:scp *', + 'code_executor---shell_executor:rsync *', + 'code_executor---shell_executor:nc *', + 'code_executor---shell_executor:netcat *', +) + + +@dataclass(frozen=True) +class PermissionConfig: + """Top-level permission configuration from agent YAML.""" + mode: Literal['auto', 'strict', 'interactive'] = 'auto' + whitelist: tuple[str, ...] = () + blacklist: tuple[str, ...] = _DEFAULT_BLACKLIST + ask_rules: tuple[str, ...] = () + safety: SafetyConfig = SafetyConfig() + + @classmethod + def from_dict(cls, d: dict[str, Any], project_root: str | None = None) -> PermissionConfig: + if not d: + return cls() + + raw_mode = d.get('mode', 'auto') + _MODE_ALIASES = {'restricted': 'interactive'} + mode = _MODE_ALIASES.get(raw_mode, raw_mode) + whitelist = tuple(d.get('whitelist', ())) + ask_rules = tuple(d.get('ask_rules', ())) + user_blacklist = tuple(d.get('blacklist', ())) + blacklist = _DEFAULT_BLACKLIST + tuple( + p for p in user_blacklist if p not in _DEFAULT_BLACKLIST + ) + + safety_raw = d.get('safety_rules', {}) + # Merge directory configs from top level into safety config + for _dir_key in ('allowed_directories', 'read_only_directories'): + if _dir_key in d and _dir_key not in safety_raw: + safety_raw = dict(safety_raw) + safety_raw[_dir_key] = d[_dir_key] + if 'path_validation' in d and 'path_validation' not in safety_raw: + safety_raw = dict(safety_raw) + safety_raw['path_validation'] = d['path_validation'] + + safety = SafetyConfig.from_dict(safety_raw, project_root=project_root) + + return cls( + mode=mode, + whitelist=whitelist, + blacklist=blacklist, + ask_rules=ask_rules, + safety=safety, + ) diff --git a/ms_agent/permission/enforcer.py b/ms_agent/permission/enforcer.py new file mode 100644 index 000000000..864fbaf60 --- /dev/null +++ b/ms_agent/permission/enforcer.py @@ -0,0 +1,138 @@ +"""PermissionEnforcer: outer-layer user-intent permission control. + +Checks blacklist/whitelist, session/persistent memory, and falls back to +the PermissionHandler for interactive user confirmation. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +from .config import PermissionConfig +from .handler import ( + AutoPermissionHandler, + PermissionAction, + PermissionHandler, + PermissionResponse, +) +from .matcher import PermissionMatcher +from .memory import PermissionMemory +from .suggestions import generate_suggestions + + +@dataclass(frozen=True) +class PermissionDecision: + action: Literal['allow', 'deny', 'ask'] + reason: str + updated_args: dict[str, Any] | None = None + + +class PermissionEnforcer: + """Outer-layer permission enforcement based on user intent and configuration.""" + + def __init__( + self, + config: PermissionConfig, + handler: PermissionHandler | None = None, + memory: PermissionMemory | None = None, + ) -> None: + self._config = config + self._handler = handler or AutoPermissionHandler() + self._memory = memory or PermissionMemory() + self._matcher = PermissionMatcher() + + async def check( + self, + tool_name: str, + tool_args: dict[str, Any], + *, + force_decision: PermissionDecision | None = None, + ) -> PermissionDecision: + if force_decision and force_decision.action == 'ask': + suggestions = generate_suggestions(tool_name, tool_args) + response = await self._handler.ask( + tool_name=tool_name, + tool_args=tool_args, + context=force_decision.reason or '', + suggestions=suggestions, + ) + return self._process_response(response, tool_name, tool_args) + + # 1. Auto / strict mode → allow everything (safety handled by SafetyGuard + ask_resolver) + if self._config.mode in ('auto', 'strict'): + return PermissionDecision(action='allow', reason=f'{self._config.mode.capitalize()} mode') + + # 2. Blacklist → deny (not overridable) + for pattern in self._config.blacklist: + if self._matcher.match_with_content(pattern, tool_name, tool_args): + return PermissionDecision( + action='deny', + reason=f'Denied by blacklist rule: {pattern}', + ) + + # 3. Whitelist → allow + for pattern in self._config.whitelist: + if self._matcher.match_with_content(pattern, tool_name, tool_args): + return PermissionDecision( + action='allow', + reason=f'Allowed by whitelist rule: {pattern}', + ) + + # 4. Memory (session + persistent) → allow + if self._memory.matches(tool_name, tool_args): + return PermissionDecision( + action='allow', + reason='Allowed by remembered permission', + ) + + # 5. Ask user via handler + suggestions = generate_suggestions(tool_name, tool_args) + response = await self._handler.ask( + tool_name=tool_name, + tool_args=tool_args, + context='', + suggestions=suggestions, + ) + + return self._process_response(response, tool_name, tool_args) + + def _process_response( + self, + response: PermissionResponse, + tool_name: str, + tool_args: dict[str, Any], + ) -> PermissionDecision: + if response.action == PermissionAction.ALLOW_ONCE: + return PermissionDecision(action='allow', reason='User allowed once') + + if response.action == PermissionAction.ALLOW_SESSION: + pattern = response.pattern or tool_name + self._memory.add_session(pattern) + return PermissionDecision( + action='allow', + reason=f'User allowed for session (pattern: {pattern})', + ) + + if response.action == PermissionAction.ALLOW_ALWAYS: + pattern = response.pattern or tool_name + self._memory.add(pattern, scope='project', source='user') + return PermissionDecision( + action='allow', + reason=f'User allowed always (pattern: {pattern})', + ) + + if response.action == PermissionAction.MODIFY: + return PermissionDecision( + action='allow', + reason='User modified args', + updated_args=response.updated_args, + ) + + if response.action == PermissionAction.DENY: + return PermissionDecision( + action='deny', + reason=response.feedback or 'User denied', + ) + + return PermissionDecision(action='deny', reason='Unknown action') diff --git a/ms_agent/permission/handler.py b/ms_agent/permission/handler.py new file mode 100644 index 000000000..622c12543 --- /dev/null +++ b/ms_agent/permission/handler.py @@ -0,0 +1,180 @@ +"""PermissionHandler protocol and implementations. + +Three implementations: + - AutoPermissionHandler: always allow (fallback). + - CLIPermissionHandler: interactive terminal menu. + - WebPermissionHandler: Future-based async with event emitter. +""" + +from __future__ import annotations + +import asyncio +import json +import sys +from dataclasses import dataclass +from enum import Enum +from typing import Any, Protocol +from uuid import uuid4 + + +class PermissionAction(str, Enum): + ALLOW_ONCE = 'allow_once' + ALLOW_SESSION = 'allow_session' + ALLOW_ALWAYS = 'allow_always' + DENY = 'deny' + MODIFY = 'modify' + + +@dataclass(frozen=True) +class PermissionResponse: + action: PermissionAction + updated_args: dict[str, Any] | None = None + pattern: str | None = None + feedback: str | None = None + + +class PermissionHandler(Protocol): + async def ask( + self, + tool_name: str, + tool_args: dict[str, Any], + context: str, + suggestions: list[str] | None = None, + ) -> PermissionResponse: ... + + +class AutoPermissionHandler: + """Always allows — used as fallback or in auto mode.""" + + async def ask( + self, + tool_name: str, + tool_args: dict[str, Any], + context: str, + suggestions: list[str] | None = None, + ) -> PermissionResponse: + return PermissionResponse(action=PermissionAction.ALLOW_ONCE) + + +class CLIPermissionHandler: + """Interactive CLI permission prompt.""" + + async def ask( + self, + tool_name: str, + tool_args: dict[str, Any], + context: str, + suggestions: list[str] | None = None, + ) -> PermissionResponse: + args_display = json.dumps(tool_args, ensure_ascii=False, indent=2) + if len(args_display) > 500: + args_display = args_display[:500] + '...' + + suggestion = suggestions[0] if suggestions else tool_name + + print(f'\n{"="*60}', file=sys.stderr) + print(f' Permission Required', file=sys.stderr) + print(f'{"="*60}', file=sys.stderr) + print(f' Tool: {tool_name}', file=sys.stderr) + print(f' Args: {args_display}', file=sys.stderr) + if context: + print(f' Context: {context}', file=sys.stderr) + print(f'{"─"*60}', file=sys.stderr) + print(f' [y] Allow this once', file=sys.stderr) + print(f' [s] Allow for this session', file=sys.stderr) + print(f' [a] Always allow (pattern: {suggestion})', file=sys.stderr) + print(f' [e] Edit args then execute', file=sys.stderr) + print(f' [n] Deny', file=sys.stderr) + print(f'{"="*60}', file=sys.stderr) + + loop = asyncio.get_running_loop() + choice = await loop.run_in_executor(None, lambda: input('Choice [y/s/a/e/n]: ').strip().lower()) + + if choice == 's': + return PermissionResponse( + action=PermissionAction.ALLOW_SESSION, + pattern=suggestion, + ) + elif choice == 'a': + edited = await loop.run_in_executor( + None, + lambda: input(f'Pattern [{suggestion}]: ').strip(), + ) + final_pattern = edited if edited else suggestion + return PermissionResponse( + action=PermissionAction.ALLOW_ALWAYS, + pattern=final_pattern, + ) + elif choice == 'e': + edited_raw = await loop.run_in_executor( + None, + lambda: input('New args (JSON): ').strip(), + ) + try: + new_args = json.loads(edited_raw) + except json.JSONDecodeError: + print('Invalid JSON, denying.', file=sys.stderr) + return PermissionResponse(action=PermissionAction.DENY) + return PermissionResponse( + action=PermissionAction.MODIFY, + updated_args=new_args, + ) + elif choice == 'n': + return PermissionResponse(action=PermissionAction.DENY) + else: + return PermissionResponse(action=PermissionAction.ALLOW_ONCE) + + +class EventEmitter(Protocol): + """Protocol for pushing events to the frontend.""" + def emit(self, event: dict[str, Any]) -> None: ... + + +class WebPermissionHandler: + """Async handler that suspends on a Future until the frontend responds.""" + + def __init__( + self, + event_emitter: EventEmitter, + timeout: float = 120.0, + ) -> None: + self._pending: dict[str, asyncio.Future[PermissionResponse]] = {} + self._event_emitter = event_emitter + self._timeout = timeout + + async def ask( + self, + tool_name: str, + tool_args: dict[str, Any], + context: str, + suggestions: list[str] | None = None, + ) -> PermissionResponse: + request_id = uuid4().hex + loop = asyncio.get_running_loop() + future: asyncio.Future[PermissionResponse] = loop.create_future() + self._pending[request_id] = future + + self._event_emitter.emit({ + 'type': 'permission_request', + 'request_id': request_id, + 'tool_name': tool_name, + 'tool_args': tool_args, + 'context': context, + 'suggestions': suggestions or [], + 'options': [a.value for a in PermissionAction], + }) + + try: + return await asyncio.wait_for(future, timeout=self._timeout) + except asyncio.TimeoutError: + return PermissionResponse( + action=PermissionAction.DENY, + feedback='Permission request timed out', + ) + finally: + self._pending.pop(request_id, None) + + def resolve(self, request_id: str, response: PermissionResponse) -> None: + future = self._pending.get(request_id) + if future and not future.done(): + future.set_result(response) diff --git a/ms_agent/permission/matcher.py b/ms_agent/permission/matcher.py new file mode 100644 index 000000000..6b40a3eeb --- /dev/null +++ b/ms_agent/permission/matcher.py @@ -0,0 +1,81 @@ +"""Shared wildcard matching for permission rules. + +Rule format: ``server---tool`` or ``server---tool:content_pattern`` +Supports ``*`` / ``?`` wildcards via fnmatch, ``|`` to separate alternatives. +""" + +from __future__ import annotations + +from typing import Any + +from ms_agent.utils.pattern_matcher import match_pattern + + +TOOL_SPLITER = '---' +CONTENT_SEP = ':' + + +def _extract_content(tool_name: str, tool_args: dict[str, Any]) -> str | None: + """Extract the primary content string from tool args for content-pattern matching.""" + val = None + if tool_name.endswith(f'{TOOL_SPLITER}shell_executor'): + val = tool_args.get('command') + elif tool_name.endswith(f'{TOOL_SPLITER}write_file'): + val = tool_args.get('path') + elif tool_name.endswith(f'{TOOL_SPLITER}read_file'): + val = tool_args.get('path') + elif tool_name.endswith(f'{TOOL_SPLITER}edit_file'): + val = tool_args.get('path') + elif tool_name.endswith(f'{TOOL_SPLITER}grep'): + val = tool_args.get('pattern') + elif tool_name.endswith(f'{TOOL_SPLITER}glob'): + val = tool_args.get('pattern') + else: + for key in ('path', 'command', 'query', 'url', 'pattern'): + if key in tool_args: + val = tool_args[key] + break + return str(val) if val is not None else None + + +class PermissionMatcher: + """Wildcard matcher for permission rules, shared by both SafetyGuard and PermissionEnforcer.""" + + def match(self, pattern: str, tool_call: str) -> bool: + """Match a tool call string against a pattern using fnmatch. + + Supports ``|`` separated alternatives: ``read_file|write_file``. + """ + return match_pattern(pattern, tool_call) + + def match_with_content( + self, + pattern: str, + tool_name: str, + tool_args: dict[str, Any], + ) -> bool: + """Match with optional content pattern after ``:``. + + Examples:: + + "file_system---read_file" → matches tool name only + "code_executor---shell_executor:pip *" → matches tool name + command content + "file_system---*" → wildcard on tool name + """ + if CONTENT_SEP in pattern: + tool_pattern, content_pattern = pattern.split(CONTENT_SEP, 1) + else: + tool_pattern = pattern + content_pattern = None + + if not self.match(tool_pattern, tool_name): + return False + + if content_pattern is None: + return True + + content = _extract_content(tool_name, tool_args) + if content is None: + return False + + return self.match(content_pattern, content) diff --git a/ms_agent/permission/memory.py b/ms_agent/permission/memory.py new file mode 100644 index 000000000..f78aca073 --- /dev/null +++ b/ms_agent/permission/memory.py @@ -0,0 +1,152 @@ +"""PermissionMemory: persist user ``allow_always`` decisions across sessions. + +Two storage scopes: + - Project: ``.ms_agent/permission_memory.json`` + - Global: ``~/.ms_agent/permission_memory.json`` + +Session-level memory (``allow_session``) lives only in-process. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal, Sequence + +from .matcher import PermissionMatcher + + +@dataclass(frozen=True) +class MemoryEntry: + pattern: str + scope: Literal['project', 'global'] + source: Literal['user', 'plugin', 'hook'] = 'user' + created_at: str = '' + + +class PermissionMemory: + """Manages persistent and session-level permission rules.""" + + def __init__( + self, + project_path: str | Path | None = None, + global_path: str | Path | None = None, + ) -> None: + self._matcher = PermissionMatcher() + + self._project_file: Path | None = None + if project_path is not None: + self._project_file = Path(project_path) / '.ms_agent' / 'permission_memory.json' + + if global_path is not None: + self._global_file = Path(global_path) + else: + self._global_file = Path.home() / '.ms_agent' / 'permission_memory.json' + + self._project_entries: list[MemoryEntry] = [] + self._global_entries: list[MemoryEntry] = [] + self._session_patterns: list[str] = [] + + self._load() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def add( + self, + pattern: str, + scope: Literal['project', 'global'] = 'project', + source: Literal['user', 'plugin', 'hook'] = 'user', + ) -> None: + entries = self._project_entries if scope == 'project' else self._global_entries + if any(e.pattern == pattern for e in entries): + return + entry = MemoryEntry( + pattern=pattern, + scope=scope, + source=source, + created_at=datetime.now(timezone.utc).isoformat(), + ) + entries.append(entry) + self._save(scope) + + def add_session(self, pattern: str) -> None: + if pattern not in self._session_patterns: + self._session_patterns.append(pattern) + + def matches(self, tool_name: str, tool_args: dict[str, Any]) -> bool: + for pattern in self._session_patterns: + if self._matcher.match_with_content(pattern, tool_name, tool_args): + return True + for entry in self._project_entries: + if self._matcher.match_with_content(entry.pattern, tool_name, tool_args): + return True + for entry in self._global_entries: + if self._matcher.match_with_content(entry.pattern, tool_name, tool_args): + return True + return False + + def revoke(self, pattern: str) -> int: + """Remove all entries matching the given pattern. Returns count removed.""" + count = 0 + before = len(self._project_entries) + self._project_entries = [e for e in self._project_entries if e.pattern != pattern] + count += before - len(self._project_entries) + + before = len(self._global_entries) + self._global_entries = [e for e in self._global_entries if e.pattern != pattern] + count += before - len(self._global_entries) + + self._session_patterns = [p for p in self._session_patterns if p != pattern] + + if count > 0: + self._save('project') + self._save('global') + return count + + def list_all(self) -> list[MemoryEntry]: + return list(self._project_entries) + list(self._global_entries) + + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def _load(self) -> None: + self._project_entries = self._load_file(self._project_file, 'project') + self._global_entries = self._load_file(self._global_file, 'global') + + @staticmethod + def _load_file(path: Path | None, scope: str) -> list[MemoryEntry]: + if path is None or not path.exists(): + return [] + try: + data = json.loads(path.read_text(encoding='utf-8')) + return [ + MemoryEntry( + pattern=e['pattern'], + scope=e.get('scope', scope), + source=e.get('source', 'user'), + created_at=e.get('created_at', ''), + ) + for e in data + ] + except (json.JSONDecodeError, KeyError, TypeError): + return [] + + def _save(self, scope: Literal['project', 'global']) -> None: + if scope == 'project': + self._save_file(self._project_file, self._project_entries) + else: + self._save_file(self._global_file, self._global_entries) + + @staticmethod + def _save_file(path: Path | None, entries: list[MemoryEntry]) -> None: + if path is None: + return + path.parent.mkdir(parents=True, exist_ok=True) + data = [asdict(e) for e in entries] + path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding='utf-8') diff --git a/ms_agent/permission/path_extractors.py b/ms_agent/permission/path_extractors.py new file mode 100644 index 000000000..1bb37a135 --- /dev/null +++ b/ms_agent/permission/path_extractors.py @@ -0,0 +1,378 @@ +"""PATH_EXTRACTORS registry: per-command path extraction for shell commands. + +Five extraction strategies: + A) filter_out_flags — 27 commands + B) parse_pattern_command — grep, rg + C) special arg skip — sed, jq + D) search-start collection — find + E) subcommand dispatch — git + Special — cd, ls, tr +""" + +from __future__ import annotations + +import os +import re +from dataclasses import dataclass +from typing import Callable, Literal + +CommandExtractor = Callable[[list[str]], list[str]] +CommandValidator = Callable[[list[str]], str | None] + + +@dataclass(frozen=True) +class ExtractorEntry: + extractor: CommandExtractor + op_type: Literal['read', 'write', 'create'] + description: str + command_validator: CommandValidator | None = None + + +# --------------------------------------------------------------------------- +# Strategy A: filter_out_flags +# --------------------------------------------------------------------------- + +def filter_out_flags(args: list[str]) -> list[str]: + """Keep non-flag arguments, respecting ``--`` separator.""" + result: list[str] = [] + after_double_dash = False + for arg in args: + if after_double_dash: + result.append(arg) + elif arg == '--': + after_double_dash = True + elif not arg.startswith('-'): + result.append(arg) + return result + + +# --------------------------------------------------------------------------- +# Strategy B: parse_pattern_command (grep / rg) +# --------------------------------------------------------------------------- + +def parse_pattern_command( + args: list[str], + flags_with_args: set[str], + defaults: list[str] | None = None, +) -> list[str]: + """Extract file paths from pattern-based commands (grep/rg). + + First non-flag arg is the search pattern (skipped); rest are file paths. + If ``-e``/``-f`` explicitly provides the pattern, all non-flag args are paths. + """ + paths: list[str] = [] + pattern_found = False + after_double_dash = False + + i = 0 + while i < len(args): + arg = args[i] + if after_double_dash: + paths.append(arg) + i += 1 + continue + if arg == '--': + after_double_dash = True + i += 1 + continue + if arg.startswith('-'): + flag = arg.split('=')[0] + if flag in ('-e', '--regexp', '-f', '--file'): + pattern_found = True + if flag in flags_with_args and '=' not in arg: + i += 1 # skip flag value + i += 1 + continue + if not pattern_found: + pattern_found = True + i += 1 + continue # skip the pattern itself + paths.append(arg) + i += 1 + return paths if paths else (defaults or []) + + +# --------------------------------------------------------------------------- +# Strategy C: special arg skip (sed / jq) +# --------------------------------------------------------------------------- + +def extract_sed(args: list[str]) -> list[str]: + """Extract file paths from sed, skipping expression arguments.""" + paths: list[str] = [] + skip_next = False + script_found = False + after_dd = False + + for i, arg in enumerate(args): + if skip_next: + skip_next = False + continue + if not after_dd and arg == '--': + after_dd = True + continue + if not after_dd and arg.startswith('-'): + if arg in ('-f', '--file'): + if i + 1 < len(args): + paths.append(args[i + 1]) + skip_next = True + script_found = True + elif arg in ('-e', '--expression'): + skip_next = True + script_found = True + elif 'e' in arg[1:] or 'f' in arg[1:]: + script_found = True + continue + if not script_found: + script_found = True + continue # skip inline expression + paths.append(arg) + return paths + + +_JQ_FLAGS_WITH_ARGS = frozenset({ + '-f', '--from-file', '--arg', '--argjson', '--slurpfile', + '--rawfile', '-L', '--indent', '--jsonargs', '--args', +}) + + +def extract_jq(args: list[str]) -> list[str]: + """Extract file paths from jq, skipping filter expression.""" + paths: list[str] = [] + filter_found = False + after_double_dash = False + + i = 0 + while i < len(args): + arg = args[i] + if after_double_dash: + paths.append(arg) + i += 1 + continue + if arg == '--': + after_double_dash = True + i += 1 + continue + if arg.startswith('-'): + flag = arg.split('=')[0] + if flag in _JQ_FLAGS_WITH_ARGS and '=' not in arg: + i += 1 + i += 1 + continue + if not filter_found: + filter_found = True + i += 1 + continue # skip the filter + paths.append(arg) + i += 1 + return paths + + +# --------------------------------------------------------------------------- +# Strategy D: search-start collection (find) +# --------------------------------------------------------------------------- + +_FIND_PATH_FLAGS = frozenset({ + '-newer', '-anewer', '-cnewer', '-mnewer', '-samefile', + '-path', '-wholename', '-ilname', '-lname', '-ipath', '-iwholename', +}) +_FIND_NEWER_PATTERN = re.compile(r'^-newer[acmBt][acmtB]$') + + +def extract_find(args: list[str]) -> list[str]: + """Extract search starting points and path-valued flags from find.""" + paths: list[str] = [] + found_non_global_flag = False + after_double_dash = False + + i = 0 + while i < len(args): + arg = args[i] + if after_double_dash: + paths.append(arg) + i += 1 + continue + if arg == '--': + after_double_dash = True + i += 1 + continue + if arg.startswith('-'): + if arg in ('-H', '-L', '-P'): + i += 1 + continue + found_non_global_flag = True + if arg in _FIND_PATH_FLAGS or _FIND_NEWER_PATTERN.match(arg): + if i + 1 < len(args): + paths.append(args[i + 1]) + i += 1 + i += 1 + continue + if not found_non_global_flag: + paths.append(arg) + i += 1 + return paths if paths else ['.'] + + +# --------------------------------------------------------------------------- +# Strategy E: subcommand dispatch (git) +# --------------------------------------------------------------------------- + +def extract_git(args: list[str]) -> list[str]: + """Extract paths only for ``git diff --no-index``.""" + if args and args[0] == 'diff' and '--no-index' in args: + return filter_out_flags(args[1:])[:2] + return [] + + +# --------------------------------------------------------------------------- +# Special commands: cd, ls, tr +# --------------------------------------------------------------------------- + +def extract_cd(args: list[str]) -> list[str]: + if not args: + return [os.path.expanduser('~')] + return [' '.join(args)] + + +def extract_ls(args: list[str]) -> list[str]: + paths = filter_out_flags(args) + return paths if paths else ['.'] + + +def extract_tr(args: list[str]) -> list[str]: + has_delete = any( + a == '-d' or a == '--delete' or (a.startswith('-') and 'd' in a[1:]) + for a in args + ) + non_flags = filter_out_flags(args) + skip_count = 1 if has_delete else 2 + return non_flags[skip_count:] + + +# --------------------------------------------------------------------------- +# grep / rg specific flag sets +# --------------------------------------------------------------------------- + +_GREP_FLAGS_WITH_ARGS = frozenset({ + '-e', '--regexp', '-f', '--file', + '--exclude', '--include', '--exclude-dir', '--include-dir', + '-m', '--max-count', + '-A', '--after-context', '-B', '--before-context', '-C', '--context', + '--label', '--color', +}) + +_RG_FLAGS_WITH_ARGS = frozenset({ + '-e', '--regexp', '-f', '--file', + '-t', '--type', '-T', '--type-not', + '-g', '--glob', '-m', '--max-count', '--max-depth', + '-r', '--replace', + '-A', '--after-context', '-B', '--before-context', '-C', '--context', + '--color', '--colors', '--encoding', '-E', + '--iglob', '--type-add', '--type-clear', +}) + + +def _extract_grep(args: list[str]) -> list[str]: + has_recursive = any(a in ('-r', '-R', '--recursive') for a in args) + paths = parse_pattern_command(args, _GREP_FLAGS_WITH_ARGS) + if not paths and has_recursive: + return ['.'] + return paths + + +def _extract_rg(args: list[str]) -> list[str]: + return parse_pattern_command(args, _RG_FLAGS_WITH_ARGS, defaults=['.']) + + +# --------------------------------------------------------------------------- +# Command validators (mv / cp) +# --------------------------------------------------------------------------- + +def _validate_mv_cp(args: list[str]) -> str | None: + """Reject mv/cp calls with flags (--target-directory bypass risk).""" + for arg in args: + if arg == '--': + break + if arg.startswith('-'): + return f'mv/cp with flags requires confirmation (possible --target-directory bypass)' + return None + + +# --------------------------------------------------------------------------- +# Registry builder +# --------------------------------------------------------------------------- + +def _make_filter_entry( + op_type: Literal['read', 'write', 'create'], + description: str, + *, + validator: CommandValidator | None = None, +) -> ExtractorEntry: + return ExtractorEntry( + extractor=filter_out_flags, + op_type=op_type, + description=description, + command_validator=validator, + ) + + +def build_extractor_registry() -> dict[str, ExtractorEntry]: + """Build the full 34-command extractor registry.""" + registry: dict[str, ExtractorEntry] = {} + + # Special commands + registry['cd'] = ExtractorEntry(extract_cd, 'read', 'change directories to') + registry['ls'] = ExtractorEntry(extract_ls, 'read', 'list files in') + registry['tr'] = ExtractorEntry(extract_tr, 'read', 'transform text from files in') + + # Strategy D + registry['find'] = ExtractorEntry(extract_find, 'read', 'search files in') + + # Strategy B + registry['grep'] = ExtractorEntry(_extract_grep, 'read', 'search for patterns in files from') + registry['rg'] = ExtractorEntry(_extract_rg, 'read', 'search for patterns in files from') + + # Strategy C + registry['sed'] = ExtractorEntry(extract_sed, 'write', 'edit files in') + registry['jq'] = ExtractorEntry(extract_jq, 'read', 'process JSON from files in') + + # Strategy E + registry['git'] = ExtractorEntry(extract_git, 'read', 'access files with git from') + + # Strategy A: create + for cmd in ('mkdir', 'touch'): + registry[cmd] = _make_filter_entry('create', f'create {"directories" if cmd == "mkdir" else "or modify files"} in') + + # Strategy A: write (with special validators for mv/cp) + registry['rm'] = _make_filter_entry('write', 'remove files from') + registry['rmdir'] = _make_filter_entry('write', 'remove directories from') + registry['mv'] = _make_filter_entry('write', 'move files to/from', validator=_validate_mv_cp) + registry['cp'] = _make_filter_entry('write', 'copy files to/from', validator=_validate_mv_cp) + + # Strategy A: read (21 commands) + _read_commands = { + 'cat': 'concatenate files from', + 'head': 'read the beginning of files from', + 'tail': 'read the end of files from', + 'sort': 'sort contents of files from', + 'uniq': 'filter duplicate lines from files in', + 'wc': 'count lines/words/bytes in files from', + 'cut': 'extract columns from files in', + 'paste': 'merge files from', + 'column': 'format files from', + 'file': 'examine file types in', + 'stat': 'read file stats from', + 'diff': 'compare files from', + 'awk': 'process text from files in', + 'strings': 'extract strings from files in', + 'hexdump': 'display hex dump of files from', + 'od': 'display octal dump of files from', + 'base64': 'encode/decode files from', + 'nl': 'number lines in files from', + 'sha256sum': 'compute SHA-256 checksums for files in', + 'sha1sum': 'compute SHA-1 checksums for files in', + 'md5sum': 'compute MD5 checksums for files in', + } + for cmd, desc in _read_commands.items(): + registry[cmd] = _make_filter_entry('read', desc) + + return registry diff --git a/ms_agent/permission/path_validator.py b/ms_agent/permission/path_validator.py new file mode 100644 index 000000000..f2fc4a7b5 --- /dev/null +++ b/ms_agent/permission/path_validator.py @@ -0,0 +1,178 @@ +"""Single-path validation: quote stripping, tilde expansion, shell-expansion +rejection, glob handling, directory-scope checks, and dangerous-path detection.""" + +from __future__ import annotations + +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Sequence + +GLOB_CHARS = set('*?[]{}') +_CONSECUTIVE_SLASHES = re.compile(r'[/\\]+') +_WINDOWS_DRIVE_ROOT = re.compile(r'^[A-Za-z]:/?$') +_WINDOWS_DRIVE_CHILD = re.compile(r'^[A-Za-z]:/[^/]+$') +_ROOT_CHILD = re.compile(r'^/[^/]+$') + + +@dataclass(frozen=True) +class PathValidationResult: + allowed: bool + resolved_path: str + action: Literal['allow', 'deny', 'ask'] + reason: str + category: str = '' + + +def _strip_quotes(path: str) -> str: + if len(path) >= 2: + if (path[0] == path[-1]) and path[0] in ('"', "'"): + return path[1:-1] + return path + + +def _expand_tilde(path: str, home_dir: str) -> tuple[str, str | None]: + """Expand ``~`` and ``~/...``. Reject ``~user``, ``~+``, ``~-``.""" + if not path.startswith('~'): + return path, None + if path == '~': + return home_dir, None + if path.startswith('~/') or path.startswith('~\\'): + return home_dir + path[1:], None + return path, f'Unsupported tilde expansion: {path}' + + +def _has_shell_expansion(path: str) -> str | None: + if '$' in path: + return f'Path contains shell variable expansion: {path}' + if '%' in path: + return f'Path contains Windows variable expansion: {path}' + if path.startswith('='): + return f'Path starts with = (Zsh expansion): {path}' + return None + + +def _has_glob(path: str) -> bool: + return bool(GLOB_CHARS & set(path)) + + +def get_glob_base_directory(pattern: str) -> str: + """Extract the directory prefix before the first glob character.""" + first_glob = len(pattern) + for i, c in enumerate(pattern): + if c in GLOB_CHARS: + first_glob = i + break + base = pattern[:first_glob] + last_sep = base.rfind('/') + if last_sep < 0: + return '.' + return base[:last_sep] or '/' + + +def _is_under_allowed(resolved: Path, allowed_dirs: Sequence[str]) -> bool: + for d in allowed_dirs: + try: + resolved.relative_to(Path(d).resolve()) + return True + except ValueError: + continue + return False + + +def validate_path( + path: str, + cwd: str, + allowed_dirs: Sequence[str], + op_type: Literal['read', 'write', 'create'], + *, + read_only_dirs: Sequence[str] = (), + home_dir: str | None = None, +) -> PathValidationResult: + """Validate a single filesystem path for a given operation type. + + Returns a ``PathValidationResult`` with ``allowed=True`` if the path passes + all checks, or ``allowed=False`` with a reason explaining the rejection. + """ + if home_dir is None: + home_dir = os.path.expanduser('~') + + path = _strip_quotes(path) + + path, tilde_err = _expand_tilde(path, home_dir) + if tilde_err: + return PathValidationResult( + allowed=False, resolved_path=path, action='deny', reason=tilde_err, + ) + + expansion_err = _has_shell_expansion(path) + if expansion_err: + return PathValidationResult( + allowed=False, resolved_path=path, action='ask', reason=expansion_err, + category='shell_expansion', + ) + + if _has_glob(path): + if op_type in ('write', 'create'): + return PathValidationResult( + allowed=False, resolved_path=path, action='deny', + reason=f'Glob patterns not allowed in {op_type} operations: {path}', + ) + path = get_glob_base_directory(path) + + if os.path.isabs(path): + resolved = Path(path).resolve() + else: + resolved = (Path(cwd) / path).resolve() + + resolved_str = str(resolved) + + if not _is_under_allowed(resolved, allowed_dirs): + if op_type == 'read': + if _is_under_allowed(resolved, read_only_dirs): + return PathValidationResult( + allowed=True, resolved_path=resolved_str, action='allow', + reason='Path allowed via read-only directory', + ) + return PathValidationResult( + allowed=False, resolved_path=resolved_str, action='ask', + reason=f'Read path outside allowed directories: {resolved_str}', + category='read_outside_dirs', + ) + return PathValidationResult( + allowed=False, resolved_path=resolved_str, action='deny', + reason=f'{op_type.capitalize()} path outside allowed directories: {resolved_str}', + ) + + return PathValidationResult( + allowed=True, resolved_path=resolved_str, action='allow', + reason='Path validation passed', + ) + + +def is_dangerous_removal_path(path: str) -> bool: + """Check if a path is too dangerous for rm/rmdir, even within allowed dirs.""" + normalized = _CONSECUTIVE_SLASHES.sub('/', path) + if normalized.endswith('/') and len(normalized) > 1: + normalized = normalized.rstrip('/') + + if normalized == '*': + return True + if normalized.endswith('/*') or normalized.endswith('\\*'): + return True + if normalized == '/': + return True + + home = os.path.expanduser('~').replace('\\', '/') + if normalized == home: + return True + + if _ROOT_CHILD.match(normalized): + return True + if _WINDOWS_DRIVE_ROOT.match(normalized): + return True + if _WINDOWS_DRIVE_CHILD.match(normalized): + return True + + return False diff --git a/ms_agent/permission/safety.py b/ms_agent/permission/safety.py new file mode 100644 index 000000000..81a06206a --- /dev/null +++ b/ms_agent/permission/safety.py @@ -0,0 +1,91 @@ +"""SafetyGuard: inner-layer safety baseline that cannot be bypassed by users. + +Checks safety rules, file path validation, and shell command path-level +analysis before any tool call is allowed to execute. +""" + +from __future__ import annotations + +import fnmatch +import os +from dataclasses import dataclass +from typing import Any, Literal, Sequence + +from .config import SafetyConfig +from .matcher import PermissionMatcher +from .path_validator import validate_path +from .shell_validator import PathSafetyConfig, SafetyDecision, ShellPathValidator + + +class SafetyGuard: + """Inner-layer safety enforcement — not overridable by user configuration.""" + + def __init__( + self, + config: SafetyConfig, + allowed_dirs: Sequence[str], + read_only_dirs: Sequence[str] = (), + workspace_root: str | None = None, + ) -> None: + self._config = config + self._matcher = PermissionMatcher() + self._allowed_dirs = list(allowed_dirs) + self._read_only_dirs = list(read_only_dirs) + self._sensitive_paths = list(config.sensitive_paths) + self._workspace_root = workspace_root + + path_safety_cfg = PathSafetyConfig( + max_command_chars=config.max_command_chars, + allowed_directories=tuple(self._allowed_dirs), + read_only_directories=tuple(self._read_only_dirs), + workspace_root=workspace_root, + ) + self._shell_validator = ShellPathValidator( + allowed_dirs=self._allowed_dirs, + safety_config=path_safety_cfg, + ) + + def check(self, tool_name: str, tool_args: dict[str, Any]) -> SafetyDecision: + # 1. Generic safety rules + for rule in self._config.patterns: + if self._matcher.match_with_content(rule, tool_name, tool_args): + return SafetyDecision(action='deny', reason=f'Blocked by safety rule: {rule}') + + # 2. Tool-specific checks + if tool_name.endswith('---shell_executor'): + command = tool_args.get('command', '') + return self._shell_validator.check(command) + + if tool_name.endswith('---write_file') or tool_name.endswith('---edit_file'): + return self._check_file_path(tool_args.get('path', ''), 'write') + + if tool_name.endswith('---read_file'): + return self._check_file_path(tool_args.get('path', ''), 'read') + + if tool_name.endswith('---grep') or tool_name.endswith('---glob'): + return self._check_file_path(tool_args.get('path', '.'), 'read') + + # 3. No rule matched → allow + return SafetyDecision(action='allow', reason='No safety rule matched') + + def _check_file_path(self, path: str, op_type: Literal['read', 'write']) -> SafetyDecision: + if not path: + return SafetyDecision(action='deny', reason='Empty file path') + + # Sensitive path check + if op_type == 'write': + expanded = os.path.expanduser(path) + for sensitive in self._sensitive_paths: + sensitive_expanded = os.path.expanduser(sensitive) + if fnmatch.fnmatch(expanded, sensitive_expanded): + return SafetyDecision( + action='deny', + reason=f'Write to sensitive path blocked: {path}', + ) + + cwd = self._workspace_root or os.getcwd() + result = validate_path(path, cwd, self._allowed_dirs, op_type, read_only_dirs=self._read_only_dirs) + if not result.allowed: + return SafetyDecision(action=result.action, reason=result.reason, category=result.category) + + return SafetyDecision(action='allow', reason='Path validation passed') diff --git a/ms_agent/permission/sed_validator.py b/ms_agent/permission/sed_validator.py new file mode 100644 index 000000000..37a8b78d3 --- /dev/null +++ b/ms_agent/permission/sed_validator.py @@ -0,0 +1,128 @@ +"""sed expression safety checks. + +Detects dangerous sed expressions (write commands, shell execution, etc.) +and determines whether a sed invocation is read-only. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Literal + +_PRINT_ONLY_EXPR = re.compile(r'^(\d+(,\d+)?)?p$') + + +@dataclass(frozen=True) +class SedSafetyResult: + safe: bool + reason: str + + +def _has_dangerous_sub_flags(expression: str) -> bool: + """Detect w/e flags in sed s-commands, supporting arbitrary delimiters.""" + i = 0 + n = len(expression) + while i < n: + if expression[i] != 's' or i + 1 >= n: + i += 1 + continue + delim = expression[i + 1] + if delim == '\\': + i += 1 + continue + pos = i + 2 + found = 0 + while pos < n and found < 2: + if expression[pos] == '\\' and pos + 1 < n: + pos += 2 + continue + if expression[pos] == delim: + found += 1 + pos += 1 + if found < 2: + break + while pos < n and expression[pos] not in ';\n': + if expression[pos] in 'we': + return True + pos += 1 + i = pos + return False + + +def is_sed_read_only(args: list[str]) -> bool: + """Check if sed invocation is read-only: ``-n`` flag, print-only expressions, no ``-i``.""" + has_n = False + has_i = False + expressions: list[str] = [] + + skip_next = False + script_found = False + + for i, arg in enumerate(args): + if skip_next: + skip_next = False + continue + if arg == '--': + break + if arg.startswith('-'): + if arg in ('-n', '--quiet', '--silent'): + has_n = True + if arg in ('-i', '--in-place') or arg.startswith('-i'): + has_i = True + if arg in ('-e', '--expression'): + if i + 1 < len(args): + expressions.append(args[i + 1]) + skip_next = True + script_found = True + elif arg in ('-f', '--file'): + skip_next = True + script_found = True + continue + if not script_found: + expressions.append(arg) + script_found = True + + if has_i: + return False + if not has_n: + return False + return all(_PRINT_ONLY_EXPR.match(e.strip()) for e in expressions if e) + + +def check_sed_expression_safety(expression: str) -> SedSafetyResult: + """Check a sed expression for dangerous patterns.""" + if not expression: + return SedSafetyResult(safe=True, reason='Empty expression') + + # Non-ASCII characters (homoglyph attacks) + try: + expression.encode('ascii') + except UnicodeEncodeError: + return SedSafetyResult(safe=False, reason='Non-ASCII characters in sed expression') + + # Newlines (multi-line command injection) + if '\n' in expression or '\r' in expression: + return SedSafetyResult(safe=False, reason='Newline in sed expression') + + # Curly braces (block commands — cannot be statically analysed) + if '{' in expression or '}' in expression: + return SedSafetyResult(safe=False, reason='Block commands ({}) in sed expression') + + # w/W command — writes to file + if re.search(r'(?\s*\(') +_REDIRECT_PATTERN = re.compile( + r'(?:&>>|&>|>>|>\||>)' + r'\s*' + r'(\S+)' +) +_FD_REDIRECT = re.compile(r'^\d*>&\d+$') + + +@dataclass(frozen=True) +class SafetyDecision: + action: Literal['allow', 'deny', 'ask'] + reason: str + category: str = '' + + +@dataclass(frozen=True) +class PathSafetyConfig: + max_command_chars: int = 8192 + allowed_directories: tuple[str, ...] = () + read_only_directories: tuple[str, ...] = () + workspace_root: str | None = None + + +class ShellPathValidator: + """Path-level security validator for shell_executor tool calls.""" + + def __init__( + self, + allowed_dirs: Sequence[str], + safety_config: PathSafetyConfig | None = None, + ) -> None: + self._allowed_dirs = list(allowed_dirs) + self._config = safety_config or PathSafetyConfig() + self._read_only_dirs = list(self._config.read_only_directories) + self._workspace_root = self._config.workspace_root or os.getcwd() + self._extractors = build_extractor_registry() + + def check(self, command: str) -> SafetyDecision: + if not command or not command.strip(): + return SafetyDecision(action='deny', reason='Empty shell command') + + if len(command) > self._config.max_command_chars: + return SafetyDecision( + action='deny', + reason=f'Command exceeds max length ({self._config.max_command_chars})', + ) + + # 1. Process substitution + if _PROCESS_OUTPUT_SUB.search(command): + return SafetyDecision( + action='ask', + reason='Command contains output process substitution >(…) — may bypass path validation', + category='process_output_sub', + ) + if _PROCESS_INPUT_SUB.search(command): + return SafetyDecision( + action='ask', + reason='Command contains input process substitution <(…) — cannot statically analyse', + category='process_input_sub', + ) + + # 2. Split compound commands + sub_commands = _split_compound(command) + + # Track cd presence for cd+write detection + has_cd = False + has_write_or_create = False + + for sub_cmd in sub_commands: + try: + tokens = shlex.split(sub_cmd) + except ValueError: + return SafetyDecision(action='ask', reason=f'Failed to parse command: {sub_cmd}', category='parse_failure') + + if not tokens: + continue + + # 3. Check output redirections on the raw sub-command string + redirect_result = self._check_redirects(sub_cmd) + if redirect_result.action != 'allow': + return redirect_result + + # 4. Strip safe wrappers + tokens = strip_safe_wrappers(tokens) + if not tokens: + continue + + base_cmd = os.path.basename(tokens[0]) + args = tokens[1:] + + if base_cmd == 'cd': + has_cd = True + + # 5. Command path extraction and validation + result = self._check_command(base_cmd, args) + if result.action != 'allow': + return result + + entry = self._extractors.get(base_cmd) + if entry and entry.op_type in ('write', 'create'): + has_write_or_create = True + + # 6. cd + write/create compound → ask + if has_cd and has_write_or_create: + return SafetyDecision( + action='ask', + reason='Command combines cd with write/create operations — ' + 'path validation may not reflect runtime working directory', + category='cd_write_compound', + ) + + return SafetyDecision(action='allow', reason='Shell command passed all checks') + + def _check_command(self, base_cmd: str, args: list[str]) -> SafetyDecision: + entry = self._extractors.get(base_cmd) + if entry is None: + return SafetyDecision(action='allow', reason=f'Unregistered command: {base_cmd}') + + # Command-level validator (e.g. mv/cp with flags) + if entry.command_validator is not None: + err = entry.command_validator(args) + if err: + return SafetyDecision(action='ask', reason=err, category='command_validator') + + # sed special handling + if base_cmd == 'sed': + return self._check_sed(args, entry) + + paths = entry.extractor(args) + if not paths: + return SafetyDecision(action='allow', reason=f'{base_cmd}: no paths to validate') + + return self._validate_paths(paths, entry.op_type, base_cmd) + + def _check_sed(self, args: list[str], entry: ExtractorEntry) -> SafetyDecision: + op_type = entry.op_type + if is_sed_read_only(args): + op_type = 'read' + + # Expression safety check + expressions = self._collect_sed_expressions(args) + for expr in expressions: + result = check_sed_expression_safety(expr) + if not result.safe: + return SafetyDecision(action='deny', reason=result.reason) + + paths = entry.extractor(args) + if not paths: + return SafetyDecision(action='allow', reason='sed: no file paths') + + return self._validate_paths(paths, op_type, 'sed') + + @staticmethod + def _collect_sed_expressions(args: list[str]) -> list[str]: + expressions: list[str] = [] + skip_next = False + script_found = False + + for i, arg in enumerate(args): + if skip_next: + skip_next = False + continue + if arg == '--': + break + if arg.startswith('-'): + if arg in ('-e', '--expression'): + if i + 1 < len(args): + expressions.append(args[i + 1]) + skip_next = True + script_found = True + elif arg in ('-f', '--file'): + skip_next = True + script_found = True + continue + if not script_found: + expressions.append(arg) + script_found = True + return expressions + + def _validate_paths( + self, + paths: list[str], + op_type: Literal['read', 'write', 'create'], + cmd_name: str, + ) -> SafetyDecision: + cwd = self._workspace_root + + for path in paths: + # Dangerous removal check for rm/rmdir + if cmd_name in ('rm', 'rmdir') and is_dangerous_removal_path(path): + return SafetyDecision( + action='deny', + reason=f'Dangerous removal path: {path}', + ) + + result = validate_path(path, cwd, self._allowed_dirs, op_type, read_only_dirs=self._read_only_dirs) + if not result.allowed: + return SafetyDecision(action=result.action, reason=result.reason, category=result.category) + + return SafetyDecision(action='allow', reason=f'{cmd_name}: all paths validated') + + def _check_redirects(self, sub_cmd: str) -> SafetyDecision: + for match in _REDIRECT_PATTERN.finditer(sub_cmd): + target = match.group(1) + if _FD_REDIRECT.match(target): + continue + if target == '/dev/null': + continue + if '$' in target or '%' in target: + return SafetyDecision( + action='deny', + reason=f'Redirect target contains variable expansion: {target}', + ) + + result = validate_path( + target, self._workspace_root, self._allowed_dirs, 'create', + read_only_dirs=self._read_only_dirs, + ) + if not result.allowed: + return SafetyDecision(action=result.action, reason=result.reason, category=result.category) + + return SafetyDecision(action='allow', reason='Redirects OK') + + +def _split_compound(command: str) -> list[str]: + """Split a compound command on ``&&``, ``||``, ``;``, ``|`` operators. + + Uses a simple approach that does not split inside quotes. + """ + parts: list[str] = [] + current: list[str] = [] + in_single = False + in_double = False + i = 0 + chars = command + + while i < len(chars): + c = chars[i] + + if c == '\\' and not in_single and i + 1 < len(chars): + current.append(c) + current.append(chars[i + 1]) + i += 2 + continue + + if c == "'" and not in_double: + in_single = not in_single + current.append(c) + i += 1 + continue + + if c == '"' and not in_single: + in_double = not in_double + current.append(c) + i += 1 + continue + + if in_single or in_double: + current.append(c) + i += 1 + continue + + # Check for compound operators + if c == ';': + parts.append(''.join(current).strip()) + current = [] + i += 1 + continue + if c == '|': + if i + 1 < len(chars) and chars[i + 1] == '|': + parts.append(''.join(current).strip()) + current = [] + i += 2 + continue + parts.append(''.join(current).strip()) + current = [] + i += 1 + continue + if c == '&': + if i + 1 < len(chars) and chars[i + 1] == '&': + parts.append(''.join(current).strip()) + current = [] + i += 2 + continue + + current.append(c) + i += 1 + + remainder = ''.join(current).strip() + if remainder: + parts.append(remainder) + + return [p for p in parts if p] diff --git a/ms_agent/permission/suggestions.py b/ms_agent/permission/suggestions.py new file mode 100644 index 000000000..52f119162 --- /dev/null +++ b/ms_agent/permission/suggestions.py @@ -0,0 +1,49 @@ +"""Auto-generate permission pattern suggestions for allow_always actions.""" + +from __future__ import annotations + +import shlex +from typing import Any + +from .matcher import CONTENT_SEP, TOOL_SPLITER +from .wrapper_strip import strip_safe_wrappers + + +def generate_suggestions(tool_name: str, tool_args: dict[str, Any]) -> list[str]: + """Generate suggested wildcard patterns based on tool name and arguments. + + Returns a list of patterns from most specific to most general. + """ + suggestions: list[str] = [] + + # Extract server name (everything before first TOOL_SPLITER) + parts = tool_name.split(TOOL_SPLITER, 1) + server = parts[0] if len(parts) > 1 else '' + + if tool_name.endswith(f'{TOOL_SPLITER}shell_executor'): + command = tool_args.get('command', '') + if command: + first_cmd = _extract_first_command(command) + if first_cmd: + suggestions.append(f'{tool_name}{CONTENT_SEP}{first_cmd} *') + suggestions.append(tool_name) + elif server == 'file_system': + suggestions.append(tool_name) + elif server == 'web_search': + suggestions.append(f'{server}{TOOL_SPLITER}*') + else: + suggestions.append(tool_name) + if server: + suggestions.append(f'{server}{TOOL_SPLITER}*') + + return suggestions + + +def _extract_first_command(command: str) -> str: + """Extract the base command name, stripping safe wrappers (timeout, nice, …).""" + try: + tokens = shlex.split(command) + except ValueError: + tokens = command.split() + stripped = strip_safe_wrappers(tokens) + return stripped[0] if stripped else '' diff --git a/ms_agent/permission/wrapper_strip.py b/ms_agent/permission/wrapper_strip.py new file mode 100644 index 000000000..fa5b8b5b0 --- /dev/null +++ b/ms_agent/permission/wrapper_strip.py @@ -0,0 +1,210 @@ +"""Safe wrapper stripping: remove harmless command wrappers so the real +command can be analysed for path extraction. + +Two-phase algorithm: + Phase 1 — strip safe environment variable assignments (VAR=val). + Phase 2 — strip wrapper commands (timeout, time, nice, nohup, stdbuf, env). +""" + +from __future__ import annotations + +import re + +# Environment variables safe to strip (do not affect paths or inject code). +SAFE_ENV_VARS: frozenset[str] = frozenset({ + # Go + 'GOEXPERIMENT', 'GOOS', 'GOARCH', 'CGO_ENABLED', 'GO111MODULE', + # Rust + 'RUST_BACKTRACE', 'RUST_LOG', + # Node + 'NODE_ENV', + # Python + 'PYTHONUNBUFFERED', 'PYTHONDONTWRITEBYTECODE', + # Pytest + 'PYTEST_DISABLE_PLUGIN_AUTOLOAD', 'PYTEST_DEBUG', + # Locale / encoding + 'LANG', 'LANGUAGE', 'LC_ALL', 'LC_CTYPE', 'LC_TIME', 'CHARSET', + # Terminal / display + 'TERM', 'COLORTERM', 'NO_COLOR', 'FORCE_COLOR', 'TZ', + # Color config + 'LS_COLORS', 'LSCOLORS', 'GREP_COLOR', 'GREP_COLORS', 'GCC_COLORS', + # Display format + 'TIME_STYLE', 'BLOCK_SIZE', 'BLOCKSIZE', +}) + +_SAFE_FLAG_VALUE = re.compile(r'^[A-Za-z0-9_.+\-]+$') +_ENV_ASSIGN = re.compile(r'^([A-Za-z_][A-Za-z0-9_]*)=') + + +def _strip_env_vars(tokens: list[str]) -> list[str]: + """Phase 1: strip leading safe environment variable assignments.""" + i = 0 + while i < len(tokens): + m = _ENV_ASSIGN.match(tokens[i]) + if not m: + break + var_name = m.group(1) + if var_name not in SAFE_ENV_VARS: + break + i += 1 + return tokens[i:] + + +def _strip_timeout(tokens: list[str]) -> list[str] | None: + """Strip ``timeout`` wrapper with its flags and duration argument.""" + if not tokens or tokens[0] != 'timeout': + return None + + no_value_flags = frozenset({'--foreground', '--preserve-status', '-v', '--verbose'}) + value_flags_long = frozenset({'--kill-after', '--signal'}) + value_flags_short = frozenset({'-k', '-s'}) + + i = 1 + while i < len(tokens): + arg = tokens[i] + if arg == '--': + i += 1 + break + if arg in no_value_flags: + i += 1 + continue + if arg in value_flags_long: + i += 2 # flag + value + continue + if any(arg.startswith(f'{f}=') for f in value_flags_long): + i += 1 + continue + for short in value_flags_short: + if arg == short: + i += 2 + break + if arg.startswith(short) and len(arg) > len(short): + val = arg[len(short):] + if not _SAFE_FLAG_VALUE.match(val): + return None # suspicious flag value + i += 1 + break + else: + if arg.startswith('-'): + i += 1 + continue + # This is the duration argument + i += 1 + break + return tokens[i:] + + +def _strip_time(tokens: list[str]) -> list[str] | None: + if not tokens or tokens[0] != 'time': + return None + i = 1 + while i < len(tokens) and tokens[i].startswith('-'): + i += 1 + return tokens[i:] + + +def _strip_nice(tokens: list[str]) -> list[str] | None: + """Strip ``nice`` in three forms: bare, ``-N``, ``-n N``.""" + if not tokens or tokens[0] != 'nice': + return None + i = 1 + if i < len(tokens): + if tokens[i] in ('-n', '--adjustment'): + i += 2 # -n N + elif tokens[i].startswith('-') and tokens[i][1:].lstrip('-').isdigit(): + i += 1 # -N (traditional) + return tokens[i:] + + +def _strip_nohup(tokens: list[str]) -> list[str] | None: + if not tokens or tokens[0] != 'nohup': + return None + return tokens[1:] + + +def _strip_stdbuf(tokens: list[str]) -> list[str] | None: + if not tokens or tokens[0] != 'stdbuf': + return None + # stdbuf flags: -i MODE, -o MODE, -e MODE (or combined: -iL, -o0, --input=MODE) + i = 1 + while i < len(tokens): + arg = tokens[i] + if arg == '--': + i += 1 + break + if arg.startswith('-'): + if '=' in arg or len(arg) > 2: + i += 1 # combined flag+value (e.g. -o0, --input=L) + else: + i += 2 # separate value (e.g. -o 0) + continue + break + return tokens[i:] + + +def _strip_env(tokens: list[str]) -> list[str] | None: + """Strip ``env`` wrapper with safe flags.""" + if not tokens or tokens[0] != 'env': + return None + + unsafe_flags = frozenset({'-S', '--split-string', '-C', '--chdir', '-P', '--path'}) + safe_no_value = frozenset({'-i', '--ignore-environment', '-0', '--null', '-v', '--verbose'}) + + i = 1 + while i < len(tokens): + arg = tokens[i] + if arg in unsafe_flags: + return None # cannot safely strip + if arg in safe_no_value: + i += 1 + continue + if arg in ('-u', '--unset'): + i += 2 + continue + if arg == '--': + i += 1 + break + if _ENV_ASSIGN.match(arg): + i += 1 + continue + if arg.startswith('-'): + return None # unknown flag + break + return tokens[i:] + + +_WRAPPER_STRIPPERS = [ + _strip_timeout, + _strip_time, + _strip_nice, + _strip_nohup, + _strip_stdbuf, + _strip_env, +] + + +def strip_safe_wrappers(tokens: list[str]) -> list[str]: + """Strip safe wrappers from a tokenized command. + + Phase 1: strip leading safe env var assignments. + Phase 2: iteratively strip wrapper commands. + """ + if not tokens: + return tokens + + # Phase 1: environment variables + tokens = _strip_env_vars(tokens) + if not tokens: + return tokens + + # Phase 2: wrapper commands (iterate until stable) + changed = True + while changed and tokens: + changed = False + for stripper in _WRAPPER_STRIPPERS: + result = stripper(tokens) + if result is not None: + tokens = result + changed = True + break + return tokens diff --git a/ms_agent/plugins/__init__.py b/ms_agent/plugins/__init__.py new file mode 100644 index 000000000..d663d3a5d --- /dev/null +++ b/ms_agent/plugins/__init__.py @@ -0,0 +1,18 @@ +"""Plugin compatibility layer for container-style community plugins. + +Keep this package initializer intentionally lightweight. Several core modules +import plugin submodules during config loading, and importing runtime/loader +here would pull hooks and agent modules early enough to create circular imports. +""" + +__all__ = [ + 'config_manager', + 'dependencies', + 'installer', + 'loader', + 'manifest', + 'registry', + 'runtime', + 'types', + 'user_config', +] diff --git a/ms_agent/plugins/agents.py b/ms_agent/plugins/agents.py new file mode 100644 index 000000000..0e68245c5 --- /dev/null +++ b/ms_agent/plugins/agents.py @@ -0,0 +1,354 @@ +"""Plugin agent registry and delegate for agents/*.md subagent templates.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from omegaconf import OmegaConf + +from ms_agent.plugins.types import AgentDef +from ms_agent.skill.schema import SkillSchemaParser + +_FRONTMATTER_RE = re.compile(r'^---\s*\n.*?\n---\s*\n', re.DOTALL) + +# Claude Code tool names -> ms-agent config.tools top-level keys to keep. +_CLAUDE_TOOL_TO_CONFIG_KEYS: dict[str, tuple[str, ...]] = { + 'Read': ('file_system',), + 'Write': ('file_system',), + 'Edit': ('file_system',), + 'MultiEdit': ('file_system',), + 'Bash': ('code_executor',), + 'Grep': ('file_system', 'localsearch', 'web_search'), + 'Glob': ('file_system',), + 'Skill': (), # skills are injected separately + 'TodoWrite': ('todo_list',), + 'AskUserQuestion': (), # no dedicated tool yet + 'Task': ('agent_tools',), +} + +_FORBIDDEN_AGENT_FRONTMATTER_KEYS = frozenset({ + 'hooks', + 'mcpServers', + 'permissionMode', +}) + +_CLAUDE_BUILTIN_SUBAGENT_TYPES = frozenset({ + 'general-purpose', + 'explore', + 'shell', + 'browser', + 'planner', + 'architect', +}) + + +@dataclass(frozen=True) +class RegisteredPluginAgent: + defn: AgentDef + namespaced_name: str + + +class PluginAgentRegistry: + """In-memory registry of plugin-defined subagent templates.""" + + def __init__(self) -> None: + self._by_namespaced: dict[str, RegisteredPluginAgent] = {} + self._by_short: dict[str, RegisteredPluginAgent] = {} + + def rebuild(self, agent_defs: list[AgentDef]) -> None: + self._by_namespaced.clear() + self._by_short.clear() + short_claimed: set[str] = set() + for defn in sorted(agent_defs, key=lambda item: (item.plugin_id, item.name)): + namespaced = f'{defn.plugin_id}:{defn.name}' + entry = RegisteredPluginAgent(defn=defn, namespaced_name=namespaced) + self._by_namespaced[namespaced] = entry + if defn.name not in short_claimed: + self._by_short[defn.name] = entry + short_claimed.add(defn.name) + + def remove_plugin(self, plugin_id: str) -> None: + for key in [ + key for key, entry in self._by_namespaced.items() + if entry.defn.plugin_id == plugin_id + ]: + entry = self._by_namespaced.pop(key) + short = self._by_short.get(entry.defn.name) + if short is not None and short.namespaced_name == entry.namespaced_name: + self._by_short.pop(entry.defn.name, None) + for name, entry in list(self._by_short.items()): + if entry.defn.plugin_id == plugin_id: + self._by_short.pop(name, None) + + def has_agents(self) -> bool: + return bool(self._by_namespaced) + + def list_all(self) -> list[dict[str, Any]]: + seen: set[str] = set() + items: list[dict[str, Any]] = [] + for entry in self._by_namespaced.values(): + if entry.namespaced_name in seen: + continue + seen.add(entry.namespaced_name) + defn = entry.defn + items.append({ + 'plugin_id': defn.plugin_id, + 'name': defn.name, + 'namespaced_name': entry.namespaced_name, + 'description': defn.description, + 'model': defn.model, + 'tools': list(defn.tools), + 'skills': list(defn.skills), + 'path': defn.path, + }) + return sorted(items, key=lambda item: item['namespaced_name']) + + def resolve(self, name: str | None) -> RegisteredPluginAgent | None: + if not name: + return None + if name in self._by_namespaced: + return self._by_namespaced[name] + if name in self._by_short: + return self._by_short[name] + if ':' in name: + plugin_id, short = name.split(':', 1) + namespaced = f'{plugin_id}:{short}' + return self._by_namespaced.get(namespaced) + return None + + +class AgentDelegate: + """Build runnable sub-agent specs from plugin agent markdown templates.""" + + @staticmethod + def read_agent_markdown(path: str | Path) -> tuple[dict[str, Any], str]: + content = Path(path).read_text(encoding='utf-8') + frontmatter = SkillSchemaParser.parse_yaml_frontmatter(content) or {} + body = _FRONTMATTER_RE.sub('', content, count=1).strip() + return frontmatter, body + + @staticmethod + def validate_frontmatter(frontmatter: dict[str, Any]) -> list[str]: + warnings: list[str] = [] + for key in _FORBIDDEN_AGENT_FRONTMATTER_KEYS: + if key in frontmatter: + warnings.append( + f'Plugin agents must not declare {key!r} in frontmatter') + return warnings + + @staticmethod + def build_inline_config( + defn: AgentDef, + parent_config: Any, + ) -> dict[str, Any]: + frontmatter, body = AgentDelegate.read_agent_markdown(defn.path) + warnings = AgentDelegate.validate_frontmatter(frontmatter) + if warnings: + raise ValueError( + f'Invalid plugin agent {defn.plugin_id}:{defn.name}: ' + + '; '.join(warnings)) + + inline: dict[str, Any] = { + 'prompt': {'system': body}, + 'ms_agent_subagent': True, + 'plugin_agent': { + 'plugin_id': defn.plugin_id, + 'name': defn.name, + 'path': defn.path, + }, + } + if hasattr(parent_config, 'local_dir') and parent_config.local_dir: + inline['local_dir'] = str(parent_config.local_dir) + model = defn.model or frontmatter.get('model') + if model and str(model).lower() != 'inherit': + parent_llm = {} + if hasattr(parent_config, 'llm') and parent_config.llm is not None: + parent_llm = OmegaConf.to_container(parent_config.llm, resolve=True) or {} + inline['llm'] = {**parent_llm, 'model': str(model)} + if defn.skills: + inline['skills'] = { + 'whitelist': list(defn.skills), + } + return inline + + @staticmethod + def compute_disallowed_tools( + defn: AgentDef, + parent_config: Any, + ) -> list[str] | None: + if defn.disallowed_tools: + return list(defn.disallowed_tools) + if not defn.tools: + return None + if not hasattr(parent_config, 'tools') or parent_config.tools is None: + return None + tools_dict = OmegaConf.to_container(parent_config.tools, resolve=True) or {} + if not isinstance(tools_dict, dict): + return None + + keep: set[str] = set() + for claude_name in defn.tools: + keep.update(_CLAUDE_TOOL_TO_CONFIG_KEYS.get(claude_name, ())) + + plugin_only_keys = {'agent_tools', 'split_task', 'task_control'} + disallowed = [ + key for key in tools_dict + if key not in keep and key not in plugin_only_keys + ] + return disallowed or None + + @staticmethod + def to_agent_tool_spec( + entry: RegisteredPluginAgent, + parent_config: Any, + *, + trust_remote_code: bool = True, + ): + from ms_agent.tools.agent_tool import _AgentToolSpec + + defn = entry.defn + description = ( + defn.description + or f'Plugin subagent {entry.namespaced_name} from {defn.plugin_id}' + ) + inline_config = AgentDelegate.build_inline_config(defn, parent_config) + disallowed_tools = AgentDelegate.compute_disallowed_tools(defn, parent_config) + return _AgentToolSpec( + tool_name=defn.name, + description=description, + parameters={ + 'type': 'object', + 'properties': { + 'prompt': { + 'type': 'string', + 'description': ( + f'Task prompt for plugin subagent {entry.namespaced_name}.' + ), + }, + 'request': { + 'type': 'string', + 'description': 'Alias of prompt for AgentTool compatibility.', + }, + 'description': { + 'type': 'string', + 'description': 'Short summary of the delegated task.', + }, + }, + 'required': [], + 'additionalProperties': True, + }, + config_path=None, + inline_config=inline_config, + server_name=f'plugin:{defn.plugin_id}', + tag_prefix=f'{defn.plugin_id}-{defn.name}-', + input_mode='text', + request_field='prompt', + input_template=None, + output_mode='final_message', + max_output_chars=100000, + trust_remote_code=trust_remote_code, + env=None, + run_in_thread=True, + run_in_process=True, + dynamic=False, + disallowed_tools=disallowed_tools, + ) + + @staticmethod + def build_task_tool_spec( + registry: PluginAgentRegistry, + *, + trust_remote_code: bool = True, + ): + from ms_agent.tools.agent_tool import _AgentToolSpec + + available = [item['namespaced_name'] for item in registry.list_all()] + return _AgentToolSpec( + tool_name='Task', + description=( + 'Launch a plugin-defined subagent. Provide `agent` (for example ' + f'{available[0] if available else "hookify:conversation-analyzer"}) ' + 'and `prompt`.' + ), + parameters={ + 'type': 'object', + 'properties': { + 'agent': { + 'type': 'string', + 'description': ( + 'Plugin subagent name, e.g. conversation-analyzer or ' + 'hookify:conversation-analyzer.' + ), + }, + 'subagent_type': { + 'type': 'string', + 'description': 'Alias of agent when it matches a plugin subagent.', + }, + 'prompt': { + 'type': 'string', + 'description': 'Prompt for the delegated subagent.', + }, + 'description': { + 'type': 'string', + 'description': 'Short summary of the delegated task.', + }, + }, + 'required': ['prompt'], + 'additionalProperties': True, + }, + config_path=None, + inline_config=None, + server_name='plugin_agents', + tag_prefix='plugin-task-', + input_mode='text', + request_field='prompt', + input_template=None, + output_mode='final_message', + max_output_chars=100000, + trust_remote_code=trust_remote_code, + env=None, + run_in_thread=True, + run_in_process=True, + dynamic=True, + disallowed_tools=None, + ) + + @staticmethod + def resolve_task_agent_name(tool_args: dict[str, Any]) -> str | None: + for key in ('agent', 'subagent_type', 'subagent', 'name'): + value = tool_args.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + @staticmethod + def resolve_task_entry( + registry: PluginAgentRegistry, + tool_args: dict[str, Any], + ) -> RegisteredPluginAgent | None: + """Resolve a plugin subagent from Task tool arguments.""" + agent_name = AgentDelegate.resolve_task_agent_name(tool_args) + entry = registry.resolve(agent_name) if agent_name else None + if entry is not None: + return entry + if agent_name: + for item in registry.list_all(): + namespaced = item['namespaced_name'] + if agent_name in {namespaced, item['name'], namespaced.split(':', 1)[-1]}: + resolved = registry.resolve(namespaced) + if resolved is not None: + return resolved + items = registry.list_all() + if len(items) == 1: + only = registry.resolve(items[0]['namespaced_name']) + if only is not None: + return only + if ( + agent_name in _CLAUDE_BUILTIN_SUBAGENT_TYPES + and len(items) == 1 + ): + return registry.resolve(items[0]['namespaced_name']) + return None diff --git a/ms_agent/plugins/commands.py b/ms_agent/plugins/commands.py new file mode 100644 index 000000000..807e54136 --- /dev/null +++ b/ms_agent/plugins/commands.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from ms_agent.command.router import CommandRouter +from ms_agent.command.types import ( + CommandContext, + CommandDef as RouterCommandDef, + CommandResult, + CommandResultType, +) +from ms_agent.plugins.types import CommandDef + +_FRONTMATTER_RE = re.compile(r'^---\s*\n.*?\n---\s*\n', re.DOTALL) + + +def register_plugin_commands( + router: CommandRouter, + command_defs: list[CommandDef], +) -> None: + """Register plugin commands as slash commands. + + The namespaced form (`/:`) is always registered. The + unqualified form is registered only when it does not conflict with an + existing command. + """ + for cmd in command_defs: + namespaced = f'{cmd.plugin_id}:{cmd.name}' + router.register( + RouterCommandDef( + name=namespaced, + description=cmd.description or f'Plugin command {namespaced}', + category=f'plugin:{cmd.plugin_id}', + ), + _handler_for(cmd), + ) + if router.resolve(cmd.name) is None: + router.register( + RouterCommandDef( + name=cmd.name, + description=cmd.description or f'Plugin command {cmd.name}', + category=f'plugin:{cmd.plugin_id}', + ), + _handler_for(cmd), + ) + + +def _handler_for(cmd: CommandDef): + async def _handler(ctx: CommandContext) -> CommandResult: + try: + content = Path(cmd.path).read_text(encoding='utf-8') + except OSError as exc: + return CommandResult( + type=CommandResultType.MESSAGE, + content=f'Plugin command `{cmd.plugin_id}:{cmd.name}` failed: {exc}', + ) + body = _strip_frontmatter(content) + args = ctx.args or '' + body = body.replace('$ARGUMENTS', args).replace('${ARGUMENTS}', args) + prompt = ( + f'Run plugin command `/{cmd.plugin_id}:{cmd.name}` from `{cmd.path}`.\n\n' + f'{body}' + ) + if args and '$ARGUMENTS' not in content and '${ARGUMENTS}' not in content: + prompt = f'{prompt}\n\nUser arguments: {args}' + return CommandResult( + type=CommandResultType.SUBMIT_PROMPT, + content=prompt, + metadata={'plugin_id': cmd.plugin_id, 'command': cmd.name}, + ) + + return _handler + + +def _strip_frontmatter(content: str) -> str: + return _FRONTMATTER_RE.sub('', content, count=1).strip() diff --git a/ms_agent/plugins/config_manager.py b/ms_agent/plugins/config_manager.py new file mode 100644 index 000000000..a94d49a16 --- /dev/null +++ b/ms_agent/plugins/config_manager.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import json +from pathlib import Path +from threading import Lock +from typing import Literal + +from ms_agent.plugins.types import PluginRecord + +PluginScope = Literal['global', 'project', 'merged'] +PLUGIN_FILE = 'plugins.json' +PROJECT_META_DIR = '.ms-agent' + + +class PluginConfigManager: + """CRUD for global/project plugins.json with project override semantics.""" + + def __init__( + self, + global_dir: str | Path = '~/.ms_agent', + project_root: str | Path | None = None, + ) -> None: + self.global_root = Path(global_dir).expanduser() + self.project_root = ( + Path(project_root).expanduser() if project_root else None + ) + self._lock = Lock() + + @property + def global_plugins_path(self) -> Path: + return self.global_root / PLUGIN_FILE + + @property + def project_plugins_path(self) -> Path: + if self.project_root is None: + raise ValueError('project_root is required for project scope') + return self.project_root / PROJECT_META_DIR / PLUGIN_FILE + + @property + def global_plugins_dir(self) -> Path: + return self.global_root / 'plugins' + + @property + def project_plugins_dir(self) -> Path: + if self.project_root is None: + raise ValueError('project_root is required for project scope') + return self.project_root / PROJECT_META_DIR / 'plugins' + + @property + def global_plugin_data_root(self) -> Path: + return self.global_root / 'plugins' / 'data' + + def list(self, scope: PluginScope = 'merged') -> list[PluginRecord]: + with self._lock: + if scope == 'global': + return self._load_scope('global') + if scope == 'project': + return self._load_scope('project') + return merge_plugin_records( + self._load_scope('global'), + self._load_scope('project') if self.project_root else [], + ) + + def load_merged(self, project_path: str | None = None) -> list[PluginRecord]: + if project_path and self.project_root is None: + scoped = PluginConfigManager(self.global_root, project_path) + return scoped.list('merged') + return self.list('merged') + + def get( + self, + plugin_id: str, + scope: PluginScope = 'merged', + ) -> PluginRecord | None: + for record in self.list(scope): + if record.id == plugin_id: + return record + return None + + def upsert( + self, + record: PluginRecord, + scope: Literal['global', 'project'] = 'global', + ) -> None: + with self._lock: + records = self._load_scope(scope) + replaced = False + normalized = PluginRecord.from_dict(record.to_dict(), scope=scope) + for idx, item in enumerate(records): + if item.id == normalized.id: + records[idx] = normalized + replaced = True + break + if not replaced: + records.append(normalized) + self._save_scope(scope, records) + + def set_enabled( + self, + plugin_id: str, + enabled: bool, + scope: Literal['global', 'project'] = 'global', + ) -> None: + with self._lock: + records = self._load_scope(scope) + for record in records: + if record.id == plugin_id: + record.enabled = enabled + self._save_scope(scope, records) + return + raise KeyError(f'Plugin not found in {scope} scope: {plugin_id}') + + def remove( + self, + plugin_id: str, + scope: Literal['global', 'project'] = 'global', + ) -> None: + with self._lock: + records = [r for r in self._load_scope(scope) if r.id != plugin_id] + self._save_scope(scope, records) + + def _path_for_scope(self, scope: Literal['global', 'project']) -> Path: + return self.global_plugins_path if scope == 'global' else self.project_plugins_path + + def _load_scope(self, scope: Literal['global', 'project']) -> list[PluginRecord]: + path = self._path_for_scope(scope) + data = self._read_json(path) + raw_plugins = data.get('plugins', []) + if not isinstance(raw_plugins, list): + return [] + return [ + PluginRecord.from_dict(item, scope=scope) + for item in raw_plugins + if isinstance(item, dict) and item.get('id') + ] + + def _save_scope( + self, + scope: Literal['global', 'project'], + records: list[PluginRecord], + ) -> None: + path = self._path_for_scope(scope) + path.parent.mkdir(parents=True, exist_ok=True) + payload = {'plugins': [record.to_dict() for record in records]} + tmp = path.with_suffix('.tmp') + with open(tmp, 'w', encoding='utf-8') as f: + json.dump(payload, f, indent=2, ensure_ascii=False) + tmp.rename(path) + + @staticmethod + def _read_json(path: Path) -> dict: + if not path.is_file(): + return {} + try: + with open(path, encoding='utf-8') as f: + data = json.load(f) + return data if isinstance(data, dict) else {} + except (OSError, json.JSONDecodeError): + return {} + + +def merge_plugin_records( + global_records: list[PluginRecord], + project_records: list[PluginRecord], +) -> list[PluginRecord]: + merged: dict[str, PluginRecord] = {} + order: list[str] = [] + for record in global_records + project_records: + if record.id not in order: + order.append(record.id) + merged[record.id] = record + return [merged[plugin_id] for plugin_id in order] diff --git a/ms_agent/plugins/dependencies.py b/ms_agent/plugins/dependencies.py new file mode 100644 index 000000000..8386eb2d2 --- /dev/null +++ b/ms_agent/plugins/dependencies.py @@ -0,0 +1,89 @@ +"""Plugin dependency parsing and version constraint checks.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + +from ms_agent.plugins.manifest import normalize_plugin_id + +_VERSION_RE = re.compile(r'^(\d+)\.(\d+)\.(\d+)') + + +@dataclass(frozen=True) +class PluginDependency: + name: str + version: str | None = None + source: str | None = None + + @property + def plugin_id(self) -> str: + return normalize_plugin_id(self.name) + + +class PluginDependencyError(ValueError): + """Raised when a plugin dependency cannot be satisfied.""" + + +def parse_dependencies(raw: dict[str, Any] | None) -> list[PluginDependency]: + if not raw: + return [] + items = raw.get('dependencies') + if not items: + return [] + if not isinstance(items, list): + raise PluginDependencyError('manifest dependencies must be an array') + deps: list[PluginDependency] = [] + for item in items: + if isinstance(item, str): + deps.append(PluginDependency(name=item)) + continue + if not isinstance(item, dict): + raise PluginDependencyError('dependency entries must be objects or strings') + name = item.get('name') or item.get('id') + if not name: + raise PluginDependencyError('dependency entry requires name') + deps.append( + PluginDependency( + name=str(name), + version=item.get('version'), + source=item.get('source') or item.get('uri'), + )) + return deps + + +def version_satisfies(installed: str, constraint: str | None) -> bool: + if not constraint or constraint in {'*', 'latest'}: + return True + if installed in {'latest', ''}: + return True + installed_parts = _parse_version(installed) + if installed_parts is None: + return True + constraint = constraint.strip() + if constraint.startswith('~'): + base = _parse_version(constraint[1:]) + if base is None: + return True + return installed_parts[:2] == base[:2] and installed_parts >= base + if constraint.startswith('^'): + base = _parse_version(constraint[1:]) + if base is None: + return True + return installed_parts >= base and installed_parts[0] == base[0] + if constraint.startswith('>='): + base = _parse_version(constraint[2:]) + return base is None or installed_parts >= base + exact = _parse_version(constraint) + if exact is None: + return True + return installed_parts == exact + + +def _parse_version(value: str) -> tuple[int, int, int] | None: + val = str(value).strip().lstrip('vV') + match = _VERSION_RE.match(val) + if not match: + return None + return int(match.group(1)), int(match.group(2)), int(match.group(3)) diff --git a/ms_agent/plugins/installer.py b/ms_agent/plugins/installer.py new file mode 100644 index 000000000..b9c835a44 --- /dev/null +++ b/ms_agent/plugins/installer.py @@ -0,0 +1,528 @@ +from __future__ import annotations + +import json +import re +import shutil +import subprocess +import tarfile +import tempfile +from datetime import datetime, timezone +from pathlib import Path +from urllib.parse import parse_qs, unquote, urlparse +from urllib.request import urlopen + +from ms_agent.plugins.config_manager import PluginConfigManager +from ms_agent.plugins.dependencies import ( + PluginDependencyError, + parse_dependencies, + version_satisfies, +) +from ms_agent.plugins.manifest import PluginManifest +from ms_agent.plugins.registry import PluginRegistry +from ms_agent.plugins.types import InstallSource, PluginRecord + +try: + from modelscope import snapshot_download +except ImportError: # pragma: no cover - optional dependency path + snapshot_download = None + + +class UnsupportedPluginSource(ValueError): + """Raised for install sources outside the current local Phase 0 scope.""" + + +_MARKETPLACE_REPOS: dict[str, str] = { + 'claude-plugins-official': 'anthropics/claude-plugins-official', +} +_MARKETPLACE_ALIAS_RE = re.compile( + r'^[a-z0-9][a-z0-9._-]*@[a-z0-9][a-z0-9._-]*$', + re.IGNORECASE, +) + + +class PluginInstaller: + """Install plugins into the MS-Agent-owned plugin cache.""" + + def __init__( + self, + config_manager: PluginConfigManager | None = None, + *, + global_root: str | Path = '~/.ms_agent', + project_root: str | Path | None = None, + ) -> None: + self.config_manager = config_manager or PluginConfigManager( + global_root, project_root) + self.global_root = Path(global_root).expanduser() + self.project_root = ( + Path(project_root).expanduser() if project_root else None + ) + + def install( + self, + source: str, + *, + scope: str = 'global', + project_path: str | Path | None = None, + link: bool = False, + force: bool = False, + format_hint: str | None = None, + enabled: bool | None = None, + _installing: set[str] | None = None, + ) -> PluginManifest: + requested_source = source + source = normalize_install_source(source) + installing = set(_installing or ()) + with self._fetch_source(source) as fetched: + source_path = fetched.path + staged_manifest = PluginManifest.parse( + source_path, + format_hint=format_hint, + ) + self._ensure_dependencies( + staged_manifest, + scope=scope, + project_path=project_path, + link=link, + force=force, + format_hint=format_hint, + enabled=enabled, + installing=installing, + ) + target = self._target_dir( + staged_manifest.plugin_id, + scope=scope, + project_path=project_path, + ) + + if target.exists() or target.is_symlink(): + if not force: + # Idempotent reinstall: keep the existing managed copy and only + # refresh plugins.json from its locked manifest. + existing = PluginManifest.parse( + target, + format_hint=staged_manifest.format, + ) + return self._write_record( + existing, + source=requested_source, + fetch_source=source, + scope=scope, + enabled=enabled, + project_path=project_path, + resolved_sha=fetched.resolved_sha, + record_path=target, + ) + + install_path = self._stage_install_tree( + source_path, + target, + link=link and fetched.type == 'local', + ) + manifest = PluginManifest.parse( + install_path, + format_hint=staged_manifest.format, + ) + self._publish_staged_install(install_path, target) + manifest = PluginManifest.parse(target, format_hint=staged_manifest.format) + return self._write_record( + manifest, + source=requested_source, + fetch_source=source, + scope=scope, + enabled=enabled, + project_path=project_path, + resolved_sha=fetched.resolved_sha, + record_path=target, + ) + + @staticmethod + def _stage_install_tree(source_path: Path, target: Path, *, link: bool) -> Path: + target.parent.mkdir(parents=True, exist_ok=True) + staging_root = target.parent / '.staging' + staging_root.mkdir(parents=True, exist_ok=True) + staged = Path(tempfile.mkdtemp( + prefix=f'{target.name}_', + dir=staging_root, + )) + shutil.rmtree(staged) + if link: + staged.symlink_to(source_path, target_is_directory=True) + else: + shutil.copytree( + source_path, + staged, + ignore=shutil.ignore_patterns('.git'), + ) + return staged + + @staticmethod + def _publish_staged_install(staged: Path, target: Path) -> None: + backup: Path | None = None + if target.exists() or target.is_symlink(): + backup = Path(tempfile.mkdtemp( + prefix=f'{target.name}_backup_', + dir=target.parent / '.staging', + )) + shutil.rmtree(backup) + target.rename(backup) + try: + staged.rename(target) + except Exception: + if backup is not None and (backup.exists() or backup.is_symlink()): + backup.rename(target) + raise + else: + if backup is not None: + if backup.is_symlink() or backup.is_file(): + backup.unlink(missing_ok=True) + elif backup.is_dir(): + shutil.rmtree(backup) + + def _write_record( + self, + manifest: PluginManifest, + *, + source: str, + fetch_source: str | None = None, + scope: str, + enabled: bool | None, + project_path: str | Path | None, + resolved_sha: str | None = None, + record_path: Path | None = None, + ) -> PluginManifest: + source_type = _source_type(fetch_source or source) + record = PluginRecord( + id=manifest.plugin_id, + enabled=manifest.enabled if enabled is None else enabled, + managed_by='ms-agent', + format=manifest.format, + manifest_path=manifest.manifest_path, + source=InstallSource( + type=source_type, + uri=source, + resolved_sha=resolved_sha, + ), + path=str(record_path or manifest.root), + installed_at=datetime.now(timezone.utc).isoformat(), + ) + manager = self._manager_for_project(project_path) + manager.upsert(record, scope=scope) # type: ignore[arg-type] + return PluginManifest.parse(manifest.root, record=record) + + def _manager_for_project( + self, + project_path: str | Path | None, + ) -> PluginConfigManager: + if project_path is not None and self.config_manager.project_root is None: + return PluginConfigManager(self.global_root, project_path) + return self.config_manager + + def _target_dir( + self, + plugin_id: str, + *, + scope: str, + project_path: str | Path | None, + ) -> Path: + if scope == 'project': + root = Path(project_path or self.project_root or '') + if not str(root): + raise ValueError('project_path is required for project plugin install') + return root / '.ms-agent' / 'plugins' / plugin_id + return self.global_root / 'plugins' / plugin_id + + def _ensure_dependencies( + self, + manifest: PluginManifest, + *, + scope: str, + project_path: str | Path | None, + link: bool, + force: bool, + format_hint: str | None, + enabled: bool | None, + installing: set[str], + ) -> None: + registry = PluginRegistry(self.config_manager) + for dep in parse_dependencies(manifest.raw): + if dep.plugin_id in installing: + raise PluginDependencyError( + f'Circular plugin dependency: {dep.plugin_id}') + existing = registry.get_record(dep.plugin_id, 'merged') + if existing is not None: + dep_manifest = registry.get_manifest(dep.plugin_id, use_cache=False) + if dep_manifest is None: + raise PluginDependencyError( + f'Installed dependency {dep.plugin_id!r} is unreadable') + if not version_satisfies(dep_manifest.version, dep.version): + raise PluginDependencyError( + f'Dependency {dep.plugin_id!r} version ' + f'{dep_manifest.version!r} does not satisfy ' + f'{dep.version!r}') + continue + if not dep.source: + raise PluginDependencyError( + f'Dependency {dep.name!r} is not installed and has no source') + installing.add(dep.plugin_id) + try: + self.install( + dep.source, + scope=scope, + project_path=project_path, + link=link, + force=force, + format_hint=format_hint, + enabled=enabled, + _installing=installing, + ) + finally: + installing.discard(dep.plugin_id) + + @staticmethod + def _resolve_local_source(source: str) -> Path: + if source.startswith('ms-agent://'): + return Path(resolve_ms_agent_uri(source)).expanduser().resolve() + if source.startswith('file://'): + parsed = urlparse(source) + return Path(parsed.path).expanduser().resolve() + return Path(source).expanduser().resolve() + + def _fetch_source(self, source: str) -> '_FetchedSource': + if source.startswith('github://'): + return _fetch_github(source) + if source.startswith('modelscope://'): + return _fetch_modelscope(source) + local_path = self._resolve_local_source(source) + if _is_tarball(local_path): + return _fetch_tarball(local_path, source) + return _FetchedSource( + path=local_path, + source=source, + type='local', + ) + + +class _FetchedSource: + def __init__( + self, + *, + path: Path, + source: str, + type: str, + cleanup: tempfile.TemporaryDirectory | None = None, + resolved_sha: str | None = None, + ) -> None: + self.path = path + self.source = source + self.type = type + self.cleanup = cleanup + self.resolved_sha = resolved_sha + + def __enter__(self) -> '_FetchedSource': + return self + + def __exit__(self, exc_type, exc, tb) -> None: + if self.cleanup is not None: + self.cleanup.cleanup() + + +def resolve_ms_agent_uri(source: str) -> str: + """Resolve ``ms-agent://plugin/install?source=...`` to an inner install URI.""" + parsed = urlparse(source) + if parsed.scheme != 'ms-agent': + raise UnsupportedPluginSource(f'Not an ms-agent URI: {source}') + if parsed.netloc != 'plugin': + raise UnsupportedPluginSource(f'Unsupported ms-agent host: {parsed.netloc}') + path = (parsed.path or '/').lstrip('/') + if path != 'install': + raise UnsupportedPluginSource(f'Unsupported ms-agent path: {parsed.path}') + inner = parse_qs(parsed.query).get('source', [None])[0] + if not inner: + raise UnsupportedPluginSource( + 'ms-agent://plugin/install requires a source query parameter') + return unquote(inner) + + +def normalize_install_source(source: str) -> str: + """Resolve marketplace aliases such as ``hookify@claude-plugins-official``.""" + if source.startswith('ms-agent://'): + return normalize_install_source(resolve_ms_agent_uri(source)) + if source.startswith(('github://', 'modelscope://', 'file://')): + return source + if '/' in source or source.startswith('.'): + return source + if _MARKETPLACE_ALIAS_RE.match(source): + plugin_name, marketplace = source.rsplit('@', 1) + return resolve_marketplace_plugin_uri(plugin_name, marketplace) + return source + + +def resolve_marketplace_plugin_uri( + plugin_name: str, + marketplace: str, + *, + ref: str = 'main', +) -> str: + repo = _MARKETPLACE_REPOS.get(marketplace) + if repo is None: + raise UnsupportedPluginSource(f'Unknown marketplace: {marketplace}') + subdir = _lookup_marketplace_plugin_path(repo, plugin_name, ref=ref) + ref_part = f'@{ref}' if ref else '' + return f'github://{repo}{ref_part}#{subdir}' + + +def _lookup_marketplace_plugin_path( + repo: str, + plugin_name: str, + *, + ref: str = 'main', +) -> str: + url = ( + f'https://raw.githubusercontent.com/{repo}/{ref}' + f'/.claude-plugin/marketplace.json' + ) + try: + with urlopen(url, timeout=30) as resp: + data = json.load(resp) + except Exception as exc: + raise UnsupportedPluginSource( + f'Failed to load marketplace index for {repo}: {exc}') from exc + + for plugin in data.get('plugins', []): + if plugin.get('name') != plugin_name: + continue + source = plugin.get('source') + if isinstance(source, str): + return source.lstrip('./') + if isinstance(source, dict): + subdir = source.get('path') or source.get('subdir') + if isinstance(subdir, str) and subdir: + return subdir.lstrip('./') + break + raise UnsupportedPluginSource( + f'Plugin {plugin_name!r} not found in marketplace {repo}') + + +def _is_tarball(path: Path) -> bool: + name = path.name.lower() + return name.endswith(('.tar.gz', '.tgz', '.tar')) + + +def _fetch_tarball(path: Path, source: str) -> _FetchedSource: + if not path.is_file(): + raise UnsupportedPluginSource(f'Tarball not found: {path}') + tmp = tempfile.TemporaryDirectory(prefix='ms_agent_plugin_tar_') + extract_dir = Path(tmp.name) / 'extracted' + extract_dir.mkdir(parents=True, exist_ok=True) + mode = 'r:gz' if path.name.lower().endswith(('.tar.gz', '.tgz')) else 'r' + with tarfile.open(path, mode) as archive: + archive.extractall(extract_dir) + children = [child for child in extract_dir.iterdir() if child.name != '.DS_Store'] + root = ( + children[0] + if len(children) == 1 and children[0].is_dir() + else extract_dir + ) + return _FetchedSource( + path=root, + source=source, + type='local', + cleanup=tmp, + ) + + +def _source_type(source: str) -> str: + if source.startswith('github://'): + return 'github' + if source.startswith('modelscope://'): + return 'modelscope' + if source.startswith('ms-agent://'): + return 'ms-agent' + if _MARKETPLACE_ALIAS_RE.match(source): + return 'github' + return 'local' + + +_GIT_SHA_RE = re.compile(r'^[0-9a-f]{7,40}$', re.IGNORECASE) + + +def _fetch_github(source: str) -> _FetchedSource: + repo, ref, subdir = _parse_github_uri(source) + tmp = tempfile.TemporaryDirectory(prefix='ms_agent_plugin_git_') + clone_dir = Path(tmp.name) / 'repo' + is_sha = bool(ref and _GIT_SHA_RE.match(ref)) + clone_cmd = [ + 'git', + 'clone', + '--depth', + '1', + '--filter=blob:none', + ] + if ref and not is_sha: + clone_cmd.extend(['--branch', ref]) + clone_cmd.extend([f'https://github.com/{repo}.git', str(clone_dir)]) + subprocess.run(clone_cmd, check=True, capture_output=True, text=True) + if is_sha: + subprocess.run( + ['git', '-C', str(clone_dir), 'fetch', 'origin', ref], + check=False, + capture_output=True, + text=True, + ) + subprocess.run( + ['git', '-C', str(clone_dir), 'checkout', ref], + check=True, + capture_output=True, + text=True, + ) + if subdir: + subprocess.run( + ['git', '-C', str(clone_dir), 'sparse-checkout', 'set', subdir], + check=True, + capture_output=True, + text=True, + ) + sha = subprocess.run( + ['git', '-C', str(clone_dir), 'rev-parse', 'HEAD'], + check=True, + capture_output=True, + text=True, + ).stdout.strip() + return _FetchedSource( + path=clone_dir / subdir if subdir else clone_dir, + source=source, + type='github', + cleanup=tmp, + resolved_sha=sha or None, + ) + + +def _parse_github_uri(source: str) -> tuple[str, str | None, str | None]: + body = source[len('github://'):] + repo_part, _, subdir = body.partition('#') + repo, _, ref = repo_part.partition('@') + if repo.count('/') != 1: + raise UnsupportedPluginSource(f'Invalid github plugin URI: {source}') + return repo, ref or None, subdir or None + + +def _fetch_modelscope(source: str) -> _FetchedSource: + if snapshot_download is None: + raise UnsupportedPluginSource( + 'modelscope is required for modelscope:// plugin install') + repo, ref, subdir = _parse_modelscope_uri(source) + local_path = Path(snapshot_download(repo, revision=ref)).expanduser().resolve() + return _FetchedSource( + path=local_path / subdir if subdir else local_path, + source=source, + type='modelscope', + ) + + +def _parse_modelscope_uri(source: str) -> tuple[str, str | None, str | None]: + body = source[len('modelscope://'):] + repo_part, _, subdir = body.partition('#') + repo, _, ref = repo_part.partition('@') + if not repo or '/' not in repo: + raise UnsupportedPluginSource(f'Invalid modelscope plugin URI: {source}') + return repo, ref or None, subdir or None diff --git a/ms_agent/plugins/loader.py b/ms_agent/plugins/loader.py new file mode 100644 index 000000000..7210d7d33 --- /dev/null +++ b/ms_agent/plugins/loader.py @@ -0,0 +1,467 @@ +from __future__ import annotations + +import copy +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from ms_agent.hooks.loaders.claude import ClaudeSettingsLoader +from ms_agent.hooks.loaders.hermes import HermesShellLoader +from ms_agent.hooks.registry import HookRegistry +from ms_agent.plugins.manifest import PluginManifest +from ms_agent.plugins.types import AgentDef, CommandDef, UnsupportedCapability +from ms_agent.skill.schema import SkillSchemaParser +from ms_agent.skill.sources import SkillSource, SkillSourceType +from ms_agent.utils import get_logger + +logger = get_logger() + + +@dataclass(frozen=True) +class PluginLoadContext: + project_path: str + session_id: str + enabled_executors: frozenset[str] + plugin_data_root: Path + + +@dataclass(frozen=True) +class PluginHookContribution: + plugin_id: str + registry: HookRegistry + plugin_root: Path + plugin_data_dir: Path + + +@dataclass +class PluginLoadResult: + skill_sources: list[SkillSource] = field(default_factory=list) + hook_registries: list[PluginHookContribution] = field(default_factory=list) + mcp_servers: dict[str, dict[str, Any]] = field(default_factory=dict) + command_defs: list[CommandDef] = field(default_factory=list) + agent_defs: list[AgentDef] = field(default_factory=list) + settings_patch: dict[str, Any] = field(default_factory=dict) + bin_paths: list[Path] = field(default_factory=list) + user_config_schema: dict[str, Any] = field(default_factory=dict) + ui_metadata: dict[str, Any] = field(default_factory=dict) + unsupported: list[UnsupportedCapability] = field(default_factory=list) + + def merge(self, other: 'PluginLoadResult') -> 'PluginLoadResult': + self.skill_sources.extend(other.skill_sources) + self.hook_registries.extend(other.hook_registries) + for name, server in other.mcp_servers.items(): + candidate = name + if candidate in self.mcp_servers: + plugin_id = server.get('plugin_id') + base = f'plugin.{plugin_id}.{name}' if plugin_id else name + candidate = _unique_mcp_name(base, set(self.mcp_servers)) + self.mcp_servers[candidate] = server + self.command_defs.extend(other.command_defs) + self.agent_defs.extend(other.agent_defs) + self.settings_patch.update(other.settings_patch) + self.bin_paths.extend(other.bin_paths) + self.user_config_schema.update(other.user_config_schema) + self.ui_metadata.update(other.ui_metadata) + self.unsupported.extend(other.unsupported) + return self + + +class PluginLoader: + @staticmethod + def load(manifest: PluginManifest, ctx: PluginLoadContext) -> PluginLoadResult: + result = PluginLoadResult() + data_dir = ctx.plugin_data_root / manifest.plugin_id + data_dir.mkdir(parents=True, exist_ok=True) + user_config = _load_user_config(data_dir) + + result.skill_sources.extend(_load_skill_sources(manifest)) + result.command_defs.extend(_load_commands(manifest)) + result.skill_sources.extend( + _command_defs_to_skill_sources(manifest, result.command_defs)) + result.agent_defs.extend(_load_agents(manifest)) + + if 'hooks' in manifest.capabilities: + registry = _load_hook_registry(manifest, ctx, data_dir, user_config) + if not registry.is_empty: + registry = registry.with_plugin_source( + plugin_id=manifest.plugin_id, + plugin_root=str(manifest.root), + plugin_data_dir=str(data_dir), + ) + result.hook_registries.append( + PluginHookContribution( + plugin_id=manifest.plugin_id, + registry=registry, + plugin_root=manifest.root, + plugin_data_dir=data_dir, + )) + + result.mcp_servers.update(_load_mcp_servers(manifest, data_dir, ctx)) + result.settings_patch.update(_load_settings(manifest.root)) + result.bin_paths.extend(_load_bin_paths(manifest.root)) + result.user_config_schema.update((manifest.raw or {}).get('userConfig') or {}) + result.ui_metadata.update(_load_ui_metadata(manifest)) + result.unsupported.extend(_load_unsupported(manifest)) + return result + + @staticmethod + def load_all( + manifests: list[PluginManifest], + ctx: PluginLoadContext, + ) -> PluginLoadResult: + result = PluginLoadResult() + for manifest in sorted(manifests, key=lambda item: item.plugin_id): + try: + result.merge(PluginLoader.load(manifest, ctx)) + except Exception as exc: + logger.warning( + 'Failed to load plugin %s: %s', manifest.plugin_id, exc) + return result + + +def _load_skill_sources(manifest: PluginManifest) -> list[SkillSource]: + return [ + SkillSource( + type=SkillSourceType.LOCAL_DIR, + path=str(path), + origin='plugin', + plugin_id=manifest.plugin_id, + capability='skills', + ) + for path in manifest.resolve_paths('skills') + ] + + +def _command_defs_to_skill_sources( + manifest: PluginManifest, + command_defs: list[CommandDef], +) -> list[SkillSource]: + """Expose plugin commands as SkillCatalog sources (strategy A).""" + return [ + SkillSource( + type=SkillSourceType.LOCAL_DIR, + path=cmd.path, + origin='plugin', + plugin_id=manifest.plugin_id, + capability='commands', + ) + for cmd in command_defs + if cmd.plugin_id == manifest.plugin_id + ] + + +def _iter_command_files(path: Path) -> list[Path]: + if path.is_file(): + return [path] + return sorted(path.glob('*.md')) + + +def _iter_agent_files(path: Path) -> list[Path]: + if path.is_file(): + return [path] + files = sorted(path.glob('*.md')) + for child in sorted(path.iterdir()): + if not child.is_dir(): + continue + agent_md = child / 'AGENT.md' + if agent_md.is_file(): + logger.warning( + 'Deprecated agents/*/AGENT.md layout at %s; prefer agents/*.md', + agent_md, + ) + files.append(agent_md) + return files + + +def _load_commands(manifest: PluginManifest) -> list[CommandDef]: + defs: list[CommandDef] = [] + for path in manifest.resolve_paths('commands'): + for file_path in _iter_command_files(path): + frontmatter = _read_frontmatter(file_path) + defs.append( + CommandDef( + plugin_id=manifest.plugin_id, + name=frontmatter.get('name') or file_path.stem, + path=str(file_path), + description=frontmatter.get('description'), + argument_hint=frontmatter.get( + 'argument-hint', frontmatter.get('argument_hint')), + )) + return defs + + +def _load_agents(manifest: PluginManifest) -> list[AgentDef]: + defs: list[AgentDef] = [] + for path in manifest.resolve_paths('agents'): + for file_path in _iter_agent_files(path): + frontmatter = _read_frontmatter(file_path) + defs.append( + AgentDef( + plugin_id=manifest.plugin_id, + name=frontmatter.get('name') or file_path.stem, + path=str(file_path), + description=frontmatter.get('description'), + model=frontmatter.get('model'), + tools=_as_tuple(frontmatter.get('tools')), + skills=_as_tuple(frontmatter.get('skills')), + disallowed_tools=_as_tuple( + frontmatter.get('disallowedTools', frontmatter.get( + 'disallowed_tools'))), + )) + return defs + + +def _load_mcp_servers( + manifest: PluginManifest, + data_dir: Path, + ctx: PluginLoadContext, +) -> dict[str, dict[str, Any]]: + raw = manifest.raw or {} + candidates: list[Any] = [] + if isinstance(raw.get('mcpServers'), dict): + candidates.append(raw['mcpServers']) + for path in manifest.resolve_paths('mcp'): + if path.is_file(): + try: + with open(path, encoding='utf-8') as f: + candidates.append(json.load(f)) + except (OSError, json.JSONDecodeError): + continue + + servers: dict[str, dict[str, Any]] = {} + for item in candidates: + entries = item.get('mcpServers', item) if isinstance(item, dict) else {} + if not isinstance(entries, dict): + continue + for name, server in entries.items(): + if not isinstance(server, dict): + continue + server_name = str(name) + if server_name in servers: + server_name = _unique_mcp_name( + f'plugin.{manifest.plugin_id}.{server_name}', + set(servers), + ) + expanded = _expand_vars( + copy.deepcopy(server), + manifest.root, + data_dir, + Path(ctx.project_path), + ) + expanded['source'] = 'plugin' + expanded['plugin_id'] = manifest.plugin_id + expanded.setdefault('enabled', True) + servers[server_name] = expanded + return servers + + +def _load_hook_registry( + manifest: PluginManifest, + ctx: PluginLoadContext, + data_dir: Path, + user_config: dict[str, Any], +) -> HookRegistry: + registry = HookRegistry(_index={}) + raw_hooks = (manifest.raw or {}).get('hooks') + if isinstance(raw_hooks, dict): + try: + hooks = raw_hooks.get('hooks', raw_hooks) + registry = registry.merge( + ClaudeSettingsLoader.parse_hooks( + hooks, + ctx.project_path, + plugin_root=str(manifest.root), + plugin_data_dir=str(data_dir), + user_config=user_config, + enabled_executors=ctx.enabled_executors, + )) + except Exception as exc: + logger.warning( + 'Failed to load inline hooks for plugin %s: %s', + manifest.plugin_id, + exc, + ) + + loaded_paths = set() + for path in manifest.resolve_paths('hooks'): + loaded_paths.add(path.resolve()) + try: + if path.suffix in {'.yaml', '.yml'}: + loaded = HermesShellLoader.load_file( + path, + ctx.project_path, + plugin_root=str(manifest.root), + plugin_data_dir=str(data_dir), + user_config=user_config, + enabled_executors=ctx.enabled_executors, + ) + else: + loaded = ClaudeSettingsLoader.parse_hooks_file( + path, + plugin_root=str(manifest.root), + plugin_data_dir=str(data_dir), + user_config=user_config, + project_path=ctx.project_path, + enabled_executors=ctx.enabled_executors, + ) + registry = registry.merge(loaded) + except Exception as exc: + logger.warning( + 'Failed to load hooks for plugin %s from %s: %s', + manifest.plugin_id, + path, + exc, + ) + + for path in ( + manifest.root / 'hooks' / 'hermes.yaml', + manifest.root / 'hooks' / 'config.yaml', + ): + if path.resolve() in loaded_paths: + continue + if path.is_file(): + try: + registry = registry.merge( + HermesShellLoader.load_file( + path, + ctx.project_path, + plugin_root=str(manifest.root), + plugin_data_dir=str(data_dir), + user_config=user_config, + enabled_executors=ctx.enabled_executors, + )) + except Exception as exc: + logger.warning( + 'Failed to load Hermes hooks for plugin %s from %s: %s', + manifest.plugin_id, + path, + exc, + ) + return registry + + +def _unique_mcp_name(base: str, used: set[str]) -> str: + candidate = base + suffix = 1 + while candidate in used: + candidate = f'{base}.{suffix}' + suffix += 1 + return candidate + + +def _load_settings(root: Path) -> dict[str, Any]: + path = root / 'settings.json' + if not path.is_file(): + return {} + try: + with open(path, encoding='utf-8') as f: + data = json.load(f) + except (OSError, json.JSONDecodeError): + return {} + if not isinstance(data, dict): + return {} + allowed = {'agent', 'subagentStatusLine'} + return {key: value for key, value in data.items() if key in allowed} + + +def _load_bin_paths(root: Path) -> list[Path]: + path = root / 'bin' + if not path.is_dir(): + return [] + return [path] + + +def _load_ui_metadata(manifest: PluginManifest) -> dict[str, Any]: + raw = manifest.raw or {} + metadata: dict[str, Any] = {} + for key in ('author', 'homepage', 'repository', 'license', 'keywords', + 'displayName', 'interface'): + if key in raw: + metadata[key] = raw[key] + assets = manifest.root / 'assets' + if assets.is_dir(): + metadata['assets_path'] = str(assets) + return metadata + + +def _load_unsupported(manifest: PluginManifest) -> list[UnsupportedCapability]: + unsupported: list[UnsupportedCapability] = [] + for capability, scan in manifest.components.items(): + if scan.status in {'unsupported', 'detect_only'}: + unsupported.append( + UnsupportedCapability( + capability=capability, + status=scan.status, + hint=scan.hint, + )) + return unsupported + + +def _read_frontmatter(path: Path) -> dict[str, Any]: + try: + content = path.read_text(encoding='utf-8') + except OSError: + return {} + return SkillSchemaParser.parse_yaml_frontmatter(content) or {} + + +def _load_user_config(data_dir: Path) -> dict[str, Any]: + path = data_dir / 'config.json' + if not path.is_file(): + return {} + try: + with open(path, encoding='utf-8') as f: + data = json.load(f) + except (OSError, json.JSONDecodeError): + return {} + return data if isinstance(data, dict) else {} + + +def _as_tuple(value: Any) -> tuple[str, ...]: + if value is None: + return () + if isinstance(value, str): + return tuple(item.strip() for item in value.split(',') if item.strip()) + if isinstance(value, (list, tuple)): + return tuple(str(item) for item in value) + return () + + +def _expand_vars( + value: Any, + plugin_root: Path, + plugin_data_dir: Path, + project_path: Path, + user_config: dict[str, Any] | None = None, +) -> Any: + if user_config is None: + user_config = _load_user_config(plugin_data_dir) + if isinstance(value, str): + expanded = ( + value + .replace('${MS_AGENT_PLUGIN_ROOT}', str(plugin_root)) + .replace('${CLAUDE_PLUGIN_ROOT}', str(plugin_root)) + .replace('${MS_AGENT_PLUGIN_DATA}', str(plugin_data_dir)) + .replace('${CLAUDE_PLUGIN_DATA}', str(plugin_data_dir)) + .replace('${MS_AGENT_PROJECT_DIR}', str(project_path)) + .replace('${CLAUDE_PROJECT_DIR}', str(project_path)) + ) + for key, item in user_config.items(): + expanded = expanded.replace(f'${{user_config.{key}}}', str(item)) + expanded = expanded.replace( + f'${{CLAUDE_PLUGIN_OPTION_{key.upper()}}}', str(item)) + return expanded + if isinstance(value, list): + return [ + _expand_vars( + item, plugin_root, plugin_data_dir, project_path, user_config) + for item in value + ] + if isinstance(value, dict): + return { + key: _expand_vars( + item, plugin_root, plugin_data_dir, project_path, user_config) + for key, item in value.items() + } + return value diff --git a/ms_agent/plugins/manifest.py b/ms_agent/plugins/manifest.py new file mode 100644 index 000000000..fb9b1c7dd --- /dev/null +++ b/ms_agent/plugins/manifest.py @@ -0,0 +1,509 @@ +from __future__ import annotations + +import json +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from ms_agent.plugins.types import ( + LOADABLE_CAPABILITIES, + ComponentScan, + InstallSource, + PluginFormat, + PluginRecord, +) + +_PLUGIN_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]{0,63}$') +_SEMVER_RE = re.compile(r'^\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?$') + +_MANIFEST_CANDIDATES: tuple[tuple[str, PluginFormat], ...] = ( + ('.ms-agent-plugin/plugin.json', PluginFormat.MS_AGENT), + ('plugin.json', PluginFormat.GENERIC), + ('.claude-plugin/plugin.json', PluginFormat.CLAUDE), + ('.codex-plugin/plugin.json', PluginFormat.CODEX), + ('.cursor-plugin/plugin.json', PluginFormat.CURSOR), + ('openclaw.plugin.json', PluginFormat.OPENCLAW), +) + + +class PluginError(ValueError): + """Base class for plugin parsing and validation errors.""" + + +class AmbiguousPluginManifest(PluginError): + """Raised when multiple non-native manifest formats are present.""" + + +class EmptyPluginError(PluginError): + """Raised when a plugin contains no loadable component.""" + + +class InvalidPluginManifest(PluginError): + """Raised when plugin.json is invalid or violates required fields.""" + + +@dataclass(frozen=True) +class ManifestCandidate: + path: Path + format: PluginFormat + raw: dict[str, Any] + + @property + def rel_path(self) -> str: + return self.path.as_posix() + + +@dataclass(frozen=True) +class PluginManifest: + plugin_id: str + name: str + version: str + description: str + root: Path + format: PluginFormat + manifest_path: str + capabilities: frozenset[str] + components: dict[str, ComponentScan] + source: InstallSource + installed_at: str | None = None + enabled: bool = True + raw: dict[str, Any] | None = None + + @classmethod + def parse( + cls, + root: str | Path, + *, + record: PluginRecord | dict[str, Any] | None = None, + format_hint: str | PluginFormat | None = None, + ) -> 'PluginManifest': + root_path = Path(root).expanduser().resolve() + plugin_record: PluginRecord | None = None + if record is not None: + plugin_record = ( + record if isinstance(record, PluginRecord) + else PluginRecord.from_dict(record) + ) + + if plugin_record and plugin_record.manifest_path: + manifest_path = _locked_manifest_path( + root_path, + plugin_record.manifest_path, + ) + raw = _read_manifest(root_path / manifest_path) + fmt = _coerce_format(plugin_record.format) or _format_for_path( + manifest_path.as_posix(), raw) + candidate = ManifestCandidate(manifest_path, fmt, raw) + else: + candidate = detect_manifest(root_path, format_hint=format_hint) + + raw = candidate.raw + name = str(raw.get('name') or '').strip() + if not name: + raise InvalidPluginManifest('Plugin manifest requires "name"') + plugin_id = normalize_plugin_id(name) + if not _PLUGIN_NAME_RE.match(plugin_id): + raise InvalidPluginManifest(f'Invalid plugin name: {name}') + + version = str(raw.get('version') or 'latest') + if version != 'latest' and not _SEMVER_RE.match(version): + raise InvalidPluginManifest(f'Invalid plugin version: {version}') + + components = scan_components(root_path, raw) + capabilities = frozenset( + key for key, scan in components.items() + if key in LOADABLE_CAPABILITIES and scan.status == 'ready' + ) + if not capabilities: + raise EmptyPluginError(f'Plugin has no loadable components: {root_path}') + + enabled = ( + plugin_record.enabled + if plugin_record + else bool(raw.get('defaultEnabled', True)) + ) + return cls( + plugin_id=plugin_id, + name=name, + version=version, + description=str(raw.get('description') or ''), + root=root_path, + format=candidate.format, + manifest_path=candidate.rel_path, + capabilities=capabilities, + components=components, + source=InstallSource.from_raw( + plugin_record.source if plugin_record else raw.get('source')), + installed_at=plugin_record.installed_at if plugin_record else None, + enabled=enabled, + raw=raw, + ) + + def resolve_paths(self, kind: str) -> list[Path]: + raw = self.raw or {} + if kind == 'skills': + paths = _paths_from_manifest_field(self.root, raw.get('skills')) + if not paths and (self.root / 'skills').is_dir(): + paths.append(self.root / 'skills') + if (self.root / 'SKILL.md').is_file(): + paths.append(self.root) + return _dedupe_paths(paths) + if kind == 'commands': + paths = _paths_from_manifest_field(self.root, raw.get('commands')) + if not paths and (self.root / 'commands').is_dir(): + paths.append(self.root / 'commands') + return _dedupe_paths(paths) + if kind == 'agents': + paths = _paths_from_manifest_field(self.root, raw.get('agents')) + if not paths and (self.root / 'agents').is_dir(): + paths.append(self.root / 'agents') + return _dedupe_paths(paths) + if kind == 'hooks': + paths = _paths_from_manifest_field(self.root, raw.get('hooks')) + hooks_json = self.root / 'hooks' / 'hooks.json' + if hooks_json.is_file(): + paths.append(hooks_json) + return _dedupe_paths(paths) + if kind == 'mcp': + paths = _paths_from_manifest_field(self.root, raw.get('mcpServers')) + for default_path in ( + self.root / '.mcp.json', + self.root / 'tools' / 'mcp.json', + self.root / 'openclaw.json', + ): + if default_path.is_file(): + paths.append(default_path) + return _dedupe_paths(paths) + return [] + + +def normalize_plugin_id(name: str) -> str: + return name.strip().lower().replace('/', '-') + + +def detect_manifest( + root: str | Path, + *, + format_hint: str | PluginFormat | None = None, +) -> ManifestCandidate: + root_path = Path(root).expanduser().resolve() + candidates = _scan_manifest_candidates(root_path) + if not candidates: + synthetic = _detect_manifestless_bundle(root_path) + if synthetic is not None: + return synthetic + raise InvalidPluginManifest(f'No plugin manifest found in {root_path}') + + if format_hint: + wanted = _coerce_format(format_hint) + for candidate in candidates: + if candidate.format == wanted: + return candidate + raise InvalidPluginManifest( + f'No {wanted.value if wanted else format_hint} manifest found') + + if len(candidates) == 1: + return candidates[0] + + native = _pick_ms_agent_native(candidates) + if native is not None: + return native + + raise AmbiguousPluginManifest( + 'Multiple plugin manifests found: ' + + ', '.join(c.rel_path for c in candidates)) + + +def scan_components( + root: str | Path, + manifest: dict[str, Any] | None = None, +) -> dict[str, ComponentScan]: + root_path = Path(root) + manifest = manifest or {} + components: dict[str, ComponentScan] = {} + + skill_count = _count_skill_dirs(root_path / 'skills') + if (root_path / 'SKILL.md').is_file(): + skill_count += 1 + if manifest.get('skills') and skill_count == 0: + skill_count = _count_paths(root_path, manifest['skills'], 'SKILL.md') + components['skills'] = _scan('ready', skill_count, root_path / 'skills') + + command_count = _count_markdown_files(root_path / 'commands') + if manifest.get('commands') and command_count == 0: + command_count = _count_paths(root_path, manifest['commands'], '*.md') + components['commands'] = _scan('ready', command_count, root_path / 'commands') + + agent_count = _count_markdown_files(root_path / 'agents') + agent_count += _count_agent_md_subdirs(root_path / 'agents') + if manifest.get('agents') and agent_count == 0: + agent_count = _count_paths(root_path, manifest['agents'], '*.md') + components['agents'] = _scan('ready', agent_count, root_path / 'agents') + + hook_count = 0 + if (root_path / 'hooks' / 'hooks.json').is_file(): + hook_count += 1 + if (root_path / 'hooks' / 'hermes.yaml').is_file(): + hook_count += 1 + if (root_path / 'hooks' / 'config.yaml').is_file(): + hook_count += 1 + hooks_field = manifest.get('hooks') + if hooks_field: + if isinstance(hooks_field, dict): + hook_count += 1 + else: + hook_count += sum( + 1 for path in _paths_from_manifest_field(root_path, hooks_field) + if path.exists() + ) + components['hooks'] = _scan('ready', hook_count, root_path / 'hooks') + + mcp_count = 0 + if (root_path / '.mcp.json').is_file(): + mcp_count += 1 + if (root_path / 'tools' / 'mcp.json').is_file(): + mcp_count += 1 + if (root_path / 'openclaw.json').is_file(): + mcp_count += 1 + mcp_field = manifest.get('mcpServers') + if mcp_field: + if isinstance(mcp_field, dict): + mcp_count += 1 + else: + mcp_count += sum( + 1 for path in _paths_from_manifest_field(root_path, mcp_field) + if path.exists() + ) + components['mcp'] = _scan('ready', mcp_count, root_path / '.mcp.json') + + settings_count = 1 if _non_empty_json(root_path / 'settings.json') else 0 + components['settings'] = _scan( + 'ready', settings_count, root_path / 'settings.json') + + bin_count = _count_executable_files(root_path / 'bin') + components['bin'] = _scan('ready', bin_count, root_path / 'bin') + + user_config_count = 1 if manifest.get('userConfig') else 0 + components['user_config'] = _scan('ready', user_config_count, None) + + components['assets'] = _detect_only(root_path / 'assets') + components['apps'] = _detect_only(root_path / '.app.json') + components['rules'] = _detect_only(root_path / 'rules') + components['lsp'] = _detect_only(root_path / '.lsp.json') + components['output_styles'] = _detect_only(root_path / 'output-styles') + components['themes'] = _detect_only(root_path / 'themes') + components['monitors'] = _detect_only(root_path / 'monitors') + components['channels'] = _scan( + 'detect_only' if manifest.get('channels') else 'skipped', + 1 if manifest.get('channels') else 0, + None, + ) + components['hooks_openclaw_internal'] = _scan( + 'unsupported' if list(root_path.glob('hooks/*/HOOK.md')) else 'skipped', + len(list(root_path.glob('hooks/*/HOOK.md'))), + root_path / 'hooks', + hint='OpenClaw in-process hooks are detect-only.', + ) + components['hooks_hermes_python'] = _scan('skipped', 0, None) + return components + + +def _scan_manifest_candidates(root: Path) -> list[ManifestCandidate]: + candidates: list[ManifestCandidate] = [] + for rel, default_format in _MANIFEST_CANDIDATES: + path = root / rel + if path.is_file(): + raw = _read_manifest(path) + fmt = _format_for_path(rel, raw, default_format) + candidates.append(ManifestCandidate(Path(rel), fmt, raw)) + return candidates + + +def _detect_manifestless_bundle(root: Path) -> ManifestCandidate | None: + package_json = root / 'package.json' + if package_json.is_file(): + try: + raw_package = _read_manifest(package_json) + except InvalidPluginManifest: + raw_package = {} + openclaw_cfg = raw_package.get('openclaw') or raw_package.get( + 'openclaw.hooks') + if openclaw_cfg or list(root.glob('hooks/*/HOOK.md')): + raw = { + 'name': raw_package.get('name') or root.name, + 'version': raw_package.get('version', 'latest'), + 'description': raw_package.get('description', ''), + } + return ManifestCandidate(Path('package.json'), PluginFormat.OPENCLAW, raw) + + for rel in ('hooks/hermes.yaml', 'hooks/config.yaml'): + if (root / rel).is_file(): + raw = { + 'name': root.name, + 'version': 'latest', + 'description': 'Hermes shell hook bundle', + } + return ManifestCandidate(Path(rel), PluginFormat.HERMES, raw) + return None + + +def _pick_ms_agent_native( + candidates: list[ManifestCandidate], +) -> ManifestCandidate | None: + for candidate in candidates: + if candidate.format == PluginFormat.MS_AGENT: + return candidate + return None + + +def _format_for_path( + rel_path: str, + raw: dict[str, Any], + default: PluginFormat | None = None, +) -> PluginFormat: + if rel_path == 'plugin.json' and raw.get('ms_agent'): + return PluginFormat.MS_AGENT + if default and default != PluginFormat.GENERIC: + return default + return default or PluginFormat.GENERIC + + +def _coerce_format(value: str | PluginFormat | None) -> PluginFormat | None: + if value is None: + return None + if isinstance(value, PluginFormat): + return value + return PluginFormat(str(value)) + + +def _read_manifest(path: Path) -> dict[str, Any]: + try: + with open(path, encoding='utf-8') as f: + data = json.load(f) + except (OSError, json.JSONDecodeError) as exc: + raise InvalidPluginManifest(f'Invalid plugin manifest: {path}') from exc + if not isinstance(data, dict): + raise InvalidPluginManifest(f'Plugin manifest must be an object: {path}') + return data + + +def _locked_manifest_path(root: Path, value: str) -> Path: + raw_path = Path(value).expanduser() + resolved = raw_path if raw_path.is_absolute() else root / raw_path + try: + return resolved.resolve().relative_to(root) + except ValueError as exc: + raise InvalidPluginManifest( + f'Locked manifest path escapes plugin root: {value}') from exc + + +def _paths_from_manifest_field(root: Path, raw: Any) -> list[Path]: + if not raw or isinstance(raw, dict): + return [] + values = raw if isinstance(raw, list) else [raw] + paths: list[Path] = [] + for value in values: + if not isinstance(value, str): + continue + paths.append(_resolve_plugin_child(root, value)) + return paths + + +def _resolve_plugin_child(root: Path, value: str) -> Path: + root_resolved = root.expanduser().resolve() + raw_path = Path(value).expanduser() + path = raw_path if raw_path.is_absolute() else root_resolved / raw_path + resolved = path.resolve() + try: + resolved.relative_to(root_resolved) + except ValueError as exc: + raise InvalidPluginManifest( + f'Plugin component path escapes plugin root: {value}') from exc + return resolved + + +def _dedupe_paths(paths: list[Path]) -> list[Path]: + seen: set[str] = set() + result: list[Path] = [] + for path in paths: + key = str(path) + if key not in seen and path.exists(): + seen.add(key) + result.append(path) + return result + + +def _scan( + ready_status: str, + count: int, + path: Path | None, + hint: str | None = None, +) -> ComponentScan: + if count <= 0: + return ComponentScan(status='skipped', count=0) + return ComponentScan( + status=ready_status, + count=count, + path=str(path) if path else None, + hint=hint, + ) + + +def _detect_only(path: Path) -> ComponentScan: + if path.is_dir(): + count = sum(1 for _ in path.iterdir()) + return ComponentScan(status='detect_only', count=count, path=str(path)) + if path.is_file(): + return ComponentScan(status='detect_only', count=1, path=str(path)) + return ComponentScan(status='skipped', count=0) + + +def _count_skill_dirs(path: Path) -> int: + if not path.is_dir(): + return 0 + return sum(1 for child in path.iterdir() if (child / 'SKILL.md').is_file()) + + +def _count_markdown_files(path: Path) -> int: + if not path.is_dir(): + return 0 + return sum(1 for child in path.glob('*.md') if child.is_file()) + + +def _count_agent_md_subdirs(path: Path) -> int: + if not path.is_dir(): + return 0 + return sum( + 1 for child in path.iterdir() + if child.is_dir() and (child / 'AGENT.md').is_file() + ) + + +def _count_executable_files(path: Path) -> int: + if not path.is_dir(): + return 0 + return sum(1 for child in path.iterdir() if child.is_file()) + + +def _count_paths(root: Path, raw: Any, marker: str) -> int: + count = 0 + for path in _paths_from_manifest_field(root, raw): + if path.is_file(): + count += 1 + elif path.is_dir() and marker == 'SKILL.md': + count += _count_skill_dirs(path) + elif path.is_dir(): + count += len(list(path.glob(marker))) + return count + + +def _non_empty_json(path: Path) -> bool: + if not path.is_file(): + return False + try: + with open(path, encoding='utf-8') as f: + data = json.load(f) + return bool(data) + except (OSError, json.JSONDecodeError): + return False diff --git a/ms_agent/plugins/registry.py b/ms_agent/plugins/registry.py new file mode 100644 index 000000000..c747bc752 --- /dev/null +++ b/ms_agent/plugins/registry.py @@ -0,0 +1,80 @@ +"""Installed plugin index — disk records plus in-memory manifest cache.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +from ms_agent.plugins.config_manager import PluginConfigManager, PluginScope +from ms_agent.plugins.manifest import PluginManifest, PluginError, normalize_plugin_id +from ms_agent.plugins.types import PluginRecord + +PluginScopeArg = Literal['global', 'project', 'merged'] + + +class PluginRegistry: + """Facade over ``plugins.json`` with optional parsed manifest cache.""" + + def __init__( + self, + config_manager: PluginConfigManager | None = None, + *, + global_root: str | Path = '~/.ms_agent', + project_root: str | Path | None = None, + ) -> None: + self.global_root = Path(global_root).expanduser() + self.config_manager = config_manager or PluginConfigManager( + self.global_root, + project_root, + ) + self._manifest_cache: dict[str, PluginManifest] = {} + + def list_records(self, scope: PluginScopeArg = 'merged') -> list[PluginRecord]: + return self.config_manager.list(scope) # type: ignore[arg-type] + + def get_record( + self, + plugin_id: str, + scope: PluginScopeArg = 'merged', + ) -> PluginRecord | None: + return self.config_manager.get(plugin_id, scope) # type: ignore[arg-type] + + def is_installed(self, plugin_id: str, scope: PluginScopeArg = 'merged') -> bool: + return self.get_record(plugin_id, scope) is not None + + def get_manifest( + self, + plugin_id: str, + *, + scope: PluginScopeArg = 'merged', + use_cache: bool = True, + ) -> PluginManifest | None: + if use_cache and plugin_id in self._manifest_cache: + return self._manifest_cache[plugin_id] + record = self.get_record(plugin_id, scope) + if record is None or not record.path: + return None + try: + manifest = PluginManifest.parse(record.path, record=record) + except PluginError: + return None + self._manifest_cache[plugin_id] = manifest + return manifest + + def invalidate(self, plugin_id: str | None = None) -> None: + if plugin_id is None: + self._manifest_cache.clear() + return + self._manifest_cache.pop(plugin_id, None) + + def managed_plugin_paths(self, project_path: str | None = None) -> set[str]: + """Resolved install paths for deduplicating legacy ``config.plugins``.""" + records = self.config_manager.load_merged(project_path) + paths: set[str] = set() + for record in records: + if record.path: + paths.add(str(Path(record.path).expanduser().resolve())) + return paths + + def managed_plugin_ids(self, project_path: str | None = None) -> set[str]: + return {record.id for record in self.config_manager.load_merged(project_path)} diff --git a/ms_agent/plugins/runtime.py b/ms_agent/plugins/runtime.py new file mode 100644 index 000000000..f871ee20f --- /dev/null +++ b/ms_agent/plugins/runtime.py @@ -0,0 +1,635 @@ +from __future__ import annotations + +import asyncio +from copy import deepcopy +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from omegaconf import OmegaConf + +from ms_agent.plugins.config_manager import PluginConfigManager +from ms_agent.plugins.agents import PluginAgentRegistry +from ms_agent.plugins.installer import PluginInstaller +from ms_agent.plugins.loader import ( + PluginLoadContext, + PluginLoadResult, + PluginLoader, +) +from ms_agent.plugins.manifest import PluginManifest, PluginError +from ms_agent.plugins.registry import PluginRegistry +from ms_agent.plugins.types import PluginRecord, component_status_dict +from ms_agent.utils import get_logger + +logger = get_logger() +_MISSING = object() + + +@dataclass +class PluginRuntime: + config_manager: PluginConfigManager | None = None + registry: PluginRegistry | None = None + global_root: str | Path = '~/.ms_agent' + skill_runtime: Any | None = None + hook_runtime_factory: Any | None = None + mcp_runtime: Any | None = None + manifests: list[PluginManifest] = field(default_factory=list) + load_result: PluginLoadResult = field(default_factory=PluginLoadResult) + agent_registry: PluginAgentRegistry = field(default_factory=PluginAgentRegistry) + _applied_skill_paths: set[str] = field(default_factory=set, init=False) + _applied_mcp_names: set[str] = field(default_factory=set, init=False) + _applied_bin_paths: set[str] = field(default_factory=set, init=False) + _applied_settings_originals: dict[str, Any] = field(default_factory=dict, init=False) + _project_path: str | None = field(default=None, init=False) + _session_id: str = field(default='', init=False) + _config: Any | None = field(default=None, init=False) + _configured_plugin_ids: set[str] = field(default_factory=set, init=False) + _enabled_executors: frozenset[str] = field( + default_factory=lambda: frozenset({'command'}), + init=False, + ) + + def __post_init__(self) -> None: + self.global_root = Path(self.global_root).expanduser() + if self.config_manager is None: + self.config_manager = PluginConfigManager(self.global_root) + if self.registry is None: + self.registry = PluginRegistry(self.config_manager) + self._reload_lock = asyncio.Lock() + + async def start( + self, + project_path: str, + session_id: str, + *, + config: Any | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> None: + async with self._reload_lock: + self._start_unlocked( + project_path, + session_id, + config=config, + enabled_executors=enabled_executors, + ) + + def start_sync( + self, + project_path: str, + session_id: str, + *, + config: Any | None = None, + enabled_executors: frozenset[str] = frozenset({'command'}), + ) -> None: + self._start_unlocked( + project_path, + session_id, + config=config, + enabled_executors=enabled_executors, + ) + + def _start_unlocked( + self, + project_path: str, + session_id: str, + *, + config: Any | None, + enabled_executors: frozenset[str], + ) -> None: + records = self._records_from_config(config, project_path) + self.registry.invalidate() # type: ignore[union-attr] + self._project_path = project_path + self._session_id = session_id + self._config = config + self._enabled_executors = enabled_executors + self._configured_plugin_ids = {record.id for record in records} + manifests: list[PluginManifest] = [] + for record in records: + if not record.enabled: + continue + try: + manifests.append(PluginManifest.parse(record.path, record=record)) + except PluginError as exc: + logger.warning('Failed to parse plugin %s: %s', record.id, exc) + + ctx = PluginLoadContext( + project_path=project_path, + session_id=session_id, + enabled_executors=enabled_executors, + plugin_data_root=self.global_root / 'plugins' / 'data', + ) + self.manifests = manifests + self.load_result = PluginLoader.load_all(manifests, ctx) + self.agent_registry.rebuild(self.load_result.agent_defs) + if config is not None: + self.apply_to_config(config) + self._sync_skill_runtime(config) + + def apply_to_config(self, config: Any) -> None: + self._remove_applied_skill_sources(config) + self._remove_applied_mcp_servers(config) + self._remove_plugin_owned_mcp_servers(config, self._configured_plugin_ids) + self._remove_applied_bin_paths(config) + self._revert_applied_settings(config) + + if self.load_result.skill_sources: + if not hasattr(config, 'skills') or config.skills is None: + config.skills = OmegaConf.create({'sources': []}) + if not hasattr(config.skills, 'sources') or config.skills.sources is None: + config.skills.sources = [] + existing = { + _skill_source_path(source) + for source in config.skills.sources + } + for source in self.load_result.skill_sources: + if source.path in existing: + continue + config.skills.sources.append({ + 'type': source.type.value, + 'path': source.path, + 'enabled': source.enabled, + 'origin': source.origin, + 'plugin_id': source.plugin_id, + 'capability': source.capability, + }) + existing.add(str(source.path)) + self._applied_skill_paths.add(str(source.path)) + + if self.load_result.mcp_servers: + if not hasattr(config, 'tools') or config.tools is None: + config.tools = OmegaConf.create({}) + current = {} + if hasattr(config, '_merged_mcp') and config._merged_mcp: + current = OmegaConf.to_container( + config._merged_mcp, resolve=True) or {} + existing_names = set(config.tools.keys()) | set( + (current.get('servers') or {}).keys()) + plugin_servers = dedupe_mcp_server_names( + self.load_result.mcp_servers, + existing_names, + ) + self.load_result.mcp_servers = plugin_servers + for name, server in plugin_servers.items(): + config.tools[name] = OmegaConf.create(server) + + servers = dict(current.get('servers', {})) + servers.update(plugin_servers) + OmegaConf.update(config, '_merged_mcp', {'servers': servers}, merge=True) + self._applied_mcp_names = set(plugin_servers) + + for key, value in self.load_result.settings_patch.items(): + self._applied_settings_originals[key] = _snapshot_config_key(config, key) + OmegaConf.update(config, key, value, merge=True) + + if self.load_result.bin_paths: + if not hasattr(config, 'tools') or config.tools is None: + config.tools = OmegaConf.create({}) + if not hasattr(config.tools, 'code_executor') or config.tools.code_executor is None: + config.tools.code_executor = OmegaConf.create({}) + existing_bins = [] + if hasattr(config.tools.code_executor, 'plugin_bin_paths'): + existing_bins = [ + str(path) for path in config.tools.code_executor.plugin_bin_paths + ] + for path in self.load_result.bin_paths: + if str(path) not in existing_bins: + existing_bins.append(str(path)) + self._applied_bin_paths.add(str(path)) + config.tools.code_executor.plugin_bin_paths = existing_bins + + def _remove_applied_skill_sources(self, config: Any) -> None: + if not self._applied_skill_paths: + return + if not hasattr(config, 'skills') or not getattr(config.skills, 'sources', None): + self._applied_skill_paths.clear() + return + config.skills.sources = [ + source for source in config.skills.sources + if _skill_source_path(source) not in self._applied_skill_paths + ] + self._applied_skill_paths.clear() + + def _remove_applied_mcp_servers(self, config: Any) -> None: + if not self._applied_mcp_names: + return + self._remove_mcp_servers_by_name(config, self._applied_mcp_names) + self._applied_mcp_names.clear() + + def _remove_plugin_owned_mcp_servers( + self, + config: Any, + plugin_ids: set[str], + ) -> None: + if not plugin_ids: + return + names: set[str] = set() + if hasattr(config, 'tools') and config.tools is not None: + for name, server in config.tools.items(): + server_data = _to_plain_container(server) + if _is_plugin_server(server_data, plugin_ids): + names.add(str(name)) + if hasattr(config, '_merged_mcp') and config._merged_mcp: + current = OmegaConf.to_container(config._merged_mcp, resolve=True) or {} + for name, server in (current.get('servers') or {}).items(): + if _is_plugin_server(server, plugin_ids): + names.add(str(name)) + self._remove_mcp_servers_by_name(config, names) + + @staticmethod + def _remove_mcp_servers_by_name(config: Any, names: set[str]) -> None: + if not names: + return + if hasattr(config, 'tools') and config.tools is not None: + for name in names: + if name in config.tools: + del config.tools[name] + if hasattr(config, '_merged_mcp') and config._merged_mcp: + current = OmegaConf.to_container(config._merged_mcp, resolve=True) or {} + servers = dict(current.get('servers', {})) + for name in names: + servers.pop(name, None) + OmegaConf.update(config, '_merged_mcp', {'servers': servers}, merge=False) + + def _remove_applied_bin_paths(self, config: Any) -> None: + if not self._applied_bin_paths: + return + code_executor = getattr(getattr(config, 'tools', None), 'code_executor', None) + if code_executor is not None and hasattr(code_executor, 'plugin_bin_paths'): + code_executor.plugin_bin_paths = [ + path for path in code_executor.plugin_bin_paths + if str(path) not in self._applied_bin_paths + ] + self._applied_bin_paths.clear() + + def _revert_applied_settings(self, config: Any) -> None: + for key, original in self._applied_settings_originals.items(): + _restore_config_key(config, key, original) + self._applied_settings_originals.clear() + + def _sync_skill_runtime(self, config: Any) -> None: + if self.skill_runtime is None: + return + if not hasattr(config, 'skills') or not config.skills: + return + catalog = getattr(self.skill_runtime, 'catalog', None) + if catalog is None: + return + plugin_sources = list(self.load_result.skill_sources) + if plugin_sources and hasattr(catalog, 'reload_sources'): + catalog.reload_sources(plugin_sources) + else: + catalog.load_from_config(config.skills) + if hasattr(self.skill_runtime, '_version'): + self.skill_runtime._version += 1 + + def list_all(self) -> list[dict[str, Any]]: + loaded = {manifest.plugin_id: manifest for manifest in self.manifests} + items: list[dict[str, Any]] = [] + for record in self.config_manager.list('merged'): # type: ignore[union-attr] + manifest = loaded.get(record.id) + if manifest is None: + items.append({ + 'plugin_id': record.id, + 'enabled': record.enabled, + 'scope': record.scope, + 'path': record.path, + 'status': 'disabled' if not record.enabled else 'error', + 'capabilities': [], + 'capabilities_status': {}, + }) + continue + status = 'ready' if record.enabled else 'disabled' + items.append({ + 'plugin_id': manifest.plugin_id, + 'name': manifest.name, + 'version': manifest.version, + 'description': manifest.description, + 'enabled': record.enabled, + 'scope': record.scope, + 'path': str(manifest.root), + 'format': manifest.format.value, + 'capabilities': sorted(manifest.capabilities), + 'status': status, + 'capabilities_status': component_status_dict(manifest.components), + 'source': record.to_dict().get('source', {}), + 'installed_at': record.installed_at, + 'commands': [ + cmd.__dict__ for cmd in self.load_result.command_defs + if cmd.plugin_id == manifest.plugin_id + ], + 'agents': [ + agent for agent in self.agent_registry.list_all() + if agent['plugin_id'] == manifest.plugin_id + ], + 'agent_defs': [ + agent.__dict__ for agent in self.load_result.agent_defs + if agent.plugin_id == manifest.plugin_id + ], + 'bin_paths': [ + str(path) for path in self.load_result.bin_paths + if str(path).startswith(str(manifest.root)) + ], + 'settings_patch': self.load_result.settings_patch, + 'user_config_schema': (manifest.raw or {}).get('userConfig') or {}, + 'unsupported': [ + item.__dict__ for item in self.load_result.unsupported + if item.capability in manifest.components + ], + }) + return items + + async def toggle( + self, + plugin_id: str, + enabled: bool, + *, + scope: str = 'global', + project_path: str | None = None, + ) -> None: + self.config_manager.set_enabled( # type: ignore[union-attr] + plugin_id, + enabled, + scope=scope, # type: ignore[arg-type] + ) + reload_path = project_path or self._project_path + if reload_path is not None: + await self.reload( + plugin_id, + project_path=reload_path, + session_id=self._session_id, + config=self._config, + ) + + async def reload( + self, + plugin_id: str | None = None, + *, + project_path: str, + session_id: str = '', + config: Any | None = None, + ) -> None: + del plugin_id + await self.start( + project_path, + session_id or self._session_id, + config=config if config is not None else self._config, + enabled_executors=self._enabled_executors, + ) + + async def install( + self, + source: str, + *, + scope: str = 'global', + project_path: str | None = None, + **opts: Any, + ) -> PluginManifest: + installer = PluginInstaller( + config_manager=self.config_manager, + global_root=self.global_root, + project_root=project_path, + ) + return installer.install(source, scope=scope, project_path=project_path, **opts) + + def get_user_config(self, plugin_id: str) -> dict[str, Any]: + manifest = self._manifest_for_plugin(plugin_id) + schema = (manifest.raw or {}).get('userConfig') or {} + data_dir = self.global_root / 'plugins' / 'data' / plugin_id + from ms_agent.plugins.user_config import default_values, load_user_config + values = load_user_config(data_dir) + if not values and schema: + values = default_values(schema) + return { + 'plugin_id': plugin_id, + 'schema': schema, + 'values': values, + 'data_dir': str(data_dir), + } + + def set_user_config( + self, + plugin_id: str, + values: dict[str, Any], + ) -> dict[str, Any]: + manifest = self._manifest_for_plugin(plugin_id) + schema = (manifest.raw or {}).get('userConfig') or {} + if not schema: + raise ValueError(f'Plugin {plugin_id} has no userConfig schema') + data_dir = self.global_root / 'plugins' / 'data' / plugin_id + from ms_agent.plugins.user_config import save_user_config + saved = save_user_config(data_dir, schema, values) + if self._config is not None and self._project_path is not None: + self._start_unlocked( + self._project_path, + self._session_id, + config=self._config, + enabled_executors=self._enabled_executors, + ) + return {'plugin_id': plugin_id, 'values': saved} + + def _manifest_for_plugin(self, plugin_id: str) -> PluginManifest: + manifest = next( + (item for item in self.manifests if item.plugin_id == plugin_id), + None, + ) + if manifest is not None: + return manifest + record = self.config_manager.get(plugin_id, 'merged') # type: ignore[union-attr] + if record is None: + raise KeyError(f'Plugin not found: {plugin_id}') + return PluginManifest.parse(record.path, record=record) + + async def uninstall( + self, + plugin_id: str, + *, + scope: str = 'global', + purge: bool = False, + ) -> None: + record = self.config_manager.get( # type: ignore[union-attr] + plugin_id, + scope=scope, # type: ignore[arg-type] + ) + self.config_manager.remove( # type: ignore[union-attr] + plugin_id, + scope=scope, # type: ignore[arg-type] + ) + if purge and record is not None: + path = Path(record.path) + if not _is_managed_plugin_path( + path, + self.global_root, + self.config_manager.project_root if self.config_manager else None, + ): + raise ValueError(f'Refusing to purge unmanaged plugin path: {path}') + if path.is_symlink() or path.is_file(): + path.unlink(missing_ok=True) + elif path.is_dir(): + import shutil + shutil.rmtree(path) + + def _records_from_config( + self, + config: Any | None, + project_path: str, + ) -> list[PluginRecord]: + raw_records = [] + if config is not None and hasattr(config, '_merged_plugins'): + merged = OmegaConf.to_container(config._merged_plugins, resolve=True) + if isinstance(merged, dict): + raw_records = merged.get('plugins', []) + if raw_records: + records = [ + PluginRecord.from_dict(item, scope=item.get('scope')) + for item in raw_records + if isinstance(item, dict) + ] + return records + _legacy_plugin_records( + config, + records, + project_path=project_path, + global_root=self.global_root, + ) + + records = self.config_manager.load_merged(project_path) # type: ignore[union-attr] + if records: + return records + _legacy_plugin_records( + config, + records, + project_path=project_path, + global_root=self.global_root, + ) + + legacy_records = _legacy_plugin_records( + config, + [], + project_path=project_path, + global_root=self.global_root, + ) + if legacy_records: + return legacy_records + return [] + + +def _skill_source_path(source: Any) -> str: + if isinstance(source, str): + return source + if isinstance(source, dict): + return str(source.get('path', '')) + return str(getattr(source, 'path', '')) + + +def _legacy_plugin_records( + config: Any | None, + existing: list[PluginRecord], + *, + project_path: str | None = None, + global_root: Path | None = None, +) -> list[PluginRecord]: + if config is None or not hasattr(config, 'plugins') or not config.plugins: + return [] + registry = PluginRegistry( + PluginConfigManager(global_root or Path('~/.ms_agent').expanduser()), + ) + managed_paths = registry.managed_plugin_paths(project_path) + managed_ids = registry.managed_plugin_ids(project_path) + existing_paths = {str(Path(record.path).expanduser().resolve()) + for record in existing if record.path} + existing_paths |= managed_paths + records: list[PluginRecord] = [] + for raw_path in config.plugins: + path = Path(str(raw_path)).expanduser().resolve() + if str(path) in existing_paths: + continue + if path.name in managed_ids: + continue + records.append( + PluginRecord( + id=path.name, + path=str(path), + enabled=True, + source={'type': 'local', 'uri': str(raw_path)}, + )) + return records + + +def _is_managed_plugin_path( + path: Path, + global_root: Path, + project_root: Path | None = None, +) -> bool: + resolved_parent = path.expanduser().resolve().parent + allowed_roots = [ + (global_root / 'plugins').expanduser().resolve(), + ] + if project_root is not None: + allowed_roots.append( + (project_root / '.ms-agent' / 'plugins').expanduser().resolve()) + for root in allowed_roots: + try: + resolved_parent.relative_to(root) + return True + except ValueError: + continue + # Symlink paths resolve to the source, so also allow lexical storage paths. + raw_parent = path.expanduser().absolute().parent + for root in allowed_roots: + try: + raw_parent.relative_to(root) + return True + except ValueError: + continue + return False + + +def dedupe_mcp_server_names( + plugin_servers: dict[str, dict[str, Any]], + existing_names: set[str], +) -> dict[str, dict[str, Any]]: + result: dict[str, dict[str, Any]] = {} + used = set(existing_names) + for name, server in plugin_servers.items(): + candidate = name + if candidate in used: + plugin_id = server.get('plugin_id') + if plugin_id: + candidate = f'plugin.{plugin_id}.{name}' + base = candidate + suffix = 1 + while candidate in used or candidate in result: + candidate = f'{base}.{suffix}' + suffix += 1 + result[candidate] = server + used.add(candidate) + return result + + +def _snapshot_config_key(config: Any, key: str) -> Any: + if not hasattr(config, key): + return _MISSING + value = getattr(config, key) + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=False) + return deepcopy(value) + + +def _restore_config_key(config: Any, key: str, value: Any) -> None: + if value is _MISSING: + if key in config: + del config[key] + return + OmegaConf.update(config, key, value, merge=False) + + +def _to_plain_container(value: Any) -> Any: + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=True) + return value + + +def _is_plugin_server(server: Any, plugin_ids: set[str]) -> bool: + return ( + isinstance(server, dict) + and server.get('source') == 'plugin' + and server.get('plugin_id') in plugin_ids + ) diff --git a/ms_agent/plugins/types.py b/ms_agent/plugins/types.py new file mode 100644 index 000000000..48bebcda7 --- /dev/null +++ b/ms_agent/plugins/types.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any + + +class PluginFormat(str, Enum): + MS_AGENT = 'ms-agent' + CLAUDE = 'claude' + CODEX = 'codex' + CURSOR = 'cursor' + OPENCLAW = 'openclaw' + HERMES = 'hermes' + GENERIC = 'generic' + MIXED = 'mixed' + + +LOADABLE_CAPABILITIES = frozenset({ + 'skills', + 'commands', + 'agents', + 'hooks', + 'mcp', + 'settings', + 'bin', + 'user_config', +}) + +CAPABILITY_STATUS_KEYS = ( + 'skills', + 'commands', + 'agents', + 'hooks', + 'mcp', + 'settings', + 'bin', + 'user_config', + 'assets', + 'apps', + 'rules', + 'lsp', + 'output_styles', + 'themes', + 'monitors', + 'channels', + 'hooks_openclaw_internal', + 'hooks_hermes_python', +) + + +@dataclass(frozen=True) +class ComponentScan: + status: str + count: int = 0 + path: str | None = None + hint: str | None = None + + def to_dict(self) -> dict[str, Any]: + return { + k: v for k, v in asdict(self).items() + if v not in (None, 0) or k in {'status', 'count'} + } + + +@dataclass(frozen=True) +class InstallSource: + type: str = 'local' + uri: str | None = None + resolved_sha: str | None = None + + @classmethod + def from_raw(cls, raw: Any) -> 'InstallSource': + if isinstance(raw, InstallSource): + return raw + if isinstance(raw, dict): + return cls( + type=str(raw.get('type', 'local')), + uri=raw.get('uri'), + resolved_sha=raw.get('resolved_sha'), + ) + if isinstance(raw, str): + return cls(type='local', uri=raw) + return cls() + + def to_dict(self) -> dict[str, Any]: + return { + k: v for k, v in asdict(self).items() + if v is not None + } + + +@dataclass +class PluginRecord: + id: str + path: str + enabled: bool = True + managed_by: str = 'ms-agent' + format: str | PluginFormat | None = None + manifest_path: str | None = None + source: InstallSource | dict[str, Any] | str | None = None + installed_at: str | None = None + scope: str | None = None + + @classmethod + def from_dict(cls, raw: dict[str, Any], *, scope: str | None = None) -> 'PluginRecord': + return cls( + id=str(raw.get('id') or raw.get('plugin_id') or raw.get('name')), + enabled=bool(raw.get('enabled', True)), + managed_by=str(raw.get('managed_by', 'ms-agent')), + format=raw.get('format'), + manifest_path=raw.get('manifest_path'), + source=InstallSource.from_raw(raw.get('source')), + path=str(raw.get('path') or ''), + installed_at=raw.get('installed_at'), + scope=scope or raw.get('scope'), + ) + + def to_dict(self) -> dict[str, Any]: + fmt = self.format.value if isinstance(self.format, PluginFormat) else self.format + source = InstallSource.from_raw(self.source).to_dict() + data = { + 'id': self.id, + 'enabled': self.enabled, + 'managed_by': self.managed_by, + 'format': fmt, + 'manifest_path': self.manifest_path, + 'source': source, + 'path': self.path, + 'installed_at': self.installed_at, + } + return {k: v for k, v in data.items() if v not in (None, {}, '')} + + +@dataclass(frozen=True) +class CommandDef: + plugin_id: str + name: str + path: str + description: str | None = None + argument_hint: str | None = None + + +@dataclass(frozen=True) +class AgentDef: + plugin_id: str + name: str + path: str + description: str | None = None + model: str | None = None + tools: tuple[str, ...] = () + skills: tuple[str, ...] = () + disallowed_tools: tuple[str, ...] = () + + +@dataclass(frozen=True) +class UnsupportedCapability: + capability: str + status: str = 'unsupported' + hint: str | None = None + + +def component_status_dict( + components: dict[str, ComponentScan], +) -> dict[str, dict[str, Any]]: + return { + key: components.get(key, ComponentScan(status='skipped')).to_dict() + for key in CAPABILITY_STATUS_KEYS + } diff --git a/ms_agent/plugins/user_config.py b/ms_agent/plugins/user_config.py new file mode 100644 index 000000000..de3d0bec7 --- /dev/null +++ b/ms_agent/plugins/user_config.py @@ -0,0 +1,105 @@ +"""userConfig schema validation and persistence for plugin data dirs.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +_ALLOWED_TYPES = frozenset({'string', 'boolean', 'number', 'integer', 'array', 'object'}) + + +class UserConfigError(ValueError): + """Raised when userConfig schema or values are invalid.""" + + +def validate_schema(schema: dict[str, Any]) -> list[str]: + """Validate manifest ``userConfig`` field definitions.""" + errors: list[str] = [] + if not isinstance(schema, dict): + return ['userConfig must be an object'] + for key, field in schema.items(): + if not isinstance(key, str) or not key.strip(): + errors.append('userConfig keys must be non-empty strings') + continue + if not isinstance(field, dict): + errors.append(f'userConfig.{key} must be an object') + continue + field_type = field.get('type') + if field_type not in _ALLOWED_TYPES: + errors.append( + f'userConfig.{key}.type must be one of {sorted(_ALLOWED_TYPES)}') + return errors + + +def validate_values( + schema: dict[str, Any], + values: dict[str, Any], +) -> list[str]: + """Validate submitted config values against a userConfig schema.""" + errors = validate_schema(schema) + if errors: + return errors + if not isinstance(values, dict): + return ['config values must be an object'] + for key, field in schema.items(): + if key not in values: + if field.get('required'): + errors.append(f'Missing required userConfig field: {key}') + continue + value = values[key] + field_type = field.get('type') + if field_type == 'string' and not isinstance(value, str): + errors.append(f'userConfig.{key} must be a string') + elif field_type == 'boolean' and not isinstance(value, bool): + errors.append(f'userConfig.{key} must be a boolean') + elif field_type in {'number', 'integer'} and not isinstance(value, (int, float)): + errors.append(f'userConfig.{key} must be a number') + elif field_type == 'array' and not isinstance(value, list): + errors.append(f'userConfig.{key} must be an array') + elif field_type == 'object' and not isinstance(value, dict): + errors.append(f'userConfig.{key} must be an object') + for key in values: + if key not in schema: + errors.append(f'Unknown userConfig field: {key}') + return errors + + +def load_user_config(data_dir: str | Path) -> dict[str, Any]: + path = Path(data_dir) / 'config.json' + if not path.is_file(): + return {} + try: + with open(path, encoding='utf-8') as f: + data = json.load(f) + except (OSError, json.JSONDecodeError): + return {} + return data if isinstance(data, dict) else {} + + +def save_user_config( + data_dir: str | Path, + schema: dict[str, Any], + values: dict[str, Any], +) -> dict[str, Any]: + errors = validate_values(schema, values) + if errors: + raise UserConfigError('; '.join(errors)) + path = Path(data_dir) + path.mkdir(parents=True, exist_ok=True) + config_path = path / 'config.json' + tmp = config_path.with_suffix('.tmp') + with open(tmp, 'w', encoding='utf-8') as f: + json.dump(values, f, indent=2, ensure_ascii=False) + tmp.rename(config_path) + return values + + +def default_values(schema: dict[str, Any]) -> dict[str, Any]: + values: dict[str, Any] = {} + for key, field in schema.items(): + if not isinstance(field, dict): + continue + if 'default' in field: + values[key] = field['default'] + return values diff --git a/ms_agent/skill/catalog.py b/ms_agent/skill/catalog.py index e1ed211db..1ceacad13 100644 --- a/ms_agent/skill/catalog.py +++ b/ms_agent/skill/catalog.py @@ -134,6 +134,9 @@ def load_from_config(self, skills_config) -> None: revision=getattr(src_cfg, "revision", None), subdir=getattr(src_cfg, "subdir", None), enabled=getattr(src_cfg, "enabled", True), + origin=getattr(src_cfg, "origin", "config"), + plugin_id=getattr(src_cfg, "plugin_id", None), + capability=getattr(src_cfg, "capability", None), )) # 3b. Simple path list (backward compat) elif hasattr(skills_config, "path") and skills_config.path: @@ -172,12 +175,21 @@ def load_from_sources(self, sources: List[SkillSource]) -> None: try: skills = self._materialize_and_load(source) for skill in skills.values(): - self._register_skill(skill) + self._register_skill(skill, source) except Exception as e: logger.warning(f"Failed to load skill source {source}: {e}") def _materialize_and_load( self, source: SkillSource) -> Dict[str, SkillSchema]: + if ( + source.capability == 'commands' + and source.path + and str(source.path).endswith('.md') + ): + return self._loader.load_command_markdown( + source.path, + plugin_id=source.plugin_id, + ) if source.type == SkillSourceType.LOCAL_DIR: return self._loader.load_skills(source.path) elif source.type == SkillSourceType.MODELSCOPE: @@ -236,6 +248,8 @@ def _init_safety(self, config) -> None: @staticmethod def _infer_trust_level(skill: SkillSchema, source=None) -> str: """Determine trust level from the skill's source path.""" + if source is not None and getattr(source, 'origin', None) == 'plugin': + return 'plugin' skill_path_str = str(skill.skill_path) builtin_str = str(BUILTIN_SKILLS_DIR) user_str = str(USER_SKILLS_DIR) @@ -246,12 +260,16 @@ def _infer_trust_level(skill: SkillSchema, source=None) -> str: return 'local' return 'community' - def _register_skill(self, skill: SkillSchema) -> None: + def _register_skill(self, skill: SkillSchema, source=None) -> None: """Register a skill; later registrations override earlier ones. Runs safety scanning (when enabled) and applies trust policy. """ - skill._trust_level = self._infer_trust_level(skill) + skill._trust_level = self._infer_trust_level(skill, source) + if source is not None: + skill._origin = getattr(source, 'origin', 'config') + skill._plugin_id = getattr(source, 'plugin_id', None) + skill._capability = getattr(source, 'capability', None) if self._safety_scanner: try: @@ -304,6 +322,45 @@ def get_skill(self, skill_id: str) -> Optional[SkillSchema]: # Hot reload # ------------------------------------------------------------------ # + def reload_sources(self, sources: List[SkillSource]) -> None: + """Reload only skills contributed by the given sources.""" + if not sources: + return + target_paths = { + str(Path(source.path).expanduser().resolve()) + for source in sources + if source.path + } + target_keys = { + (source.plugin_id, source.capability) + for source in sources + if source.plugin_id + } + remove_ids: List[str] = [] + for sid, skill in self._skills.items(): + plugin_id = getattr(skill, '_plugin_id', None) + capability = getattr(skill, '_capability', None) + if plugin_id and (plugin_id, capability) in target_keys: + remove_ids.append(sid) + continue + for file_info in skill.files: + file_path = str(Path(file_info.path).expanduser().resolve()) + if file_path in target_paths: + remove_ids.append(sid) + break + for sid in remove_ids: + self._skills.pop(sid, None) + for source in sources: + if not source.enabled: + continue + try: + skills = self._materialize_and_load(source) + for skill in skills.values(): + self._register_skill(skill, source) + except Exception as e: + logger.warning(f'Failed to reload skill source {source}: {e}') + self._invalidate_cache() + def reload(self) -> None: self._skills.clear() self.load_from_sources(self._sources) diff --git a/ms_agent/skill/loader.py b/ms_agent/skill/loader.py index 763128366..fcadb37ff 100644 --- a/ms_agent/skill/loader.py +++ b/ms_agent/skill/loader.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +import re from pathlib import Path from typing import Dict, List, Optional, Union @@ -190,6 +191,48 @@ def get_all_skills(self) -> Dict[str, SkillSchema]: """ return self.loaded_skills.copy() + def load_command_markdown( + self, + command_path: str | Path, + *, + plugin_id: str | None = None, + ) -> Dict[str, SkillSchema]: + """Load a plugin command ``*.md`` file as a virtual skill entry.""" + from .schema import SkillFile, SkillSchema + + path = Path(command_path) + if not path.is_file(): + return {} + try: + content = path.read_text(encoding='utf-8') + except OSError: + return {} + frontmatter = self.parser.parse_yaml_frontmatter(content) or {} + name = str(frontmatter.get('name') or path.stem) + description = str( + frontmatter.get('description') or f'Plugin command {name}') + skill_id = ( + f'{plugin_id}:{name}' if plugin_id else f'command:{name}') + body_text = re.sub( + r'^---\s*\n.*?\n---\s*\n', + '', + content, + count=1, + flags=re.DOTALL, + ).strip() + skill = SkillSchema( + skill_id=skill_id, + name=name, + description=description, + content=body_text, + files=[SkillFile(name='SKILL.md', type='.md', path=path)], + skill_path=path.parent, + version='latest', + tags=['plugin-command'], + ) + key = self._get_skill_key(skill=skill) + return {key: skill} + def reload_skill(self, skill_path: str) -> Optional[SkillSchema]: """ Reload a skill from its directory. diff --git a/ms_agent/skill/runtime.py b/ms_agent/skill/runtime.py index 39c6423f3..65c285d2e 100644 --- a/ms_agent/skill/runtime.py +++ b/ms_agent/skill/runtime.py @@ -92,6 +92,9 @@ def list_all(self) -> List[Dict[str, Any]]: 'tags': skill.tags, 'has_scripts': bool(skill.scripts), 'version': skill.version, + 'origin': getattr(skill, '_origin', 'config'), + 'plugin_id': getattr(skill, '_plugin_id', None), + 'capability': getattr(skill, '_capability', None), }) return result diff --git a/ms_agent/skill/sources.py b/ms_agent/skill/sources.py index 23e9de67a..5ade485d1 100644 --- a/ms_agent/skill/sources.py +++ b/ms_agent/skill/sources.py @@ -22,6 +22,9 @@ class SkillSource: revision: Optional[str] = None subdir: Optional[str] = None enabled: bool = True + origin: str = "config" + plugin_id: Optional[str] = None + capability: Optional[str] = None _MODELSCOPE_URI_RE = re.compile( diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 926bfdb01..51fbfd121 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -215,6 +215,8 @@ def __init__(self, config: DictConfig, **kwargs): self._active_sync_tasks: Dict[str, Any] = {} # effective_call_id -> stream file path (set during _run_agent, consumed by call_tool) self._stream_paths: Dict[str, str] = {} + self._plugin_agent_registry = None + self._plugin_spec_keys: set[str] = set() self._load_specs() @property @@ -476,6 +478,47 @@ def _emit_chunk_event(self, event_type: str, data: Dict[str, Any]) -> None: def set_task_manager(self, task_manager) -> None: self._task_manager = task_manager + def sync_plugin_agents(self, registry) -> None: + """Register plugin-defined subagents and the Claude-compatible Task tool.""" + from ms_agent.plugins.agents import AgentDelegate + + self._plugin_agent_registry = registry + for key in self._plugin_spec_keys: + self._specs.pop(key, None) + self._plugin_spec_keys.clear() + + if registry is None or not registry.has_agents(): + self._build_server_index() + return + + for entry in { + item['namespaced_name']: registry.resolve(item['namespaced_name']) + for item in registry.list_all() + }.values(): + if entry is None: + continue + spec = AgentDelegate.to_agent_tool_spec( + entry, + self.config, + trust_remote_code=self._trust_remote_code, + ) + self._specs[spec.tool_name] = spec + self._plugin_spec_keys.add(spec.tool_name) + namespaced_tool = entry.namespaced_name.replace(':', '---') + if namespaced_tool != spec.tool_name: + from dataclasses import replace + namespaced_spec = replace(spec, tool_name=namespaced_tool) + self._specs[namespaced_spec.tool_name] = namespaced_spec + self._plugin_spec_keys.add(namespaced_spec.tool_name) + + task_spec = AgentDelegate.build_task_tool_spec( + registry, + trust_remote_code=self._trust_remote_code, + ) + self._specs[task_spec.tool_name] = task_spec + self._plugin_spec_keys.add(task_spec.tool_name) + self._build_server_index() + # ── stream-file helpers ──────────────────────────────────────────────── def _stream_file_enabled(self) -> bool: @@ -747,6 +790,9 @@ async def _watcher(): async def _call_dynamic(self, tool_args: dict, spec: '_AgentToolSpec') -> str: + if spec.tool_name == 'Task' and self._plugin_agent_registry is not None: + return await self._call_plugin_task(tool_args, spec) + tasks = tool_args.get('tasks', []) execution_mode = tool_args.get('execution_mode', 'sequential') @@ -817,6 +863,48 @@ async def _run_one(i: int, task: dict) -> str: formatted += f'SubTask{i}:{content}\n' return formatted + async def _call_plugin_task(self, tool_args: dict, + spec: '_AgentToolSpec') -> str: + from ms_agent.plugins.agents import AgentDelegate + + registry = self._plugin_agent_registry + if registry is None or not registry.has_agents(): + return json.dumps({ + 'error': 'No plugin subagents are registered.', + }, ensure_ascii=False) + + entry = AgentDelegate.resolve_task_entry(registry, tool_args) + if entry is None: + agent_name = AgentDelegate.resolve_task_agent_name(tool_args) + available = ', '.join( + item['namespaced_name'] for item in registry.list_all()) + return json.dumps({ + 'error': ( + f'Unknown plugin subagent {agent_name!r}. ' + f'Available: {available}' + ), + }, ensure_ascii=False) + + delegate_spec = AgentDelegate.to_agent_tool_spec( + entry, + self.config, + trust_remote_code=self._trust_remote_code, + ) + prompt = ( + tool_args.get('prompt') + or tool_args.get('request') + or tool_args.get('description') + or '' + ) + if not isinstance(prompt, str): + prompt = json.dumps(prompt, ensure_ascii=False) + + use_subprocess = ( + delegate_spec.run_in_thread and delegate_spec.run_in_process) + agent = None if use_subprocess else self._build_agent(delegate_spec) + messages = await self._run_agent(agent, prompt, delegate_spec) + return self._format_output(messages, delegate_spec) + @staticmethod def _terminate_process(proc: Optional[mp.Process], *, reason: str) -> None: if proc is None: @@ -1140,7 +1228,8 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: if not self._enable_stats: result = await runner() - self._save_transcript(result, runtime_agent_tag) + if not spec.run_in_process: + self._save_transcript(result, runtime_agent_tag) if _writer is not None: # Store with the same key used by call_tool() to pop it. store_key = call_id if call_id is not None else _effective_call_id @@ -1153,7 +1242,8 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: result = None try: result = await runner() - self._save_transcript(result, runtime_agent_tag) + if not spec.run_in_process: + self._save_transcript(result, runtime_agent_tag) if _writer is not None: store_key = call_id if call_id is not None else _effective_call_id self._stream_paths[store_key] = _writer.stream_path diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index a95baaf96..54ed4fc21 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -15,10 +15,8 @@ from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger from ms_agent.utils.artifact_manager import ArtifactManager -from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package -from ms_agent.utils.workspace_policy import (WorkspacePolicyError, - WorkspacePolicyKernel) +from ms_agent.utils.workspace_context import WorkspaceContext logger = get_logger() @@ -234,9 +232,8 @@ class LocalCodeExecutionTool(ToolBase): def __init__(self, config): super().__init__(config) - self.output_dir = Path( - getattr(config, 'output_dir', - DEFAULT_OUTPUT_DIR)).expanduser().resolve() + self._ws = WorkspaceContext.from_config(config) + self.output_dir = self._ws.root self.output_dir.mkdir(parents=True, exist_ok=True) self.tool_config = getattr( @@ -258,36 +255,13 @@ def __init__(self, config): self._task_manager = None self._watcher_tasks: Set[asyncio.Task] = set() - wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) - extra_allow: List[str] = [] - deny_globs = None - if wp is not None: - extra_allow = list(getattr(wp, 'allow_roots', []) or []) - dg = getattr(wp, 'deny_globs', None) - if dg: - deny_globs = list(dg) shell_cfg = getattr(self.tool_config, 'shell', None) if self.tool_config else None - shell_mode = getattr( - shell_cfg, 'default_mode', - 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', - False)) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', - 8192)) if shell_cfg else 8192 - self._policy = WorkspacePolicyKernel( - self.output_dir, - extra_allow_roots=extra_allow, - deny_globs=deny_globs, - shell_default_mode=str(shell_mode), - shell_network_enabled=net, - max_command_chars=max_cmd, - ) max_kb = 256 if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): max_kb = int(shell_cfg.max_output_kb) self._artifacts = ArtifactManager( - self.output_dir, max_combined_bytes=max_kb * 1024) + self._ws.root, max_combined_bytes=max_kb * 1024) self.exclude_func( getattr(getattr(config, 'tools', None), 'code_executor', None)) @@ -374,6 +348,12 @@ def _build_env(self, field: str, inherit: bool = False) -> Dict[str, str]: if value is None: continue env[key] = str(value) + plugin_bins = getattr(self.tool_config, 'plugin_bin_paths', + None) if self.tool_config else None + if plugin_bins: + paths = [str(path) for path in plugin_bins if path] + if paths: + env['PATH'] = os.pathsep.join(paths + [env.get('PATH', '')]) return env async def connect(self) -> None: @@ -718,18 +698,6 @@ async def shell_executor(self, exec_timeout = timeout or self._shell_timeout call_id = call_id or f'shell-{os.urandom(4).hex()}' - try: - self._policy.assert_shell_command_allowed(command) - except WorkspacePolicyError as e: - return json.dumps( - { - 'success': False, - 'error': str(e) - }, - ensure_ascii=False, - indent=2, - ) - shell_cmd = self._prepare_shell_command(command) if run_in_background: @@ -749,7 +717,7 @@ async def shell_executor(self, shell_cmd, stdout=ai_subprocess.PIPE, stderr=ai_subprocess.PIPE, - cwd=str(self._policy.workspace_root), + cwd=str(self._ws.root), env=self.shell_env, ) except FileNotFoundError as exc: @@ -825,7 +793,7 @@ async def _watcher() -> None: shell_cmd, stdout=ai_subprocess.PIPE, stderr=ai_subprocess.PIPE, - cwd=str(self._policy.workspace_root), + cwd=str(self._ws.root), env=self.shell_env, ) except FileNotFoundError as exc: diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index b8a81e465..b4de986eb 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -15,9 +15,8 @@ from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger from ms_agent.utils.artifact_manager import ArtifactManager -from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_OUTPUT_DIR -from ms_agent.utils.workspace_policy import (WorkspacePolicyError, - WorkspacePolicyKernel) +from ms_agent.utils.constants import DEFAULT_INDEX_DIR +from ms_agent.utils.workspace_context import WorkspaceContext logger = get_logger() @@ -109,7 +108,8 @@ def __init__(self, config, **kwargs): self.exclude_functions = [ _FS_TOOL_ALIASES.get(n, n) for n in self.exclude_functions ] - self.output_dir = getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR) + self._ws = WorkspaceContext.from_config(config) + self.output_dir = str(self._ws.root) self.trust_remote_code = kwargs.get('trust_remote_code', False) self.allow_read_all_files = getattr( getattr(config.tools, 'file_system', {}), 'allow_read_all_files', @@ -133,38 +133,18 @@ def __init__(self, config, **kwargs): self._glob_max_files = int( getattr(fs_cfg, 'glob_max_files', 100) or 100) - wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) - extra = list(getattr(wp, 'allow_roots', []) or []) if wp else [] - deny = list(getattr(wp, 'deny_globs', []) or []) if wp else [] - - shell_cfg = getattr( - getattr(config.tools, 'code_executor', None), 'shell', None) - shell_mode = getattr( - shell_cfg, 'default_mode', - 'workspace_write') if shell_cfg else 'workspace_write' - net = bool(getattr(shell_cfg, 'network_enabled', - False)) if shell_cfg else False - max_cmd = int(getattr(shell_cfg, 'max_command_chars', - 8192)) if shell_cfg else 8192 - - _out_p = Path(self.output_dir).expanduser().resolve() try: - _out_p.mkdir(parents=True, exist_ok=True) + self._ws.root.mkdir(parents=True, exist_ok=True) except OSError: pass - self._fs_policy = WorkspacePolicyKernel( - _out_p, - extra_allow_roots=extra, - deny_globs=deny if deny else None, - shell_default_mode=str(shell_mode), - shell_network_enabled=net, - max_command_chars=max_cmd, - ) + + shell_cfg = getattr( + getattr(config.tools, 'code_executor', None), 'shell', None) max_kb = 256 if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): max_kb = int(shell_cfg.max_output_kb) self._fs_artifacts = ArtifactManager( - _out_p, max_combined_bytes=max_kb * 1024) + self._ws.root, max_combined_bytes=max_kb * 1024) async def connect(self): logger.warning_once( @@ -419,10 +399,8 @@ async def grep( head_limit if head_limit is not None else self._default_grep_head) offset = offset or 0 path = path or '.' - try: - root = self._fs_policy.resolve_under_roots(path) - except WorkspacePolicyError as e: - return json.dumps({'success': False, 'error': str(e)}, indent=2) + raw = Path(path).expanduser() + root = raw.resolve() if raw.is_absolute() else (self._ws.root / raw).resolve() if pattern is None or (isinstance(pattern, str) and not pattern.strip()): @@ -516,7 +494,7 @@ async def _grep_rg_file( *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=str(self._fs_policy.workspace_root), + cwd=str(self._ws.root), ) out_b, err_b = await asyncio.wait_for( proc.communicate(), timeout=self._grep_timeout) @@ -553,7 +531,7 @@ async def _grep_rg_dir( *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=str(self._fs_policy.workspace_root), + cwd=str(self._ws.root), ) out_b, err_b = await asyncio.wait_for( proc.communicate(), timeout=self._grep_timeout) @@ -599,15 +577,15 @@ def consider_file(fp: Path) -> bool: if root.is_file(): files = [root] else: - for fp in _walk_files_limited(root, self._fs_policy.deny_globs, + for fp in _walk_files_limited(root, self._ws.deny_globs, 50_000): if consider_file(fp): files.append(fp) for fp in files: rel = str(fp.relative_to( - self._fs_policy.workspace_root)) if _is_relative( - fp, self._fs_policy.workspace_root) else str(fp) + self._ws.root)) if _is_relative( + fp, self._ws.root) else str(fp) try: if output_mode == 'files_with_matches': with fp.open(encoding='utf-8', errors='replace') as f: @@ -640,10 +618,8 @@ def consider_file(fp: Path) -> bool: async def glob(self, pattern: str, path: str = '') -> str: call_id = f'glob-{pattern[:40]}' - try: - base = self._fs_policy.resolve_under_roots(path or '.') - except WorkspacePolicyError as e: - return json.dumps({'success': False, 'error': str(e)}, indent=2) + raw = Path(path or '.').expanduser() + base = raw.resolve() if raw.is_absolute() else (self._ws.root / raw).resolve() if not base.is_dir(): return json.dumps( @@ -656,20 +632,18 @@ async def glob(self, pattern: str, path: str = '') -> str: matches: List[str] = [] truncated = False - deny = self._fs_policy.deny_globs + deny = self._ws.deny_globs try: for p in sorted(base.glob(pattern)): if not p.is_file(): continue rp = p.resolve() - if not self._fs_policy.path_is_allowed(rp): - continue if _is_denied_path(rp, base, deny): continue rel = str(p.relative_to( - self._fs_policy.workspace_root)) if _is_relative( - p, self._fs_policy.workspace_root) else str(p) + self._ws.root)) if _is_relative( + p, self._ws.root) else str(p) matches.append(rel) if len(matches) >= self._glob_max_files: truncated = True diff --git a/ms_agent/tools/mcp_client.py b/ms_agent/tools/mcp_client.py index 7fbfe148c..0bd6e95a3 100644 --- a/ms_agent/tools/mcp_client.py +++ b/ms_agent/tools/mcp_client.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy import os from contextlib import AsyncExitStack from datetime import timedelta @@ -7,7 +8,7 @@ from mcp.client.stdio import stdio_client from omegaconf import DictConfig from types import TracebackType -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, List, Literal, Optional from ms_agent.config import Config from ms_agent.config.env import Env @@ -47,6 +48,7 @@ def __init__( ): super().__init__(config) self.sessions: Dict[str, ClientSession] = {} + self._server_stacks: Dict[str, AsyncExitStack] = {} self.exit_stack = AsyncExitStack() self.mcp_config: Dict[str, Dict[str, Any]] = {'mcpServers': {}} if config is not None: @@ -89,40 +91,52 @@ async def call_tool(self, server_name: str, tool_name: str, return '\n\n'.join(texts) + def _filter_session_tools( + self, + server_name: str, + response: ListToolsResult, + ) -> List[Tool]: + exclude: list[str] = [] + include: list[str] = [] + if self.include_functions and server_name in self.include_functions: + include = self.include_functions[server_name] + elif self.exclude_functions and server_name in self.exclude_functions: + exclude = self.exclude_functions[server_name] + session_tools = [t for t in response.tools if t.name not in exclude] + if include: + session_tools = [t for t in session_tools if t.name in include] + return [ + Tool( + tool_name=t.name, + server_name=server_name, + description=t.description, + parameters=t.inputSchema, + ) + for t in session_tools + ] + + async def get_tools_for_server(self, server_name: str) -> List[Tool]: + """List tools for a single connected server (failures are isolated).""" + session = self.sessions.get(server_name) + if session is None: + return [] + try: + response = await session.list_tools() + except Exception as e: + new_eg = enhance_error( + e, f'MCP `{server_name}` list tool failed, details: ') + raise new_eg from e + return self._filter_session_tools(server_name, response) + async def get_tools(self) -> Dict: - tools = {} - for key, session in self.sessions.items(): - tools[key] = [] + tools: Dict[str, List[Tool]] = {} + for key in self.sessions: try: - response = await session.list_tools() + tools[key] = await self.get_tools_for_server(key) except Exception as e: - new_eg = enhance_error( - e, f'MCP `{key}` list tool failed, details: ') - raise new_eg from e - _session_tools = response.tools - exclude = [] - include = [] - if self.include_functions: - if key in self.include_functions: - include = self.include_functions[key] - elif self.exclude_functions: - if key in self.exclude_functions: - exclude = self.exclude_functions[key] - _session_tools = [ - t for t in _session_tools if t.name not in exclude - ] - if include: - _session_tools = [ - t for t in _session_tools if t.name in include - ] - _session_tools = [ - Tool( - tool_name=t.name, - server_name=key, - description=t.description, - parameters=t.inputSchema) for t in _session_tools - ] - tools[key].extend(_session_tools) + logger.warning( + 'Skipping MCP server %s in get_tools: %s', key, e) + tools[key] = [] return tools @staticmethod @@ -140,11 +154,67 @@ def print_tools(server_name: str, tools: ListToolsResult): logger.info(f'\nConnected to server "{server_name}" ' f'with tools: \n{sep.join(tools)}.') + @staticmethod + def resolve_server_env(server: Dict[str, Any]) -> Dict[str, str]: + envs = Env.load_env() + env_dict = copy.deepcopy(server.get('env') or {}) + return { + key: value if value else envs.get(key, '') + for key, value in env_dict.items() + } + + def list_connected_servers(self) -> list[str]: + return list(self.sessions.keys()) + + def is_connected(self, server_name: str) -> bool: + return server_name in self.sessions + + async def disconnect_server(self, server_name: str) -> None: + """Disconnect a single MCP server.""" + stack = self._server_stacks.pop(server_name, None) + self.sessions.pop(server_name, None) + self.exclude_functions.pop(server_name, None) + self.include_functions.pop(server_name, None) + if stack is not None: + await stack.aclose() + + async def connect_single_server( + self, + server_name: str, + server_config: Dict[str, Any], + timeout: int = CONNECTION_TIMEOUT, + ) -> str: + """Connect one server from a normalized config entry.""" + if self.is_connected(server_name): + return server_name + server = copy.deepcopy(server_config) + env_dict = self.resolve_server_env(server) + if 'exclude' in server: + self.exclude_functions[server_name] = server.pop('exclude') + if 'include' in server: + self.include_functions[server_name] = server.pop('include') + assert (not self.include_functions.get(server_name)) or ( + not self.exclude_functions.get(server_name) + ), 'Set either `include` or `exclude` in tools config.' + timeout = server.pop('timeout', timeout) + for drop_key in ('enabled', 'source', 'plugin_id', 'meta'): + server.pop(drop_key, None) + return await self.connect_to_server( + server_name=server_name, + env=env_dict, + timeout=timeout, + **server, + ) + async def connect_to_server(self, server_name: str, timeout: int = CONNECTION_TIMEOUT, **kwargs): + if self.is_connected(server_name): + return server_name logger.info(f'connect to {server_name}') + stack = AsyncExitStack() + self._server_stacks[server_name] = stack # transport: stdio, sse, streamable_http, websocket transport = kwargs.get('transport') or kwargs.get('type') command = kwargs.get('command') @@ -155,7 +225,7 @@ async def connect_to_server(self, logger.info( '`transport` or `type` is configured as "sse", using sse transport.' ) - sse_transport = await self.exit_stack.enter_async_context( + sse_transport = await stack.enter_async_context( sse_client( url, kwargs.get('headers'), kwargs.get('timeout', DEFAULT_HTTP_TIMEOUT), @@ -175,7 +245,7 @@ async def connect_to_server(self, 'To use Websocket connections, please install the required dependency with: ' "'pip install mcp[ws]' or 'pip install websockets'" ) from None - websocket_transport = await self.exit_stack.enter_async_context( + websocket_transport = await stack.enter_async_context( websocket_client(url)) read, write = websocket_transport @@ -195,7 +265,7 @@ async def connect_to_server(self, other_kwargs = {} if httpx_client_factory is not None: other_kwargs['httpx_client_factory'] = httpx_client_factory - streamable_transport = await self.exit_stack.enter_async_context( + streamable_transport = await stack.enter_async_context( streamablehttp_client( url, headers=kwargs.get('headers'), @@ -210,7 +280,7 @@ async def connect_to_server(self, session_kwargs = session_kwargs or {} timeout = max( session_kwargs.pop('read_timeout_seconds', timeout), 1) - session = await self.exit_stack.enter_async_context( + session = await stack.enter_async_context( ClientSession( read, write, @@ -232,9 +302,9 @@ async def connect_to_server(self, 'encoding_error_handler', DEFAULT_ENCODING_ERROR_HANDLER), ) - stdio, write = await self.exit_stack.enter_async_context( + stdio, write = await stack.enter_async_context( stdio_client(server_params)) - session = await self.exit_stack.enter_async_context( + session = await stack.enter_async_context( ClientSession(stdio, write)) else: raise ValueError( @@ -295,6 +365,8 @@ async def add_mcp_config(self, mcp_config: Dict[str, Dict[str, Any]]): async def cleanup(self): """Clean up resources""" + for name in list(self._server_stacks): + await self.disconnect_server(name) await self.exit_stack.aclose() async def __aenter__(self) -> 'MCPClient': @@ -302,7 +374,7 @@ async def __aenter__(self) -> 'MCPClient': await self.connect() return self except Exception: - await self.exit_stack.aclose() + await self.cleanup() raise async def __aexit__( @@ -311,4 +383,4 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - await self.exit_stack.aclose() + await self.cleanup() diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index bb5006c9e..f6d67be47 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -9,7 +9,7 @@ import uuid from copy import copy from types import TracebackType -from typing import Any, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional from ms_agent.llm.utils import Tool, ToolCall from ms_agent.tools.agent_tool import AgentTool @@ -94,9 +94,30 @@ def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None, mcp_client: Optional[MCPClient] = None, + permission_enforcer=None, + safety_guard=None, + permission_mode: str = 'auto', + read_policy: str = 'loose', + hook_runtime=None, + permission_config=None, + mcp_callable_check: Optional[Callable[[str], bool]] = None, + mcp_failure_handler: Optional[Callable[ + [str, str, str, Optional[str]], Awaitable[None]]] = None, + mcp_unavailable_detail: Optional[Callable[[str], dict]] = None, + mcp_success_handler: Optional[Callable[[str], Awaitable[None]]] = None, **kwargs): self.config = config self.trust_remote_code = kwargs.get('trust_remote_code', False) + self._permission_enforcer = permission_enforcer + self._permission_config = permission_config + self._safety_guard = safety_guard + self._permission_mode = permission_mode + self._read_policy = read_policy + self._hook_runtime = hook_runtime + self.mcp_callable_check = mcp_callable_check + self.mcp_failure_handler = mcp_failure_handler + self.mcp_unavailable_detail = mcp_unavailable_detail + self.mcp_success_handler = mcp_success_handler self.extra_tools: List[ToolBase] = [] self.has_split_task_tool = False @@ -184,6 +205,23 @@ def __init__(self, if issubclass(cls, ToolBase) and cls.__module__ == _plugin: self.register_tool(cls(self.config)) self._tool_index = {} + self._mcp_index_keys: set[str] = set() + self._skip_mcp_reindex = False + + def ensure_plugin_agent_tools(self, registry) -> None: + """Attach plugin-defined subagents to AgentTool before connect().""" + if registry is None or not registry.has_agents(): + return + agent_tool = None + for tool in self.extra_tools: + if isinstance(tool, AgentTool): + agent_tool = tool + break + if agent_tool is None: + agent_tool = AgentTool( + self.config, trust_remote_code=self.trust_remote_code) + self.extra_tools.append(agent_tool) + agent_tool.sync_plugin_agents(registry) # Used temporarily during async initialization; the actual client is managed in self.servers self.mcp_client = mcp_client @@ -194,21 +232,28 @@ def __init__(self, # Initialize concurrency limiter (will be set in connect) self._concurrent_limiter = None self._init_lock = None + self._sync_lock = asyncio.Lock() def register_tool(self, tool: ToolBase): self.extra_tools.append(tool) async def connect(self): - if self.mcp_client and MCPClient and isinstance(self.mcp_client, MCPClient): + if self.mcp_client is not None: self.servers = self.mcp_client - await self.servers.add_mcp_config(self.mcp_config) - self.mcp_config = self.servers.mcp_config + has_add = hasattr(self.servers, 'add_mcp_config') + is_mcp = MCPClient is not None and isinstance(self.mcp_client, MCPClient) + if self.mcp_config and self.mcp_config.get('mcpServers') and (is_mcp or has_add): + await self.servers.add_mcp_config(self.mcp_config) + if hasattr(self.servers, 'mcp_config'): + self.mcp_config = self.servers.mcp_config elif MCPClient is not None: self.servers = MCPClient(self.mcp_config, self.config) await self.servers.connect() for tool in self.extra_tools: await tool.connect() - await self.reindex_tool() + + if not self._skip_mcp_reindex: + await self.reindex_tool() # Initialize concurrency limiter self._concurrent_limiter = asyncio.Semaphore(MAX_CONCURRENT_TOOLS) @@ -227,6 +272,97 @@ async def cleanup(self): except Exception: # noqa pass + def _clear_mcp_index_entries(self) -> None: + for key in self._mcp_index_keys: + self._tool_index.pop(key, None) + self._mcp_index_keys.clear() + + async def _report_mcp_failure( + self, + server_name: str, + phase: str, + message: str, + *, + tool_name: str | None = None, + exc: BaseException | None = None, + ) -> None: + if self.mcp_failure_handler is None: + return + from ms_agent.mcp.runtime import classify_failure_message, is_connection_error + if exc is not None: + if not is_connection_error(exc): + return + elif classify_failure_message(message) == 'none': + return + await self.mcp_failure_handler( + server_name, + phase, + message, + tool_name=tool_name, + exc=exc, + ) + + def _extend_mcp_tool_index( + self, + tool_ins: ToolBase, + server_name: str, + tool_list: List[Tool], + ) -> None: + for tool in tool_list: + max_server_len = MAX_TOOL_NAME_LEN - len( + tool['tool_name']) - len(self.TOOL_SPLITER) + if len(server_name) > max_server_len: + key = ( + f"{server_name[:max(0, max_server_len)]}" + f"{self.TOOL_SPLITER}{tool['tool_name']}") + else: + key = f"{server_name}{self.TOOL_SPLITER}{tool['tool_name']}" + assert key not in self._tool_index, ( + f'Tool name duplicated {tool["tool_name"]}') + indexed = copy(tool) + indexed['tool_name'] = key + self._tool_index[key] = (tool_ins, server_name, indexed) + self._mcp_index_keys.add(key) + + async def sync_mcp_tools( + self, + *, + visible_servers: set[str], + indexable_servers: set[str], + callable_servers: set[str], + cached_tools_by_server: dict[str, list[dict]] | None = None, + ) -> list[tuple[str, BaseException]]: + """Rebuild MCP entries in ``_tool_index`` (called by MCPRuntime). + + Returns transport failures from per-server ``list_tools`` calls. + """ + del visible_servers, callable_servers, cached_tools_by_server + failures: list[tuple[str, BaseException]] = [] + async with self._sync_lock: + self._clear_mcp_index_entries() + if self.servers is None: + return failures + for server_name in indexable_servers: + try: + if hasattr(self.servers, 'get_tools_for_server'): + tool_list = await self.servers.get_tools_for_server( + server_name) + else: + live_mcps = await self.servers.get_tools() + tool_list = live_mcps.get(server_name, []) + except Exception as exc: + logger.warning( + 'Failed to list tools for MCP server %s: %s', + server_name, + exc, + ) + failures.append((server_name, exc)) + continue + if tool_list: + self._extend_mcp_tool_index( + self.servers, server_name, tool_list) + return failures + async def reindex_tool(self): def extend_tool(tool_ins: ToolBase, server_name: str, @@ -247,7 +383,7 @@ def extend_tool(tool_ins: ToolBase, server_name: str, if self.servers is not None: mcps = await self.servers.get_tools() for server_name, tool_list in mcps.items(): - extend_tool(self.servers, server_name, tool_list) + self._extend_mcp_tool_index(self.servers, server_name, tool_list) for extra_tool in self.extra_tools: tools = await extra_tool.get_tools() for server_name, tool_list in tools.items(): @@ -272,6 +408,9 @@ async def single_call_tool(self, tool_info: ToolCall): brief_info = json.dumps(tool_info, ensure_ascii=False) if len(brief_info) > 1024: brief_info = brief_info[:1024] + '...' + wait_sec = self.tool_call_timeout + tool_ins = None + server_name = '' try: tool_name = tool_info['tool_name'] tool_args = tool_info['arguments'] @@ -281,7 +420,72 @@ async def single_call_tool(self, tool_info: ToolCall): except Exception: # noqa return f'The input {tool_args} is not a valid JSON, fix your arguments and try again' assert tool_name in self._tool_index, f'Tool name {tool_name} not found' - tool_ins, server_name, _ = self._tool_index[tool_name] + index_snapshot = self._tool_index[tool_name] + tool_ins, server_name, _ = index_snapshot + + # --- MCP availability (before SafetyGuard / PreToolUse) --- + if (tool_ins is self.servers and self.mcp_callable_check is not None + and not self.mcp_callable_check(server_name)): + detail = ( + self.mcp_unavailable_detail(server_name) + if self.mcp_unavailable_detail is not None else { + 'success': False, + 'error': 'mcp_unavailable', + 'server_name': server_name, + 'message': f'MCP server {server_name} is not callable', + }) + return json.dumps(detail, ensure_ascii=False) + + # --- Permission checks --- + args_dict = dict(tool_args) if isinstance(tool_args, dict) else {} + if self._safety_guard is not None: + from ms_agent.permission.ask_resolver import resolve_ask + safety_decision = self._safety_guard.check(tool_name, args_dict) + if safety_decision.action == 'deny': + return f'Blocked by safety policy: {safety_decision.reason}' + if safety_decision.action == 'ask': + resolved = resolve_ask(safety_decision, self._permission_mode, self._read_policy) + if resolved.action == 'deny': + return f'Blocked by safety policy: {resolved.reason}' + if resolved.action == 'ask': + if self._permission_enforcer is None: + return f'Blocked by safety policy (requires confirmation): {resolved.reason}' + # interactive mode: fall through to enforcer/handler + + # --- PreToolUse hooks --- + hook_result = None + pre_attachments: list = [] + if self._hook_runtime is not None and not self._hook_runtime.is_empty: + from ms_agent.utils.workspace_context import resolve_workspace_root + project_path = str(resolve_workspace_root(self.config)) + hook_result, pre_attachments = await self._hook_runtime.run_pre_tool_use( + tool_name=tool_name, + tool_args=args_dict, + project_path=project_path, + ) + if hook_result.updated_args is not None: + tool_args = hook_result.updated_args + args_dict = dict(hook_result.updated_args) + tool_info['arguments'] = tool_args + + from ms_agent.hooks.permission_resolve import resolve_hook_permission_decision + + perm_out = await resolve_hook_permission_decision( + hook_result=hook_result, + tool_name=tool_name, + tool_args=args_dict, + permission_enforcer=self._permission_enforcer, + permission_config=self._permission_config, + hook_runtime=self._hook_runtime, + ) + if isinstance(perm_out, str): + return perm_out + if perm_out.action == 'deny': + return f'Tool call denied: {perm_out.reason}' + if perm_out.updated_args is not None: + tool_args = perm_out.updated_args + tool_info['arguments'] = tool_args + raw_args = dict(tool_args) if isinstance(tool_args, dict) else {} wait_sec = effective_tool_wait_seconds( raw_args, @@ -309,20 +513,67 @@ async def single_call_tool(self, tool_info: ToolCall): tool_name, self.TOOL_SPLITER), tool_args=call_args), timeout=wait_sec) + + if (self.mcp_success_handler is not None + and tool_ins is self.servers): + await self.mcp_success_handler(server_name) + + # --- PostToolUse hooks --- + hook_attachments = list(pre_attachments) + if self._hook_runtime is not None and not self._hook_runtime.is_empty: + response_text = ( + response if isinstance(response, str) + else str(response.get('result', response)) + if isinstance(response, dict) else str(response)) + _, post_attachments = await self._hook_runtime.run_post_tool_use( + tool_name=tool_name, + tool_args=args_dict, + tool_result=response_text, + tool_call_id=tool_info.get('id'), + ) + hook_attachments.extend(post_attachments) + if hook_attachments: + if isinstance(response, dict): + response = dict(response) + response['hook_attachments'] = hook_attachments + else: + response = { + 'result': response, + 'hook_attachments': hook_attachments, + } return response except asyncio.TimeoutError: import traceback logger.warning(traceback.format_exc()) tn = tool_info.get('tool_name', '(unknown)') - return ( + timeout_msg = ( f'Tool call timed out after {wait_sec:.0f}s (tool: {tn}). ' f'Default limit is {self.tool_call_timeout:.0f}s; ' f'set numeric field "timeout" in the tool arguments to wait longer ' f'(seconds, maximum {self.tool_call_timeout_max:.0f}s). ' f'Original call (truncated): {brief_info}') + if tool_ins is not None and tool_ins is self.servers: + await self._report_mcp_failure( + server_name, + 'call_tool', + timeout_msg, + tool_name=self._registered_tool_suffix( + tool_info.get('tool_name', ''), self.TOOL_SPLITER), + exc=asyncio.TimeoutError(timeout_msg), + ) + return timeout_msg except Exception as e: import traceback logger.warning(traceback.format_exc()) + if tool_ins is not None and tool_ins is self.servers: + await self._report_mcp_failure( + server_name, + 'call_tool', + str(e), + tool_name=self._registered_tool_suffix( + tool_info.get('tool_name', ''), self.TOOL_SPLITER), + exc=e, + ) return f'Tool calling failed: {brief_info}, details: {str(e)}' async def parallel_call_tool(self, tool_list: List[ToolCall]): diff --git a/ms_agent/utils/constants.py b/ms_agent/utils/constants.py index e068e654e..cbd37e271 100644 --- a/ms_agent/utils/constants.py +++ b/ms_agent/utils/constants.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from typing import Dict, Optional -# The default output dir +# The default output dir when explicitly referenced in legacy configs. +# When ``output_dir`` is omitted, ``resolve_workspace_root()`` uses cwd instead. DEFAULT_OUTPUT_DIR = './output' DEFAULT_INDEX_DIR = '.index' diff --git a/ms_agent/utils/pattern_matcher.py b/ms_agent/utils/pattern_matcher.py new file mode 100644 index 000000000..b5d167cc3 --- /dev/null +++ b/ms_agent/utils/pattern_matcher.py @@ -0,0 +1,19 @@ +"""Shared fnmatch pattern matching for hooks and permission modules.""" + +from __future__ import annotations + +import fnmatch + + +def match_pattern(pattern: str, target: str) -> bool: + """Match *target* against *pattern* using fnmatch with ``|`` alternatives. + + Examples: + match_pattern("file_system---*", "file_system---read_file") -> True + match_pattern("read_file|write_file", "read_file") -> True + """ + for alt in pattern.split('|'): + alt = alt.strip() + if alt and fnmatch.fnmatch(target, alt): + return True + return False diff --git a/ms_agent/utils/workspace_context.py b/ms_agent/utils/workspace_context.py new file mode 100644 index 000000000..937dd1e4b --- /dev/null +++ b/ms_agent/utils/workspace_context.py @@ -0,0 +1,43 @@ +"""Lightweight workspace context: root directory and deny globs for file traversal.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +_DEFAULT_DENY_GLOBS: tuple[str, ...] = ('**/.git/**',) +_MISSING = object() + + +def resolve_workspace_root(config: Any) -> Path: + """Resolve the agent workspace root (``output_dir``). + + When ``output_dir`` is omitted or empty in config, defaults to the process + current working directory (the user's workspace). Explicit values are + expanded and resolved to an absolute path. + """ + raw = getattr(config, 'output_dir', _MISSING) + if raw is _MISSING or raw is None: + return Path.cwd().resolve() + text = str(raw).strip() + if not text: + return Path.cwd().resolve() + return Path(text).expanduser().resolve() + + +@dataclass(frozen=True) +class WorkspaceContext: + """Runtime context for tools — no security checks, only cwd and traversal filtering.""" + + root: Path + deny_globs: tuple[str, ...] = _DEFAULT_DENY_GLOBS + + @classmethod + def from_config(cls, config: Any) -> WorkspaceContext: + wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) + raw_deny = list(getattr(wp, 'deny_globs', []) or []) if wp else [] + deny = tuple(raw_deny) if raw_deny else _DEFAULT_DENY_GLOBS + + return cls(root=resolve_workspace_root(config), deny_globs=deny) diff --git a/ms_agent/utils/workspace_policy.py b/ms_agent/utils/workspace_policy.py deleted file mode 100644 index c11edf4db..000000000 --- a/ms_agent/utils/workspace_policy.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Workspace path policy: allow-roots (default output_dir) and optional deny globs.""" - -from __future__ import annotations - -import fnmatch -import os -import re -from pathlib import Path -from typing import Iterable, Sequence - - -class WorkspacePolicyError(ValueError): - """Raised when a path or command violates workspace policy.""" - - -class WorkspacePolicyKernel: - """Resolve user paths under allowed workspace roots; optional shell read-only rules.""" - - def __init__( - self, - output_dir: Path | str, - *, - extra_allow_roots: Sequence[str | Path] | None = None, - deny_globs: Sequence[str] | None = None, - shell_default_mode: str = 'workspace_write', - shell_network_enabled: bool = False, - max_command_chars: int = 8192, - ) -> None: - self._output = Path(output_dir).expanduser().resolve() - self._roots: list[Path] = [self._output] - if extra_allow_roots: - for r in extra_allow_roots: - p = Path(r).expanduser().resolve() - if p not in self._roots: - self._roots.append(p) - if deny_globs is None or len(tuple(deny_globs)) == 0: - self._deny_globs: tuple[str, ...] = ('**/.git/**', ) - else: - self._deny_globs = tuple(deny_globs) - self.shell_default_mode = shell_default_mode - self.shell_network_enabled = shell_network_enabled - self.max_command_chars = max_command_chars - - @property - def workspace_root(self) -> Path: - return self._output - - @property - def allow_roots(self) -> tuple[Path, ...]: - return tuple(self._roots) - - @property - def deny_globs(self) -> tuple[str, ...]: - return self._deny_globs - - def resolve_under_roots(self, user_path: str | Path) -> Path: - """Resolve *user_path* to an absolute path that must lie under one allow root.""" - raw = Path(user_path).expanduser() - if raw.is_absolute(): - resolved = raw.resolve() - else: - resolved = (self._output / raw).resolve() - for root in self._roots: - try: - resolved.relative_to(root) - break - except ValueError: - continue - else: - raise WorkspacePolicyError( - f'Path is outside allowed workspace roots: {resolved}') - if self._is_denied(resolved): - raise WorkspacePolicyError( - f'Path matches a deny_globs pattern: {resolved}') - return resolved - - def _is_denied(self, path: Path) -> bool: - if not self._deny_globs: - return False - rel = None - try: - rel = path.relative_to(self._output) - except ValueError: - rel = path - rel_s = rel.as_posix() - for pat in self._deny_globs: - if fnmatch.fnmatch(rel_s, pat) or fnmatch.fnmatch(path.name, pat): - return True - if fnmatch.fnmatch(str(path), pat): - return True - return False - - def path_is_allowed(self, path: Path) -> bool: - path = path.expanduser().resolve() - for root in self._roots: - try: - path.relative_to(root) - break - except ValueError: - continue - else: - return False - return not self._is_denied(path) - - def assert_shell_command_allowed(self, command: str) -> None: - """Length and mode-based checks before executing shell.""" - if not command or not command.strip(): - raise WorkspacePolicyError('Empty shell command') - if len(command) > self.max_command_chars: - raise WorkspacePolicyError( - f'Shell command exceeds max length ({self.max_command_chars})') - - mode = self.shell_default_mode - if mode == 'read_only': - if _shell_looks_mutating_or_network(command, allow_network=False): - raise WorkspacePolicyError( - 'Shell is in read_only mode: mutating or network commands are not allowed' - ) - elif mode == 'workspace_write': - if not self.shell_network_enabled and _shell_looks_network( - command): - raise WorkspacePolicyError( - 'Network commands are disabled for shell (enable tools.code_executor.shell.network_enabled)' - ) - # future: explicit 'network' mode could allow curl etc. - - -def _shell_looks_network(command: str) -> bool: - lowered = command.lower() - tokens = ( - 'curl ', - 'wget ', - 'ssh ', - 'scp ', - 'rsync ', - 'ftp ', - 'nc ', - 'netcat ', - 'pip install', - 'pip3 install', - 'npm install', - 'yarn add', - 'pnpm add', - ) - return any(t in lowered for t in tokens) - - -def _shell_looks_mutating_or_network(command: str, *, - allow_network: bool) -> bool: - if not allow_network and _shell_looks_network(command): - return True - # redirection that creates/overwrites files - if re.search(r'[>]{1,2}\s*[^\s]', command): - return True - if re.search(r'\b(rm|rmdir|mv|cp|chmod|chown|chgrp|mkdir|touch|tee)\b', - command): - return True - return False - - -def iter_files_under( - root: Path, - *, - deny_globs: Iterable[str] = (), - max_files: int = 100_000, -) -> Iterable[Path]: - """Yield files under *root* (depth-first), skipping directories matching deny globs.""" - deny = tuple(deny_globs) - count = 0 - root = root.resolve() - - def dir_skipped(dirpath: Path) -> bool: - try: - rel = dirpath.relative_to(root).as_posix() - except ValueError: - return True - for pat in deny: - if fnmatch.fnmatch(rel, pat) or fnmatch.fnmatch(rel + '/', pat): - return True - parts = rel.split('/') - for i in range(len(parts)): - sub = '/'.join(parts[:i + 1]) - if fnmatch.fnmatch(sub, pat.rstrip('/')) or fnmatch.fnmatch( - sub + '/', pat): - return True - return False - - for dirpath, dirnames, filenames in os.walk( - root, topdown=True, followlinks=False): - dp = Path(dirpath) - if dir_skipped(dp): - dirnames[:] = [] - continue - # prune skipped subdirs - keep: list[str] = [] - for d in dirnames: - child = dp / d - if dir_skipped(child): - continue - keep.append(d) - dirnames[:] = keep - for name in filenames: - count += 1 - if count > max_files: - return - yield dp / name diff --git a/projects/deep_research/v2/run_benchmark.sh b/projects/deep_research/v2/run_benchmark.sh index b781afaf0..d4c147f84 100755 --- a/projects/deep_research/v2/run_benchmark.sh +++ b/projects/deep_research/v2/run_benchmark.sh @@ -36,10 +36,6 @@ PYTHON_BIN="/Users/luyan/software/miniconda3/bin/python" # Force unbuffered output so progress lines like "[xx] OK" show up in logs promptly. export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}" -# When stdout is redirected (e.g., nohup > file), Python is block-buffered by default. -# Force unbuffered output so progress lines like "[xx] OK" show up in logs promptly. -export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}" - # Use caffeinate on macOS when available; otherwise run normally. RUN_PREFIX=() if command -v caffeinate >/dev/null 2>&1; then diff --git a/tests/config/test_mcp_resolver.py b/tests/config/test_mcp_resolver.py new file mode 100644 index 000000000..dd3322764 --- /dev/null +++ b/tests/config/test_mcp_resolver.py @@ -0,0 +1,135 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for MCP config merge rules (design doc §5.4).""" +from __future__ import annotations + +from pathlib import Path + +import pytest +from omegaconf import OmegaConf + +from ms_agent.config.mcp_manager import MCPConfigManager +from ms_agent.config.mcp_schema import merge_mcp_layers, normalize_mcp_server_entry +from ms_agent.config.resolver import ConfigResolver + + +@pytest.fixture +def tmp_roots(tmp_path: Path): + global_root = tmp_path / 'global' + project_root = tmp_path / 'project' + global_root.mkdir() + project_root.mkdir() + return global_root, project_root + + +class TestNormalizeMcpServerEntry: + def test_strips_agent_yaml_metadata(self): + entry = { + 'mcp': True, + 'command': 'npx', + 'args': ['-y', 'pkg'], + 'implementation': 'builtin', + 'trust_remote_code': True, + } + normalized = normalize_mcp_server_entry(entry, source='agent_yaml') + assert normalized is not None + assert normalized['command'] == 'npx' + assert 'implementation' not in normalized + assert normalized['source'] == 'agent_yaml' + assert normalized['enabled'] is True + + def test_mcp_false_excluded(self): + assert normalize_mcp_server_entry({'mcp': False, 'command': 'x'}) is None + + +class TestConfigMergeCases: + def test_case1_global_only(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('fetch', {'command': 'A'}, scope='global') + resolver = ConfigResolver(global_root, project_root) + resolved = resolver.resolve_mcp() + assert resolved.mcp_servers['fetch']['enabled'] is True + assert resolved.mcp_servers['fetch']['command'] == 'A' + + def test_case2_project_reenables(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('fetch', {'command': 'A', 'enabled': False}, scope='global') + mgr.set_enabled('fetch', True, scope='project') + resolver = ConfigResolver(global_root, project_root) + resolved = resolver.resolve_mcp() + assert resolved.mcp_servers['fetch']['enabled'] is True + assert resolved.mcp_servers['fetch']['command'] == 'A' + + def test_case3_agent_yaml_overrides_command(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('fetch', {'command': 'A'}, scope='global') + agent_cfg = OmegaConf.create({ + 'tools': { + 'fetch': {'mcp': True, 'command': 'B'}, + }, + }) + resolver = ConfigResolver(global_root, project_root, agent_config=agent_cfg) + resolved = resolver.resolve_mcp() + assert resolved.mcp_servers['fetch']['command'] == 'B' + + def test_case4_project_overrides_command(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('fetch', {'command': 'A'}, scope='global') + mgr.add('fetch', {'command': 'C'}, scope='project') + resolver = ConfigResolver(global_root, project_root) + resolved = resolver.resolve_mcp() + assert resolved.mcp_servers['fetch']['command'] == 'C' + + def test_case5_project_remove_masks_global(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('fetch', {'command': 'A'}, scope='global') + mgr.remove('fetch', scope='project') + resolver = ConfigResolver(global_root, project_root) + resolved = resolver.resolve_mcp() + assert resolved.mcp_servers['fetch']['enabled'] is False + + def test_case6_session_reenables(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('fetch', {'command': 'A', 'enabled': False}, scope='global') + resolver = ConfigResolver(global_root, project_root) + resolved = resolver.resolve_mcp( + session_override={'fetch': {'enabled': True}}) + assert resolved.mcp_servers['fetch']['enabled'] is True + + def test_case7_mcp_false_not_in_mcp_servers(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('filesystem', {'command': 'A'}, scope='global') + agent_cfg = OmegaConf.create({ + 'tools': { + 'filesystem': {'mcp': False, 'command': 'B'}, + }, + }) + resolver = ConfigResolver(global_root, project_root, agent_config=agent_cfg) + resolved = resolver.resolve_mcp() + assert 'filesystem' not in resolved.mcp_servers + + def test_merge_enabled_inheritance(self): + base = {'command': 'A', 'enabled': False} + override = {'command': 'B'} + merged = merge_mcp_layers({'fetch': base}, {'fetch': override}) + assert merged['fetch']['command'] == 'B' + assert merged['fetch']['enabled'] is False + + def test_resolve_mcp_all_layers_builtin_shadow(self, tmp_roots): + global_root, project_root = tmp_roots + mgr = MCPConfigManager(global_root, project_root) + mgr.add('filesystem', {'command': 'A'}, scope='global') + agent_cfg = OmegaConf.create({ + 'tools': { + 'filesystem': {'mcp': False, 'command': 'B'}, + }, + }) + resolver = ConfigResolver(global_root, project_root, agent_config=agent_cfg) + merged = resolver.resolve_mcp_all_layers() + assert 'filesystem' not in merged diff --git a/tests/config/test_resolver.py b/tests/config/test_resolver.py index 103f0c04f..5f77468fe 100644 --- a/tests/config/test_resolver.py +++ b/tests/config/test_resolver.py @@ -243,6 +243,28 @@ def test_skills_merged_into_config(self, tmp_path): assert hasattr(config, '_merged_skills') assert 'bad-skill' in list(config._merged_skills.disabled) + def test_plugins_merged_into_config(self, tmp_path): + global_dir = tmp_path / '.ms_agent' + global_dir.mkdir() + plugin_path = str(global_dir / 'plugins' / 'demo') + plugins_data = { + 'plugins': [{ + 'id': 'demo', + 'enabled': True, + 'format': 'claude', + 'manifest_path': '.claude-plugin/plugin.json', + 'path': plugin_path, + }], + } + (global_dir / 'plugins.json').write_text(json.dumps(plugins_data)) + + resolver = ConfigResolver(global_dir=str(global_dir)) + config = resolver.resolve() + + assert hasattr(config, '_merged_plugins') + assert config._merged_plugins.plugins[0].id == 'demo' + assert plugin_path in list(config.plugins) + class TestPersonalizationInResolver: diff --git a/tests/fixtures/hooks/allow.py b/tests/fixtures/hooks/allow.py new file mode 100644 index 000000000..be4158e65 --- /dev/null +++ b/tests/fixtures/hooks/allow.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +"""Allow hook for tests.""" +import json +import sys + +json.load(sys.stdin) +print(json.dumps({"decision": "allow", "reason": "allowed by test"})) diff --git a/tests/fixtures/hooks/block.sh b/tests/fixtures/hooks/block.sh new file mode 100755 index 000000000..d84d9bebf --- /dev/null +++ b/tests/fixtures/hooks/block.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +echo "blocked via exit 2" >&2 +exit 2 diff --git a/tests/fixtures/hooks/deny.py b/tests/fixtures/hooks/deny.py new file mode 100644 index 000000000..19baf558b --- /dev/null +++ b/tests/fixtures/hooks/deny.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +"""Deny hook for tests.""" +import json +import sys + +json.load(sys.stdin) +print(json.dumps({"decision": "deny", "reason": "blocked by test"})) diff --git a/tests/fixtures/hooks/pass.py b/tests/fixtures/hooks/pass.py new file mode 100755 index 000000000..6c596941e --- /dev/null +++ b/tests/fixtures/hooks/pass.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +"""Pass-through hook for tests.""" +import json +import sys + +event = json.load(sys.stdin) +print(json.dumps({})) diff --git a/tests/mcp/test_mcp_runtime.py b/tests/mcp/test_mcp_runtime.py new file mode 100644 index 000000000..ae6a8d4a9 --- /dev/null +++ b/tests/mcp/test_mcp_runtime.py @@ -0,0 +1,478 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for MCPRuntime and ToolManager integration (design doc §14).""" +from __future__ import annotations + +import asyncio +import json +from typing import Any, Dict, List +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ms_agent.config.mcp_schema import ResolvedMCPConfig +from ms_agent.llm.utils import Tool +from ms_agent.mcp.runtime import ( + DEGRADED_FAILURE_THRESHOLD, + MCPRuntime, + classify_mcp_failure, + is_connection_error, +) +from ms_agent.tools.tool_manager import ToolManager + + +class FakeMCPClient: + """Minimal MCPClient stand-in for unit tests.""" + + def __init__(self, mcp_config: Dict[str, Any] | None = None): + self.mcp_config = mcp_config or {'mcpServers': {}} + self.sessions: Dict[str, Any] = {} + self.connect_calls: list[str] = [] + self.get_tools_calls = 0 + self.call_tool_calls = 0 + self.list_tools_raises_for: str | None = None + + def is_connected(self, server_name: str) -> bool: + return server_name in self.sessions + + def list_connected_servers(self) -> list[str]: + return list(self.sessions.keys()) + + async def connect_single_server(self, server_name: str, server_config: dict): + self.connect_calls.append(server_name) + if server_config.get('fail_connect') or server_config.get('command') == 'x': + raise ConnectionError(f'connect failed: {server_name}') + self.sessions[server_name] = object() + return server_name + + async def disconnect_server(self, server_name: str): + self.sessions.pop(server_name, None) + + async def get_tools_for_server(self, server_name: str) -> List[Tool]: + self.get_tools_calls += 1 + if self.list_tools_raises_for == server_name: + raise ConnectionError('session closed') + if server_name not in self.sessions: + return [] + return [ + Tool( + tool_name='demo_tool', + server_name=server_name, + description='demo', + parameters={}, + ) + ] + + async def get_tools(self) -> Dict[str, List[Tool]]: + tools: Dict[str, List[Tool]] = {} + for name in self.sessions: + try: + tools[name] = await self.get_tools_for_server(name) + except Exception: + tools[name] = [] + return tools + + async def call_tool(self, server_name: str, tool_name: str, tool_args: dict): + self.call_tool_calls += 1 + if getattr(self, 'call_raises', False): + raise ConnectionError('broken pipe') + if getattr(self, 'call_timeout', False): + raise TimeoutError('tool call timeout') + return 'ok' + + async def cleanup(self): + self.sessions.clear() + + +def _resolved(*servers: tuple[str, dict]) -> ResolvedMCPConfig: + return ResolvedMCPConfig( + mcp_servers={name: dict(cfg, enabled=cfg.get('enabled', True)) + for name, cfg in servers}) + + +@pytest.mark.asyncio +async def test_independent_client_injection(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + await runtime.start() + assert client.is_connected('fetch') + + +@pytest.mark.asyncio +async def test_disable_removes_tools_but_keeps_session(): + client = FakeMCPClient() + config = _resolved(('fetch', {'command': 'echo'})) + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=config, + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + tools = await tm.get_tools() + assert any('fetch---' in t['tool_name'] for t in tools) + + await runtime.disable_server('fetch') + tools = await tm.get_tools() + assert not any('fetch---' in t['tool_name'] for t in tools) + assert client.is_connected('fetch') + + +@pytest.mark.asyncio +async def test_sync_mcp_tools_clears_and_is_idempotent(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + await runtime.sync_tools() + keys = [t['tool_name'] for t in await tm.get_tools()] + assert keys.count('fetch---demo_tool') == 1 + + +@pytest.mark.asyncio +async def test_connect_skip_policy(): + client = FakeMCPClient() + config = ResolvedMCPConfig(mcp_servers={ + 'bad': {'command': 'x'}, + 'good': {'command': 'echo'}, + }) + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=config, + connect_policy='skip', + owns_client=False, + ) + await runtime.start() + assert runtime.get_server('bad').status == 'error' + assert runtime.get_server('good').status == 'connected' + + +@pytest.mark.asyncio +async def test_runtime_mode_a_no_double_connect(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + await runtime.start() + class _Cfg: + tools = type('T', (), {})() + + tm = ToolManager( + config=_Cfg(), + mcp_config={}, + mcp_client=client, # type: ignore[arg-type] + ) + tm._skip_mcp_reindex = True + await tm.connect() + assert client.connect_calls.count('fetch') == 1 + + +@pytest.mark.asyncio +async def test_degraded_hidden_from_llm_tools(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + + await runtime.record_failure( + 'fetch', 'call_tool', 'Connection closed', + exc=ConnectionError('Connection closed')) + assert runtime.get_server('fetch').status == 'degraded' + tools = await tm.get_tools() + assert not any('fetch---' in t['tool_name'] for t in tools) + assert runtime.is_callable('fetch') is False + + +@pytest.mark.asyncio +async def test_transient_failure_not_immediately_degraded(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + + for _ in range(DEGRADED_FAILURE_THRESHOLD - 1): + await runtime.record_failure( + 'fetch', 'call_tool', 'timeout', + exc=TimeoutError('timeout')) + + assert runtime.get_server('fetch').status == 'connected' + tools = await tm.get_tools() + assert any('fetch---' in t['tool_name'] for t in tools) + + +@pytest.mark.asyncio +async def test_transient_failure_threshold_degrades(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + + for _ in range(DEGRADED_FAILURE_THRESHOLD): + await runtime.record_failure( + 'fetch', 'call_tool', 'timeout', + exc=TimeoutError('timeout')) + + assert runtime.get_server('fetch').status == 'degraded' + tools = await tm.get_tools() + assert not any('fetch---' in t['tool_name'] for t in tools) + + +@pytest.mark.asyncio +async def test_get_tools_per_server_isolation(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved( + ('good', {'command': 'echo'}), + ('bad', {'command': 'echo'}), + ), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + client.list_tools_raises_for = 'bad' + await runtime.sync_tools() + tools = await tm.get_tools() + assert any('good---' in t['tool_name'] for t in tools) + assert not any('bad---' in t['tool_name'] for t in tools) + + +@pytest.mark.asyncio +async def test_mcp_failure_handler_on_connection_error(): + client = FakeMCPClient() + client.call_raises = True + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + + await tm.single_call_tool({ + 'tool_name': 'fetch---demo_tool', + 'arguments': {}, + }) + assert runtime.get_server('fetch').status == 'degraded' + + +@pytest.mark.asyncio +async def test_sync_mcp_tools_during_parallel_hooks(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + + async def call_and_sync(): + task = asyncio.create_task(tm.single_call_tool({ + 'tool_name': 'fetch---demo_tool', + 'arguments': {}, + })) + await asyncio.sleep(0) + await runtime.sync_tools() + return await task + + results = await asyncio.gather(call_and_sync(), runtime.sync_tools()) + assert results is not None + + +@pytest.mark.asyncio +async def test_record_success_resets_transient_counter(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + await runtime.start() + await runtime.record_failure( + 'fetch', 'call_tool', 'timeout', exc=TimeoutError('timeout')) + await runtime.record_failure( + 'fetch', 'call_tool', 'timeout', exc=TimeoutError('timeout')) + assert runtime.get_server('fetch').consecutive_failures == 2 + await runtime.record_success('fetch') + assert runtime.get_server('fetch').consecutive_failures == 0 + assert runtime.get_server('fetch').status == 'connected' + + +def test_classify_mcp_failure(): + assert classify_mcp_failure(TimeoutError()) == 'transient' + assert classify_mcp_failure(ConnectionError('x')) == 'hard' + assert classify_mcp_failure(BrokenPipeError()) == 'hard' + assert classify_mcp_failure(ValueError('bad arg')) == 'none' + + +def test_is_connection_error(): + assert is_connection_error(ConnectionError('x')) + assert is_connection_error(TimeoutError()) + assert not is_connection_error(ValueError('bad arg')) + + +class _DenyHookRuntime: + is_empty = False + + async def run_pre_tool_use(self, tool_name, tool_args, **kwargs): + from ms_agent.hooks.events import HookResult + return HookResult(action='deny', reason='blocked by test'), [] + + async def run_post_tool_use(self, **kwargs): + from ms_agent.hooks.events import HookResult + return HookResult(action='allow'), [] + + +@pytest.mark.asyncio +async def test_mcp_tool_pre_tool_use_deny(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime, hook_runtime=_DenyHookRuntime()) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + + result = await tm.single_call_tool({ + 'tool_name': 'fetch---demo_tool', + 'arguments': {}, + }) + assert 'Blocked by hook' in result + assert client.call_tool_calls == 0 + + +@pytest.mark.asyncio +async def test_timeout_triggers_transient_failure(): + client = FakeMCPClient() + + async def slow_call_tool(server_name, tool_name, tool_args): + await asyncio.sleep(2) + return 'ok' + + client.call_tool = slow_call_tool # type: ignore[method-assign] + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + await runtime.sync_tools() + + await tm.single_call_tool({ + 'tool_name': 'fetch---demo_tool', + 'arguments': {'timeout': 1}, + }) + state = runtime.get_server('fetch') + assert state.consecutive_failures == 1 + assert state.status == 'connected' + + +@pytest.mark.asyncio +async def test_sync_mcp_tools_list_failure_records(): + client = FakeMCPClient() + runtime = MCPRuntime( + mcp_client=client, # type: ignore[arg-type] + config=_resolved(('fetch', {'command': 'echo'})), + owns_client=False, + ) + tm = _make_tool_manager(client, runtime) + await tm.connect() + runtime.bind_tool_manager(tm) + await runtime.start() + client.list_tools_raises_for = 'fetch' + await runtime.sync_tools() + state = runtime.get_server('fetch') + assert state.consecutive_failures >= 1 + assert state.last_error is not None + + +@pytest.mark.asyncio +async def test_mcp_client_aexit_disconnects_server_stacks(): + from contextlib import AsyncExitStack + + from ms_agent.tools.mcp_client import MCPClient + + client = MCPClient({'mcpServers': {}}) + client.sessions['fake'] = object() + stack = AsyncExitStack() + await stack.__aenter__() + client._server_stacks['fake'] = stack + + await client.__aexit__(None, None, None) + assert 'fake' not in client.sessions + assert 'fake' not in client._server_stacks + + +def _make_tool_manager(client, runtime, hook_runtime=None): + class _Tools: + pass + + class _Config: + tool_call_timeout = 30 + tool_call_timeout_max = 600 + tools = _Tools() + + tm = ToolManager( + config=_Config(), + mcp_config={}, + mcp_client=client, # type: ignore[arg-type] + hook_runtime=hook_runtime, + mcp_callable_check=runtime.is_callable, + mcp_failure_handler=runtime.record_failure, + mcp_unavailable_detail=runtime.unavailable_detail, + mcp_success_handler=runtime.record_success, + ) + tm._skip_mcp_reindex = True + return tm diff --git a/tests/permission/__init__.py b/tests/permission/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/permission/test_ask_resolver.py b/tests/permission/test_ask_resolver.py new file mode 100644 index 000000000..2ccb17250 --- /dev/null +++ b/tests/permission/test_ask_resolver.py @@ -0,0 +1,111 @@ +"""Tests for ask_resolver: mode-based resolution of SafetyGuard ``ask`` decisions.""" + +import pytest + +from ms_agent.permission.ask_resolver import resolve_ask +from ms_agent.permission.shell_validator import SafetyDecision + + +class TestStrictMode: + """strict mode: all ask → deny.""" + + @pytest.mark.parametrize('category', [ + 'process_input_sub', + 'process_output_sub', + 'parse_failure', + 'cd_write_compound', + 'command_validator', + 'shell_expansion', + 'read_outside_dirs', + ]) + def test_all_ask_denied(self, category: str) -> None: + decision = SafetyDecision(action='ask', reason='test', category=category) + result = resolve_ask(decision, mode='strict') + assert result.action == 'deny' + assert 'strict mode' in result.reason + + +class TestInteractiveMode: + """interactive mode: ask unchanged.""" + + @pytest.mark.parametrize('category', [ + 'process_input_sub', + 'process_output_sub', + 'parse_failure', + 'cd_write_compound', + 'command_validator', + 'shell_expansion', + 'read_outside_dirs', + ]) + def test_all_ask_preserved(self, category: str) -> None: + decision = SafetyDecision(action='ask', reason='test reason', category=category) + result = resolve_ask(decision, mode='interactive') + assert result.action == 'ask' + assert result.reason == 'test reason' + + +class TestAutoMode: + """auto mode: per-category resolution.""" + + def test_process_input_sub_allowed(self) -> None: + decision = SafetyDecision(action='ask', reason='input sub', category='process_input_sub') + result = resolve_ask(decision, mode='auto') + assert result.action == 'allow' + + def test_process_output_sub_denied(self) -> None: + decision = SafetyDecision(action='ask', reason='output sub', category='process_output_sub') + result = resolve_ask(decision, mode='auto') + assert result.action == 'deny' + + def test_parse_failure_denied(self) -> None: + decision = SafetyDecision(action='ask', reason='bad parse', category='parse_failure') + result = resolve_ask(decision, mode='auto') + assert result.action == 'deny' + + def test_cd_write_compound_denied(self) -> None: + decision = SafetyDecision(action='ask', reason='cd+write', category='cd_write_compound') + result = resolve_ask(decision, mode='auto') + assert result.action == 'deny' + + def test_command_validator_denied(self) -> None: + decision = SafetyDecision(action='ask', reason='suspicious', category='command_validator') + result = resolve_ask(decision, mode='auto') + assert result.action == 'deny' + + def test_shell_expansion_denied(self) -> None: + decision = SafetyDecision(action='ask', reason='$VAR', category='shell_expansion') + result = resolve_ask(decision, mode='auto') + assert result.action == 'deny' + + def test_unknown_category_denied(self) -> None: + decision = SafetyDecision(action='ask', reason='unknown', category='something_new') + result = resolve_ask(decision, mode='auto') + assert result.action == 'deny' + + +class TestReadPolicy: + """read_outside_dirs resolved by read_policy in auto mode.""" + + def test_loose_allows_read_outside(self) -> None: + decision = SafetyDecision(action='ask', reason='outside', category='read_outside_dirs') + result = resolve_ask(decision, mode='auto', read_policy='loose') + assert result.action == 'allow' + + def test_strict_denies_read_outside(self) -> None: + decision = SafetyDecision(action='ask', reason='outside', category='read_outside_dirs') + result = resolve_ask(decision, mode='auto', read_policy='strict') + assert result.action == 'deny' + + +class TestPassthrough: + """Non-ask decisions pass through unchanged.""" + + def test_allow_unchanged(self) -> None: + decision = SafetyDecision(action='allow', reason='ok') + result = resolve_ask(decision, mode='strict') + assert result is decision + + def test_deny_unchanged(self) -> None: + decision = SafetyDecision(action='deny', reason='blocked') + result = resolve_ask(decision, mode='auto') + assert result is decision diff --git a/tests/permission/test_enforcer.py b/tests/permission/test_enforcer.py new file mode 100644 index 000000000..8d9830c4e --- /dev/null +++ b/tests/permission/test_enforcer.py @@ -0,0 +1,189 @@ +"""Tests for PermissionEnforcer.""" + +import pytest + +from ms_agent.permission.config import PermissionConfig +from ms_agent.permission.enforcer import PermissionEnforcer +from ms_agent.permission.handler import ( + AutoPermissionHandler, + PermissionAction, + PermissionResponse, +) +from ms_agent.permission.memory import PermissionMemory + + +def _interactive_config(**kwargs) -> PermissionConfig: + """Build interactive-mode config via from_dict (restricted → interactive alias).""" + raw = {'mode': 'restricted', **kwargs} + if 'whitelist' in raw: + raw['whitelist'] = list(raw['whitelist']) + if 'blacklist' in raw: + raw['blacklist'] = list(raw['blacklist']) + config = PermissionConfig.from_dict(raw) + assert config.mode == 'interactive' + return config + + +class MockDenyHandler: + async def ask(self, tool_name, tool_args, context, suggestions=None): + return PermissionResponse(action=PermissionAction.DENY, feedback='Denied by mock') + + +class MockAllowHandler: + async def ask(self, tool_name, tool_args, context, suggestions=None): + return PermissionResponse(action=PermissionAction.ALLOW_ONCE) + + +class MockAlwaysHandler: + async def ask(self, tool_name, tool_args, context, suggestions=None): + return PermissionResponse( + action=PermissionAction.ALLOW_ALWAYS, + pattern=tool_name, + ) + + +@pytest.fixture +def auto_enforcer(): + config = PermissionConfig(mode='auto') + return PermissionEnforcer(config=config) + + +@pytest.fixture +def interactive_enforcer(tmp_path): + config = _interactive_config( + whitelist=('file_system---read_file',), + blacklist=('code_executor---shell_executor:rm -rf *',), + ) + handler = MockAllowHandler() + memory = PermissionMemory(project_path=tmp_path) + return PermissionEnforcer(config=config, handler=handler, memory=memory) + + +class TestAutoMode: + @pytest.mark.asyncio + async def test_always_allows(self, auto_enforcer): + r = await auto_enforcer.check('any_tool', {}) + assert r.action == 'allow' + assert 'Auto mode' in r.reason + + +class TestInteractiveMode: + @pytest.mark.asyncio + async def test_whitelist_allows(self, interactive_enforcer): + r = await interactive_enforcer.check('file_system---read_file', {'path': '/test'}) + assert r.action == 'allow' + assert 'whitelist' in r.reason + + @pytest.mark.asyncio + async def test_blacklist_denies(self, interactive_enforcer): + r = await interactive_enforcer.check( + 'code_executor---shell_executor', + {'command': 'rm -rf /tmp'}, + ) + assert r.action == 'deny' + assert 'blacklist' in r.reason + + @pytest.mark.asyncio + async def test_unknown_asks_handler(self, interactive_enforcer): + r = await interactive_enforcer.check('unknown---tool', {'arg': 'val'}) + assert r.action == 'allow' # MockAllowHandler returns allow_once + + @pytest.mark.asyncio + async def test_deny_handler(self, tmp_path): + config = _interactive_config() + handler = MockDenyHandler() + memory = PermissionMemory(project_path=tmp_path) + enforcer = PermissionEnforcer(config=config, handler=handler, memory=memory) + + r = await enforcer.check('unknown---tool', {}) + assert r.action == 'deny' + + +class TestBlacklistPriority: + @pytest.mark.asyncio + async def test_blacklist_over_whitelist(self, tmp_path): + config = _interactive_config( + whitelist=('code_executor---*',), + blacklist=('code_executor---shell_executor:rm *',), + ) + enforcer = PermissionEnforcer( + config=config, + handler=MockAllowHandler(), + memory=PermissionMemory(project_path=tmp_path), + ) + r = await enforcer.check( + 'code_executor---shell_executor', + {'command': 'rm -rf /'}, + ) + assert r.action == 'deny' + + +class TestMemoryIntegration: + @pytest.mark.asyncio + async def test_session_memory(self, tmp_path): + config = _interactive_config() + memory = PermissionMemory(project_path=tmp_path) + memory.add_session('custom---tool') + enforcer = PermissionEnforcer( + config=config, + handler=MockDenyHandler(), + memory=memory, + ) + r = await enforcer.check('custom---tool', {}) + assert r.action == 'allow' + + @pytest.mark.asyncio + async def test_persistent_memory(self, tmp_path): + config = _interactive_config() + memory = PermissionMemory(project_path=tmp_path) + memory.add('custom---tool', scope='project') + enforcer = PermissionEnforcer( + config=config, + handler=MockDenyHandler(), + memory=memory, + ) + r = await enforcer.check('custom---tool', {}) + assert r.action == 'allow' + + @pytest.mark.asyncio + async def test_allow_always_persists(self, tmp_path): + config = _interactive_config() + memory = PermissionMemory(project_path=tmp_path) + enforcer = PermissionEnforcer( + config=config, + handler=MockAlwaysHandler(), + memory=memory, + ) + + r = await enforcer.check('new---tool', {}) + assert r.action == 'allow' + + # Second call should match from memory + enforcer2 = PermissionEnforcer( + config=config, + handler=MockDenyHandler(), + memory=memory, + ) + r2 = await enforcer2.check('new---tool', {}) + assert r2.action == 'allow' + + +class TestModifyAction: + @pytest.mark.asyncio + async def test_modify_returns_updated_args(self, tmp_path): + class MockModifyHandler: + async def ask(self, tool_name, tool_args, context, suggestions=None): + return PermissionResponse( + action=PermissionAction.MODIFY, + updated_args={'command': 'ls -la'}, + ) + + config = _interactive_config() + enforcer = PermissionEnforcer( + config=config, + handler=MockModifyHandler(), + memory=PermissionMemory(project_path=tmp_path), + ) + r = await enforcer.check('code_executor---shell_executor', {'command': 'rm -rf /'}) + assert r.action == 'allow' + assert r.updated_args == {'command': 'ls -la'} diff --git a/tests/permission/test_matcher.py b/tests/permission/test_matcher.py new file mode 100644 index 000000000..6e8e64bf6 --- /dev/null +++ b/tests/permission/test_matcher.py @@ -0,0 +1,84 @@ +"""Tests for PermissionMatcher.""" + +import pytest + +from ms_agent.permission.matcher import PermissionMatcher + + +@pytest.fixture +def matcher(): + return PermissionMatcher() + + +class TestMatch: + def test_exact_match(self, matcher): + assert matcher.match('file_system---read_file', 'file_system---read_file') + + def test_wildcard_star(self, matcher): + assert matcher.match('file_system---*', 'file_system---read_file') + assert matcher.match('*---read_file', 'file_system---read_file') + assert matcher.match('*', 'anything') + + def test_wildcard_question(self, matcher): + assert matcher.match('file_system---read_fil?', 'file_system---read_file') + assert not matcher.match('file_system---read_fil?', 'file_system---read_files') + + def test_no_match(self, matcher): + assert not matcher.match('file_system---write_file', 'file_system---read_file') + + def test_pipe_alternatives(self, matcher): + assert matcher.match('read_file|write_file', 'read_file') + assert matcher.match('read_file|write_file', 'write_file') + assert not matcher.match('read_file|write_file', 'edit_file') + + def test_pipe_with_wildcards(self, matcher): + assert matcher.match('file_system---*|web_search---*', 'web_search---fetch_page') + + def test_empty_pattern(self, matcher): + assert not matcher.match('', 'file_system---read_file') + + +class TestMatchWithContent: + def test_tool_name_only(self, matcher): + assert matcher.match_with_content( + 'file_system---read_file', + 'file_system---read_file', + {'path': '/tmp/test'}, + ) + + def test_content_pattern(self, matcher): + assert matcher.match_with_content( + 'code_executor---shell_executor:pip *', + 'code_executor---shell_executor', + {'command': 'pip install requests'}, + ) + + def test_content_no_match(self, matcher): + assert not matcher.match_with_content( + 'code_executor---shell_executor:npm *', + 'code_executor---shell_executor', + {'command': 'pip install requests'}, + ) + + def test_content_pattern_with_wildcard_tool(self, matcher): + assert matcher.match_with_content( + '*---shell_executor:ls *', + 'code_executor---shell_executor', + {'command': 'ls -la'}, + ) + + def test_no_content_available(self, matcher): + assert not matcher.match_with_content( + 'unknown---tool:pattern', + 'unknown---tool', + {'some_arg': 'value'}, + ) + + def test_non_string_content_is_coerced(self, matcher): + # Non-string args must not crash fnmatch (TypeError). + result = matcher.match_with_content( + 'file_system---read_file:/tmp/*', + 'file_system---read_file', + {'path': ['/tmp/a', '/tmp/b']}, + ) + assert isinstance(result, bool) diff --git a/tests/permission/test_memory.py b/tests/permission/test_memory.py new file mode 100644 index 000000000..d2af175a2 --- /dev/null +++ b/tests/permission/test_memory.py @@ -0,0 +1,105 @@ +"""Tests for PermissionMemory.""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from ms_agent.permission.memory import PermissionMemory + + +@pytest.fixture +def memory(tmp_path): + project_path = tmp_path / 'project' + project_path.mkdir() + global_path = tmp_path / 'global' / 'permission_memory.json' + return PermissionMemory(project_path=project_path, global_path=global_path) + + +class TestAdd: + def test_add_project(self, memory, tmp_path): + memory.add('file_system---read_*', scope='project') + entries = memory.list_all() + assert len(entries) == 1 + assert entries[0].pattern == 'file_system---read_*' + assert entries[0].scope == 'project' + + def test_add_global(self, memory): + memory.add('web_search---*', scope='global') + entries = memory.list_all() + assert len(entries) == 1 + assert entries[0].scope == 'global' + + +class TestMatches: + def test_match_persistent(self, memory): + memory.add('file_system---read_*', scope='project') + assert memory.matches('file_system---read_file', {}) + assert not memory.matches('file_system---write_file', {}) + + def test_match_session(self, memory): + memory.add_session('code_executor---shell_executor:ls *') + assert memory.matches('code_executor---shell_executor', {'command': 'ls -la'}) + assert not memory.matches('code_executor---shell_executor', {'command': 'rm file'}) + + def test_match_content_pattern(self, memory): + memory.add('code_executor---shell_executor:pip *', scope='project') + assert memory.matches('code_executor---shell_executor', {'command': 'pip install requests'}) + assert not memory.matches('code_executor---shell_executor', {'command': 'npm install'}) + + +class TestRevoke: + def test_revoke(self, memory): + memory.add('file_system---*', scope='project') + assert memory.matches('file_system---read_file', {}) + count = memory.revoke('file_system---*') + assert count == 1 + assert not memory.matches('file_system---read_file', {}) + + def test_revoke_nonexistent(self, memory): + count = memory.revoke('nonexistent') + assert count == 0 + + +class TestPersistence: + def test_reload(self, tmp_path): + project_path = tmp_path / 'project' + project_path.mkdir() + global_path = tmp_path / 'global' / 'permission_memory.json' + + mem1 = PermissionMemory(project_path=project_path, global_path=global_path) + mem1.add('file_system---*', scope='project') + mem1.add('web_search---*', scope='global') + + mem2 = PermissionMemory(project_path=project_path, global_path=global_path) + assert mem2.matches('file_system---read_file', {}) + assert mem2.matches('web_search---fetch_page', {}) + + def test_session_not_persisted(self, tmp_path): + project_path = tmp_path / 'project' + project_path.mkdir() + + mem1 = PermissionMemory(project_path=project_path) + mem1.add_session('temp_pattern') + + mem2 = PermissionMemory(project_path=project_path) + assert not mem2.matches('temp_pattern', {}) + + +class TestEdgeCases: + def test_no_project_path(self, tmp_path): + global_path = tmp_path / 'global' / 'permission_memory.json' + mem = PermissionMemory(project_path=None, global_path=global_path) + mem.add('test', scope='global') + assert mem.matches('test', {}) + + def test_corrupt_file(self, tmp_path): + project_path = tmp_path / 'project' + project_path.mkdir() + mem_file = project_path / '.ms_agent' / 'permission_memory.json' + mem_file.parent.mkdir(parents=True) + mem_file.write_text('not json') + + mem = PermissionMemory(project_path=project_path) + assert mem.list_all() == [] diff --git a/tests/permission/test_path_extractors.py b/tests/permission/test_path_extractors.py new file mode 100644 index 000000000..7f3b58885 --- /dev/null +++ b/tests/permission/test_path_extractors.py @@ -0,0 +1,160 @@ +"""Tests for PATH_EXTRACTORS registry.""" + +import pytest + +from ms_agent.permission.path_extractors import ( + build_extractor_registry, + extract_cd, + extract_find, + extract_git, + extract_jq, + extract_sed, + extract_tr, + filter_out_flags, + parse_pattern_command, +) + + +class TestFilterOutFlags: + def test_basic(self): + assert filter_out_flags(['-la', 'file1', 'file2']) == ['file1', 'file2'] + + def test_double_dash(self): + assert filter_out_flags(['--', '-file']) == ['-file'] + + def test_mixed(self): + assert filter_out_flags(['-r', 'src', '--force', '--', '-tricky']) == ['src', '-tricky'] + + def test_no_flags(self): + assert filter_out_flags(['a', 'b', 'c']) == ['a', 'b', 'c'] + + def test_empty(self): + assert filter_out_flags([]) == [] + + +class TestExtractCd: + def test_no_args(self): + import os + assert extract_cd([]) == [os.path.expanduser('~')] + + def test_single_dir(self): + assert extract_cd(['/tmp']) == ['/tmp'] + + def test_space_in_dir(self): + assert extract_cd(['/my', 'dir']) == ['/my dir'] + + +class TestExtractFind: + def test_basic(self): + assert extract_find(['.', '-name', '*.py']) == ['.'] + + def test_multiple_paths(self): + assert extract_find(['dir1', 'dir2', '-name', '*.py']) == ['dir1', 'dir2'] + + def test_no_args(self): + assert extract_find([]) == ['.'] + + def test_path_flag(self): + paths = extract_find(['.', '-newer', '/ref/file', '-name', '*.txt']) + assert '/ref/file' in paths + + def test_global_options(self): + assert extract_find(['-L', '/src', '-name', '*.py']) == ['/src'] + + +class TestParsePatternCommand: + def test_grep_basic(self): + flags = {'-e', '--regexp', '-f', '--file', '-A', '-B', '-C', '-m', '--max-count'} + paths = parse_pattern_command(['pattern', 'file1', 'file2'], flags) + assert paths == ['file1', 'file2'] + + def test_grep_with_e_flag(self): + flags = {'-e', '--regexp'} + paths = parse_pattern_command(['-e', 'pattern', 'file1'], flags) + assert paths == ['file1'] + + def test_grep_no_files(self): + flags = {'-e', '--regexp'} + paths = parse_pattern_command(['pattern'], flags, defaults=['.']) + assert paths == ['.'] + + +class TestExtractSed: + def test_inline_expression(self): + assert extract_sed(['s/a/b/', 'file.txt']) == ['file.txt'] + + def test_e_flag(self): + assert extract_sed(['-e', 's/a/b/', 'file.txt']) == ['file.txt'] + + def test_f_flag(self): + paths = extract_sed(['-f', 'script.sed', 'file.txt']) + assert 'script.sed' in paths + assert 'file.txt' in paths + + def test_multiple_expressions(self): + assert extract_sed(['-e', 's/a/b/', '-e', 's/c/d/', 'file.txt']) == ['file.txt'] + + +class TestExtractJq: + def test_basic(self): + assert extract_jq(['.data', 'file.json']) == ['file.json'] + + def test_no_files(self): + assert extract_jq(['.data']) == [] + + def test_with_flags(self): + assert extract_jq(['-r', '.data', 'file.json']) == ['file.json'] + + +class TestExtractGit: + def test_diff_no_index(self): + assert extract_git(['diff', '--no-index', 'a.txt', 'b.txt']) == ['a.txt', 'b.txt'] + + def test_other_subcommand(self): + assert extract_git(['status']) == [] + + def test_regular_diff(self): + assert extract_git(['diff', 'HEAD']) == [] + + +class TestExtractTr: + def test_basic(self): + assert extract_tr(['a-z', 'A-Z']) == [] + + def test_with_delete(self): + assert extract_tr(['-d', 'set1']) == [] + + def test_with_file(self): + assert extract_tr(['a-z', 'A-Z', 'file.txt']) == ['file.txt'] + + +class TestRegistry: + def test_registry_size(self): + reg = build_extractor_registry() + assert len(reg) >= 34 + + def test_all_commands_have_op_type(self): + reg = build_extractor_registry() + for cmd, entry in reg.items(): + assert entry.op_type in ('read', 'write', 'create'), f'{cmd} has invalid op_type' + + def test_rm_is_write(self): + reg = build_extractor_registry() + assert reg['rm'].op_type == 'write' + + def test_cat_is_read(self): + reg = build_extractor_registry() + assert reg['cat'].op_type == 'read' + + def test_mkdir_is_create(self): + reg = build_extractor_registry() + assert reg['mkdir'].op_type == 'create' + + def test_mv_has_validator(self): + reg = build_extractor_registry() + assert reg['mv'].command_validator is not None + assert reg['mv'].command_validator(['-t', '/dst', 'file']) is not None # should warn + + def test_mv_no_flags_ok(self): + reg = build_extractor_registry() + assert reg['mv'].command_validator(['src', 'dst']) is None diff --git a/tests/permission/test_path_validator.py b/tests/permission/test_path_validator.py new file mode 100644 index 000000000..f92a605b3 --- /dev/null +++ b/tests/permission/test_path_validator.py @@ -0,0 +1,191 @@ +"""Tests for path validation core.""" + +import os +import tempfile + +import pytest + +from ms_agent.permission.path_validator import ( + get_glob_base_directory, + is_dangerous_removal_path, + validate_path, +) + + +class TestValidatePath: + def test_relative_path_within_allowed(self): + with tempfile.TemporaryDirectory() as td: + r = validate_path('test.txt', td, [td], 'read') + assert r.allowed + assert r.action == 'allow' + + def test_absolute_path_within_allowed(self): + with tempfile.TemporaryDirectory() as td: + r = validate_path(os.path.join(td, 'test.txt'), td, [td], 'write') + assert r.allowed + + def test_path_outside_allowed_write(self): + r = validate_path('/etc/passwd', '/tmp', ['/tmp'], 'write') + assert not r.allowed + assert r.action == 'deny' + + def test_path_outside_allowed_read(self): + r = validate_path('/etc/passwd', '/tmp', ['/tmp'], 'read') + assert not r.allowed + assert r.action == 'ask' # read outside → ask, not deny + + def test_tilde_expansion(self): + home = os.path.expanduser('~') + r = validate_path('~/test.txt', '/tmp', [home], 'read') + assert r.allowed + + def test_tilde_user_rejected(self): + r = validate_path('~otheruser/file', '/tmp', ['/tmp'], 'read') + assert not r.allowed + assert 'Unsupported tilde expansion' in r.reason + + def test_tilde_plus_rejected(self): + r = validate_path('~+/file', '/tmp', ['/tmp'], 'read') + assert not r.allowed + + def test_shell_variable_rejected(self): + r = validate_path('$HOME/file', '/tmp', ['/tmp'], 'write') + assert not r.allowed + assert 'variable expansion' in r.reason + + def test_windows_variable_rejected(self): + r = validate_path('%TEMP%/file', '/tmp', ['/tmp'], 'write') + assert not r.allowed + + def test_zsh_equals_rejected(self): + r = validate_path('=ls', '/tmp', ['/tmp'], 'read') + assert not r.allowed + + def test_glob_in_write_rejected(self): + r = validate_path('*.txt', '/tmp', ['/tmp'], 'write') + assert not r.allowed + assert r.action == 'deny' + + def test_glob_in_read_uses_base_dir(self): + with tempfile.TemporaryDirectory() as td: + r = validate_path(os.path.join(td, '*.txt'), td, [td], 'read') + assert r.allowed + + def test_quoted_path(self): + with tempfile.TemporaryDirectory() as td: + r = validate_path(f'"{td}/test.txt"', td, [td], 'write') + assert r.allowed + + def test_multiple_allowed_dirs(self): + with tempfile.TemporaryDirectory() as td1: + with tempfile.TemporaryDirectory() as td2: + r = validate_path(os.path.join(td2, 'f'), td1, [td1, td2], 'write') + assert r.allowed + + +class TestReadOnlyDirectories: + def test_read_allowed_via_read_only_dir(self): + with tempfile.TemporaryDirectory() as write_dir: + with tempfile.TemporaryDirectory() as ro_dir: + r = validate_path( + os.path.join(ro_dir, 'data.csv'), write_dir, + [write_dir], 'read', read_only_dirs=[ro_dir], + ) + assert r.allowed + assert r.action == 'allow' + + def test_write_denied_in_read_only_dir(self): + with tempfile.TemporaryDirectory() as write_dir: + with tempfile.TemporaryDirectory() as ro_dir: + r = validate_path( + os.path.join(ro_dir, 'data.csv'), write_dir, + [write_dir], 'write', read_only_dirs=[ro_dir], + ) + assert not r.allowed + assert r.action == 'deny' + + def test_create_denied_in_read_only_dir(self): + with tempfile.TemporaryDirectory() as write_dir: + with tempfile.TemporaryDirectory() as ro_dir: + r = validate_path( + os.path.join(ro_dir, 'new.txt'), write_dir, + [write_dir], 'create', read_only_dirs=[ro_dir], + ) + assert not r.allowed + assert r.action == 'deny' + + def test_read_outside_both_dirs_returns_ask(self): + with tempfile.TemporaryDirectory() as write_dir: + with tempfile.TemporaryDirectory() as ro_dir: + r = validate_path( + '/etc/passwd', write_dir, + [write_dir], 'read', read_only_dirs=[ro_dir], + ) + assert not r.allowed + assert r.action == 'ask' + assert r.category == 'read_outside_dirs' + + def test_allowed_dir_takes_precedence_over_read_only(self): + with tempfile.TemporaryDirectory() as td: + r = validate_path( + os.path.join(td, 'file.txt'), td, + [td], 'read', read_only_dirs=[td], + ) + assert r.allowed + + def test_write_in_allowed_dir_still_works(self): + with tempfile.TemporaryDirectory() as td: + with tempfile.TemporaryDirectory() as ro_dir: + r = validate_path( + os.path.join(td, 'file.txt'), td, + [td], 'write', read_only_dirs=[ro_dir], + ) + assert r.allowed + + +class TestIsDangerousRemovalPath: + @pytest.mark.parametrize('path', [ + '/', + '*', + '/tmp/*', + '/usr', + os.path.expanduser('~'), + ]) + def test_dangerous_paths(self, path): + assert is_dangerous_removal_path(path) + + @pytest.mark.parametrize('path', [ + '/tmp/mydir', + '/usr/local/bin', + 'relative/path', + './test.txt', + ]) + def test_safe_paths(self, path): + assert not is_dangerous_removal_path(path) + + def test_windows_drive_root(self): + assert is_dangerous_removal_path('C:/') + assert is_dangerous_removal_path('D:\\') + + def test_windows_drive_child(self): + assert is_dangerous_removal_path('C:/Windows') + + def test_normalized_slashes(self): + assert is_dangerous_removal_path('///') + + +class TestGetGlobBaseDirectory: + def test_no_glob(self): + assert get_glob_base_directory('/tmp/test.txt') == '/tmp' + + def test_glob_at_end(self): + assert get_glob_base_directory('/tmp/*.txt') == '/tmp' + + def test_glob_in_middle(self): + assert get_glob_base_directory('/tmp/*/test.txt') == '/tmp' + + def test_relative_glob(self): + assert get_glob_base_directory('*.py') == '.' + + def test_root_glob(self): + assert get_glob_base_directory('/*') == '/' diff --git a/tests/permission/test_safety.py b/tests/permission/test_safety.py new file mode 100644 index 000000000..4970fa37c --- /dev/null +++ b/tests/permission/test_safety.py @@ -0,0 +1,305 @@ +"""Tests for SafetyGuard.""" + +import os +import tempfile + +import pytest + +from ms_agent.permission.config import SafetyConfig +from ms_agent.permission.safety import SafetyGuard + + +@pytest.fixture +def guard(tmp_path): + config = SafetyConfig() + return SafetyGuard(config=config, allowed_dirs=[str(tmp_path)]) + + +class TestSafetyRules: + def test_rm_rf_blocked(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': 'rm -rf /'}, + ) + assert r.action == 'deny' + + def test_mkfs_blocked(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': 'mkfs /dev/sda'}, + ) + assert r.action == 'deny' + + def test_dd_blocked(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': 'dd if=/dev/zero of=/dev/sda'}, + ) + assert r.action == 'deny' + + +class TestFilePathChecks: + def test_read_within_allowed(self, guard, tmp_path): + r = guard.check( + 'file_system---read_file', + {'path': str(tmp_path / 'test.txt')}, + ) + assert r.action == 'allow' + + def test_write_outside_allowed(self, guard): + r = guard.check( + 'file_system---write_file', + {'path': '/etc/passwd'}, + ) + assert r.action == 'deny' + + def test_edit_within_allowed(self, guard, tmp_path): + r = guard.check( + 'file_system---edit_file', + {'path': str(tmp_path / 'test.py')}, + ) + assert r.action == 'allow' + + def test_empty_path(self, guard): + r = guard.check('file_system---write_file', {'path': ''}) + assert r.action == 'deny' + + +class TestSensitivePaths: + def test_write_git_config(self, guard): + r = guard.check( + 'file_system---write_file', + {'path': '.git/config'}, + ) + assert r.action == 'deny' + + def test_write_ssh_key(self, guard): + home = os.path.expanduser('~') + r = guard.check( + 'file_system---write_file', + {'path': f'{home}/.ssh/id_rsa'}, + ) + assert r.action == 'deny' + + +class TestShellCommands: + def test_safe_command(self, guard, tmp_path): + r = guard.check( + 'code_executor---shell_executor', + {'command': f'ls {tmp_path}'}, + ) + assert r.action == 'allow' + + def test_empty_command(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': ''}, + ) + assert r.action == 'deny' + + +class TestUnknownTool: + def test_passthrough(self, guard): + r = guard.check('unknown---tool', {'arg': 'value'}) + assert r.action == 'allow' + + +class TestCustomConfig: + def test_custom_patterns_tool_level(self, tmp_path): + config = SafetyConfig( + patterns=('custom---dangerous_tool',), + ) + guard = SafetyGuard(config=config, allowed_dirs=[str(tmp_path)]) + r = guard.check('custom---dangerous_tool', {'arg': 'value'}) + assert r.action == 'deny' + + def test_custom_patterns_shell(self, tmp_path): + config = SafetyConfig( + patterns=('code_executor---shell_executor:curl *',), + ) + guard = SafetyGuard(config=config, allowed_dirs=[str(tmp_path)]) + r = guard.check( + 'code_executor---shell_executor', + {'command': 'curl https://evil.com'}, + ) + assert r.action == 'deny' + + +class TestProcessSubstitution: + """Process substitution split into input/output categories.""" + + def test_input_sub_category(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': 'diff <(sort a.txt) <(sort b.txt)'}, + ) + assert r.action == 'ask' + assert r.category == 'process_input_sub' + + def test_output_sub_category(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': 'echo secret > >(tee log.txt)'}, + ) + assert r.action == 'ask' + assert r.category == 'process_output_sub' + + def test_output_sub_takes_precedence(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': 'cat <(echo a) > >(tee b)'}, + ) + assert r.category == 'process_output_sub' + + +class TestReadOnlyDirs: + """SafetyGuard respects read_only_directories.""" + + def test_read_allowed_in_read_only_dir(self, tmp_path): + ro_dir = tmp_path / 'readonly' + ro_dir.mkdir() + config = SafetyConfig() + guard = SafetyGuard(config=config, allowed_dirs=[str(tmp_path)], read_only_dirs=[str(ro_dir)]) + # ro_dir is under tmp_path anyway; use a truly separate dir + import tempfile + with tempfile.TemporaryDirectory() as separate_ro: + guard2 = SafetyGuard(config=config, allowed_dirs=[str(tmp_path)], read_only_dirs=[separate_ro]) + r = guard2.check('file_system---read_file', {'path': f'{separate_ro}/data.csv'}) + assert r.action == 'allow' + + def test_write_denied_in_read_only_dir(self, tmp_path): + import tempfile + with tempfile.TemporaryDirectory() as separate_ro: + config = SafetyConfig() + guard = SafetyGuard(config=config, allowed_dirs=[str(tmp_path)], read_only_dirs=[separate_ro]) + r = guard.check('file_system---write_file', {'path': f'{separate_ro}/data.csv'}) + assert r.action == 'deny' + + def test_shell_read_in_read_only_dir(self, tmp_path): + import tempfile + with tempfile.TemporaryDirectory() as separate_ro: + config = SafetyConfig() + guard = SafetyGuard(config=config, allowed_dirs=[str(tmp_path)], read_only_dirs=[separate_ro]) + r = guard.check('code_executor---shell_executor', {'command': f'cat {separate_ro}/data.csv'}) + assert r.action == 'allow' + + def test_shell_write_in_read_only_dir(self, tmp_path): + import tempfile + with tempfile.TemporaryDirectory() as separate_ro: + config = SafetyConfig() + guard = SafetyGuard(config=config, allowed_dirs=[str(tmp_path)], read_only_dirs=[separate_ro]) + r = guard.check('code_executor---shell_executor', {'command': f'rm {separate_ro}/data.csv'}) + assert r.action == 'deny' + + +class TestCategoryPropagation: + """SafetyDecision carries category from validators.""" + + def test_parse_failure_category(self, guard): + r = guard.check( + 'code_executor---shell_executor', + {'command': "echo 'unterminated"}, + ) + assert r.action == 'ask' + assert r.category == 'parse_failure' + + def test_shell_expansion_category(self, guard): + r = guard.check( + 'file_system---read_file', + {'path': '$HOME/secrets.txt'}, + ) + assert r.action == 'ask' + assert r.category == 'shell_expansion' + + +class TestConfigParsing: + """SafetyConfig.from_dict parses read_only_directories.""" + + def test_read_only_directories_parsed(self, tmp_path): + d = {'read_only_directories': [str(tmp_path / 'data')]} + config = SafetyConfig.from_dict(d) + assert config.read_only_directories == (str(tmp_path / 'data'),) + + def test_read_only_directories_project_root(self): + d = {'read_only_directories': ['${PROJECT_ROOT}']} + config = SafetyConfig.from_dict(d, project_root='/my/project') + assert config.read_only_directories == ('/my/project',) + + def test_read_only_directories_default_empty(self): + config = SafetyConfig.from_dict({}) + assert config.read_only_directories == () + + def test_write_policy_removed(self): + assert not hasattr(SafetyConfig, 'write_policy') or 'write_policy' not in SafetyConfig.__dataclass_fields__ + + +class TestGrepGlobCoverage: + """SafetyGuard checks grep and glob path arguments.""" + + def test_grep_within_allowed(self, guard, tmp_path): + r = guard.check('file_system---grep', {'path': str(tmp_path / 'src')}) + assert r.action == 'allow' + + def test_grep_outside_allowed(self, guard): + r = guard.check('file_system---grep', {'path': '/etc'}) + assert r.action in ('deny', 'ask') + + def test_glob_within_allowed(self, guard, tmp_path): + r = guard.check('file_system---glob', {'path': str(tmp_path)}) + assert r.action == 'allow' + + def test_glob_outside_allowed(self, guard): + r = guard.check('file_system---glob', {'path': '/etc'}) + assert r.action in ('deny', 'ask') + + def test_grep_default_path(self, guard): + r = guard.check('file_system---grep', {'pattern': 'foo'}) + assert r.action in ('allow', 'ask') + + def test_glob_default_path(self, guard): + r = guard.check('file_system---glob', {'pattern': '*.py'}) + assert r.action in ('allow', 'ask') + + +class TestWorkspaceRoot: + """SafetyGuard respects workspace_root for relative path resolution.""" + + def test_relative_path_resolved_via_workspace_root(self, tmp_path): + config = SafetyConfig() + guard = SafetyGuard( + config=config, + allowed_dirs=[str(tmp_path)], + workspace_root=str(tmp_path), + ) + r = guard.check('file_system---read_file', {'path': 'test.txt'}) + assert r.action == 'allow' + + def test_relative_path_outside_workspace_root(self): + config = SafetyConfig() + guard = SafetyGuard( + config=config, + allowed_dirs=['/some/dir'], + workspace_root='/some/dir', + ) + r = guard.check('file_system---write_file', {'path': '../../etc/passwd'}) + assert r.action == 'deny' + + +class TestDefaultBlacklist: + """PermissionConfig includes default network command blacklist.""" + + def test_default_blacklist_contains_curl(self): + from ms_agent.permission.config import PermissionConfig + config = PermissionConfig() + assert any('curl' in p for p in config.blacklist) + + def test_default_blacklist_contains_wget(self): + from ms_agent.permission.config import PermissionConfig + config = PermissionConfig() + assert any('wget' in p for p in config.blacklist) + + def test_user_blacklist_merged(self): + from ms_agent.permission.config import PermissionConfig + config = PermissionConfig.from_dict({'blacklist': ['custom---tool']}) + assert any('curl' in p for p in config.blacklist) + assert 'custom---tool' in config.blacklist diff --git a/tests/permission/test_security_regression.py b/tests/permission/test_security_regression.py new file mode 100644 index 000000000..bde0dbb66 --- /dev/null +++ b/tests/permission/test_security_regression.py @@ -0,0 +1,111 @@ +"""Security regression tests — attack vectors from design doc Section 17.3.""" + +import os +import tempfile + +import pytest + +from ms_agent.permission.config import SafetyConfig +from ms_agent.permission.safety import SafetyGuard + + +@pytest.fixture +def guard(tmp_path): + config = SafetyConfig() + return SafetyGuard(config=config, allowed_dirs=[str(tmp_path)]) + + +class TestAttackVectors: + def test_rm_rf_root(self, guard): + """rm -rf / → deny (dangerous path)""" + r = guard.check('code_executor---shell_executor', {'command': 'rm -rf /'}) + assert r.action == 'deny' + + def test_timeout_rm_rf_root(self, guard): + """timeout 10 rm -rf / → deny (wrapper stripped, then dangerous path)""" + r = guard.check('code_executor---shell_executor', {'command': 'timeout 10 rm -rf /'}) + assert r.action == 'deny' + + def test_rm_double_dash_tricky(self, guard): + """rm -- -/../.claude/settings.json → deny (path outside allowed dirs)""" + r = guard.check( + 'code_executor---shell_executor', + {'command': 'rm -- -/../.claude/settings.json'}, + ) + assert r.action in ('deny', 'ask') + + def test_redirect_to_etc(self, guard): + """echo "x" > /etc/passwd → ask/deny (redirect to sensitive path)""" + r = guard.check( + 'code_executor---shell_executor', + {'command': 'echo "x" > /etc/passwd'}, + ) + assert r.action in ('deny', 'ask') + + def test_cd_plus_mv(self, guard, tmp_path): + """cd dir && mv a b → ask (cd + write compound)""" + r = guard.check( + 'code_executor---shell_executor', + {'command': f'cd {tmp_path} && mv {tmp_path}/a {tmp_path}/b'}, + ) + assert r.action == 'ask' + + def test_rm_dollar_home(self, guard): + """rm $HOME/.ssh/* → ask/deny (shell expansion in path)""" + r = guard.check( + 'code_executor---shell_executor', + {'command': 'rm $HOME/.ssh/*'}, + ) + assert r.action in ('deny', 'ask') + + def test_env_home_override(self, guard): + """env HOME=/tmp rm -rf ~ → HOME is unsafe, not stripped""" + r = guard.check( + 'code_executor---shell_executor', + {'command': 'env HOME=/tmp rm -rf ~'}, + ) + assert r.action == 'deny' + + def test_process_substitution(self, guard): + """echo secret > >(tee .git/config) → ask (process substitution)""" + r = guard.check( + 'code_executor---shell_executor', + {'command': 'echo secret > >(tee .git/config)'}, + ) + assert r.action == 'ask' + + def test_mv_target_directory(self, guard, tmp_path): + """mv --target-directory=/etc test.txt → ask (command validator)""" + r = guard.check( + 'code_executor---shell_executor', + {'command': f'mv --target-directory=/etc {tmp_path}/test.txt'}, + ) + assert r.action == 'ask' + + def test_sed_write_expression(self, guard, tmp_path): + """sed -e 's/x/y/w /etc/passwd' file → deny (sed expression safety)""" + r = guard.check( + 'code_executor---shell_executor', + {'command': f"sed -e 's/x/y/w /etc/passwd' {tmp_path}/file"}, + ) + assert r.action == 'deny' + + +class TestSensitivePathWrites: + def test_write_etc(self, guard): + r = guard.check('file_system---write_file', {'path': '/etc/hosts'}) + assert r.action == 'deny' + + def test_write_ssh(self, guard): + home = os.path.expanduser('~') + r = guard.check('file_system---write_file', {'path': f'{home}/.ssh/authorized_keys'}) + assert r.action == 'deny' + + def test_write_git_hooks(self, guard): + r = guard.check('file_system---write_file', {'path': '.git/hooks/pre-commit'}) + assert r.action == 'deny' + + def test_write_bashrc(self, guard): + home = os.path.expanduser('~') + r = guard.check('file_system---write_file', {'path': f'{home}/.bashrc'}) + assert r.action == 'deny' diff --git a/tests/permission/test_sed_validator.py b/tests/permission/test_sed_validator.py new file mode 100644 index 000000000..024cf3fe1 --- /dev/null +++ b/tests/permission/test_sed_validator.py @@ -0,0 +1,100 @@ +"""Tests for sed expression safety validator.""" + +import pytest + +from ms_agent.permission.sed_validator import ( + check_sed_expression_safety, + is_sed_read_only, +) + + +class TestIsSedReadOnly: + def test_print_only(self): + assert is_sed_read_only(['-n', 'p']) + + def test_address_print(self): + assert is_sed_read_only(['-n', '1,5p']) + + def test_no_n_flag(self): + assert not is_sed_read_only(['p']) + + def test_with_in_place(self): + assert not is_sed_read_only(['-n', '-i', 'p']) + + def test_substitution(self): + assert not is_sed_read_only(['-n', 's/a/b/']) + + def test_e_flag_print(self): + assert is_sed_read_only(['-n', '-e', 'p']) + + +class TestCheckSedExpressionSafety: + def test_safe_expression(self): + r = check_sed_expression_safety('s/foo/bar/') + assert r.safe + + def test_write_command(self): + r = check_sed_expression_safety('s/foo/bar/w /tmp/out') + assert not r.safe + assert 'w' in r.reason.lower() or 'Write' in r.reason + + def test_execute_command(self): + r = check_sed_expression_safety('s/foo/bar/e') + assert not r.safe + + def test_non_ascii(self): + r = check_sed_expression_safety('s/foö/bar/') + assert not r.safe + assert 'Non-ASCII' in r.reason + + def test_newline(self): + r = check_sed_expression_safety('s/foo/bar/\n') + assert not r.safe + + def test_curly_braces(self): + r = check_sed_expression_safety('{s/foo/bar/}') + assert not r.safe + + def test_negation(self): + r = check_sed_expression_safety('!d') + assert not r.safe + + def test_empty(self): + r = check_sed_expression_safety('') + assert r.safe + + +class TestArbitraryDelimiter: + """Substitution flag detection must work with any delimiter, not just '/'.""" + + def test_pipe_delimiter_write(self): + r = check_sed_expression_safety('s|foo|bar|w /tmp/out') + assert not r.safe + + def test_hash_delimiter_exec(self): + r = check_sed_expression_safety('s#foo#bar#e') + assert not r.safe + + def test_at_delimiter_gw(self): + r = check_sed_expression_safety('s@pat@rep@gw file') + assert not r.safe + + def test_pipe_delimiter_safe(self): + r = check_sed_expression_safety('s|foo|bar|g') + assert r.safe + + def test_escaped_delimiter_in_pattern(self): + r = check_sed_expression_safety('s/foo\\/bar/baz/w file') + assert not r.safe + + def test_escaped_delimiter_safe(self): + r = check_sed_expression_safety('s/foo\\/bar/baz/g') + assert r.safe + + def test_semicolon_chained(self): + r = check_sed_expression_safety('s|a|b|g;s|c|d|e') + assert not r.safe + + def test_semicolon_chained_safe(self): + r = check_sed_expression_safety('s|a|b|g;s|c|d|g') + assert r.safe diff --git a/tests/permission/test_shell_validator.py b/tests/permission/test_shell_validator.py new file mode 100644 index 000000000..2d775d0ca --- /dev/null +++ b/tests/permission/test_shell_validator.py @@ -0,0 +1,129 @@ +"""Tests for ShellPathValidator pipeline.""" + +import os +import tempfile + +import pytest + +from ms_agent.permission.shell_validator import ShellPathValidator + + +@pytest.fixture +def validator(tmp_path): + return ShellPathValidator(allowed_dirs=[str(tmp_path)]) + + +class TestBasicCommands: + def test_ls_allowed(self, validator, tmp_path): + r = validator.check(f'ls {tmp_path}') + assert r.action == 'allow' + + def test_cat_allowed(self, validator, tmp_path): + r = validator.check(f'cat {tmp_path}/test.txt') + assert r.action == 'allow' + + def test_empty_command(self, validator): + r = validator.check('') + assert r.action == 'deny' + + def test_long_command(self, validator): + r = validator.check('a' * 9000) + assert r.action == 'deny' + + +class TestDangerousCommands: + def test_rm_rf_root(self, validator): + r = validator.check('rm -rf /') + assert r.action == 'deny' + + def test_rm_star(self, validator): + r = validator.check('rm *') + assert r.action == 'deny' + + def test_rm_within_allowed(self, validator, tmp_path): + r = validator.check(f'rm {tmp_path}/test.txt') + assert r.action == 'allow' + + +class TestWrapperStripping: + def test_timeout_rm(self, validator): + r = validator.check('timeout 10 rm -rf /') + assert r.action == 'deny' + + def test_nice_rm(self, validator): + r = validator.check('nice -10 rm -rf /') + assert r.action == 'deny' + + def test_nohup_rm(self, validator): + r = validator.check('nohup rm -rf /') + assert r.action == 'deny' + + +class TestCompoundCommands: + def test_cd_plus_write(self, validator, tmp_path): + r = validator.check(f'cd {tmp_path} && rm {tmp_path}/test.txt') + assert r.action == 'ask' + + def test_multiple_safe(self, validator, tmp_path): + r = validator.check(f'ls {tmp_path} && cat {tmp_path}/f') + assert r.action == 'allow' + + +class TestRedirects: + def test_redirect_within_allowed(self, validator, tmp_path): + r = validator.check(f'echo hello > {tmp_path}/out.txt') + assert r.action == 'allow' + + def test_redirect_to_dev_null(self, validator): + r = validator.check('echo hello > /dev/null') + assert r.action == 'allow' + + def test_redirect_with_variable(self, validator): + r = validator.check('echo hello > $HOME/file') + assert r.action == 'deny' + + +class TestProcessSubstitution: + def test_output_substitution(self, validator): + r = validator.check('echo secret > >(tee .git/config)') + assert r.action == 'ask' + + def test_input_substitution(self, validator): + r = validator.check('diff <(cat a) <(cat b)') + assert r.action == 'ask' + + +class TestPathOutsideAllowed: + def test_write_outside(self, validator): + r = validator.check('touch /etc/test') + assert r.action in ('deny', 'ask') + + def test_read_outside(self, validator): + r = validator.check('cat /etc/passwd') + assert r.action == 'ask' + + +class TestShellExpansion: + def test_variable_in_path(self, validator): + r = validator.check('rm $HOME/.ssh/key') + assert r.action in ('deny', 'ask') + + def test_env_var_rm(self, validator): + r = validator.check('rm ${TMPDIR}/file') + assert r.action in ('deny', 'ask') + + +class TestMvCpValidator: + def test_mv_with_flags(self, validator, tmp_path): + r = validator.check(f'mv -t /dst {tmp_path}/file') + assert r.action == 'ask' + + def test_mv_simple(self, validator, tmp_path): + r = validator.check(f'mv {tmp_path}/a {tmp_path}/b') + assert r.action == 'allow' + + +class TestUnregisteredCommand: + def test_unknown_passthrough(self, validator): + r = validator.check('someunknowncommand arg1 arg2') + assert r.action == 'allow' diff --git a/tests/permission/test_suggestions.py b/tests/permission/test_suggestions.py new file mode 100644 index 000000000..5458cb7f3 --- /dev/null +++ b/tests/permission/test_suggestions.py @@ -0,0 +1,50 @@ +"""Tests for generate_suggestions.""" + +from ms_agent.permission.suggestions import generate_suggestions + + +class TestShellSuggestions: + def test_plain_command(self): + suggestions = generate_suggestions( + 'code_executor---shell_executor', + {'command': 'ls -la'}, + ) + assert suggestions[0] == 'code_executor---shell_executor:ls *' + assert 'code_executor---shell_executor' in suggestions + + def test_strips_timeout_wrapper(self): + suggestions = generate_suggestions( + 'code_executor---shell_executor', + {'command': 'timeout 10 ls -la'}, + ) + assert suggestions[0] == 'code_executor---shell_executor:ls *' + + def test_strips_nice_wrapper(self): + suggestions = generate_suggestions( + 'code_executor---shell_executor', + {'command': 'nice -n 10 pip install requests'}, + ) + assert suggestions[0] == 'code_executor---shell_executor:pip *' + + def test_empty_command(self): + suggestions = generate_suggestions( + 'code_executor---shell_executor', + {'command': ''}, + ) + assert suggestions == ['code_executor---shell_executor'] + + +class TestOtherTools: + def test_file_system(self): + suggestions = generate_suggestions( + 'file_system---read_file', + {'path': '/src/main.py'}, + ) + assert suggestions == ['file_system---read_file'] + + def test_web_search(self): + suggestions = generate_suggestions( + 'web_search---search', + {'query': 'test'}, + ) + assert suggestions[0] == 'web_search---*' diff --git a/tests/permission/test_wrapper_strip.py b/tests/permission/test_wrapper_strip.py new file mode 100644 index 000000000..f4e847167 --- /dev/null +++ b/tests/permission/test_wrapper_strip.py @@ -0,0 +1,69 @@ +"""Tests for safe wrapper stripping.""" + +import pytest + +from ms_agent.permission.wrapper_strip import strip_safe_wrappers + + +class TestStripSafeWrappers: + def test_timeout(self): + assert strip_safe_wrappers(['timeout', '10', 'ls', '-la']) == ['ls', '-la'] + + def test_timeout_with_flags(self): + assert strip_safe_wrappers(['timeout', '--foreground', '10', 'ls']) == ['ls'] + + def test_timeout_kill_after(self): + assert strip_safe_wrappers(['timeout', '-k', '5', '10', 'ls']) == ['ls'] + + def test_time(self): + assert strip_safe_wrappers(['time', 'ls', '-la']) == ['ls', '-la'] + + def test_nice_bare(self): + assert strip_safe_wrappers(['nice', 'ls']) == ['ls'] + + def test_nice_traditional(self): + assert strip_safe_wrappers(['nice', '-10', 'ls']) == ['ls'] + + def test_nice_posix(self): + assert strip_safe_wrappers(['nice', '-n', '10', 'ls']) == ['ls'] + + def test_nohup(self): + assert strip_safe_wrappers(['nohup', 'cat', 'file']) == ['cat', 'file'] + + def test_stdbuf(self): + assert strip_safe_wrappers(['stdbuf', '-o0', 'cat', 'file']) == ['cat', 'file'] + + def test_env_simple(self): + assert strip_safe_wrappers(['env', 'ls']) == ['ls'] + + def test_env_with_assignment(self): + assert strip_safe_wrappers(['env', 'FOO=bar', 'ls']) == ['ls'] + + def test_env_unsafe_flag_S(self): + result = strip_safe_wrappers(['env', '-S', 'something', 'ls']) + assert result == ['env', '-S', 'something', 'ls'] + + def test_env_unsafe_flag_C(self): + result = strip_safe_wrappers(['env', '-C', '/tmp', 'ls']) + assert result == ['env', '-C', '/tmp', 'ls'] + + def test_safe_env_var(self): + assert strip_safe_wrappers(['NODE_ENV=production', 'ls']) == ['ls'] + + def test_unsafe_env_var(self): + result = strip_safe_wrappers(['HOME=/tmp', 'rm', 'file']) + assert result == ['HOME=/tmp', 'rm', 'file'] + + def test_chained_wrappers(self): + result = strip_safe_wrappers(['timeout', '10', 'nice', '-5', 'ls']) + assert result == ['ls'] + + def test_env_var_then_wrapper(self): + result = strip_safe_wrappers(['NODE_ENV=test', 'timeout', '10', 'ls']) + assert result == ['ls'] + + def test_empty(self): + assert strip_safe_wrappers([]) == [] + + def test_no_wrapper(self): + assert strip_safe_wrappers(['ls', '-la']) == ['ls', '-la'] diff --git a/tests/plugins/test_installer_local.py b/tests/plugins/test_installer_local.py new file mode 100644 index 000000000..b6cd39077 --- /dev/null +++ b/tests/plugins/test_installer_local.py @@ -0,0 +1,218 @@ +import json +import os +import subprocess + +import pytest + +from ms_agent.plugins.config_manager import PluginConfigManager +from ms_agent.plugins.installer import PluginInstaller, normalize_install_source, resolve_marketplace_plugin_uri + + +def _sample_plugin(root): + (root / '.claude-plugin').mkdir(parents=True) + (root / '.claude-plugin' / 'plugin.json').write_text( + json.dumps({'name': 'local-demo', 'version': '0.1.0'}), + encoding='utf-8', + ) + skill = root / 'skills' / 'writer' + skill.mkdir(parents=True) + (skill / 'SKILL.md').write_text( + '---\nname: Writer\ndescription: Write better text.\n---\n', + encoding='utf-8', + ) + + +def test_normalize_marketplace_alias_to_github_uri(monkeypatch): + payload = { + 'plugins': [ + {'name': 'hookify', 'source': './plugins/hookify'}, + ], + } + + class FakeResponse: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def read(self): + return json.dumps(payload).encode('utf-8') + + monkeypatch.setattr( + 'ms_agent.plugins.installer.urlopen', + lambda url, timeout=30: FakeResponse(), + ) + + assert normalize_install_source('hookify@claude-plugins-official') == ( + 'github://anthropics/claude-plugins-official@main#plugins/hookify' + ) + + +def test_resolve_marketplace_plugin_uri_uses_index(monkeypatch): + payload = { + 'plugins': [ + {'name': 'hookify', 'source': './plugins/hookify'}, + ], + } + + class FakeResponse: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def read(self): + return json.dumps(payload).encode('utf-8') + + monkeypatch.setattr( + 'ms_agent.plugins.installer.urlopen', + lambda url, timeout=30: FakeResponse(), + ) + + uri = resolve_marketplace_plugin_uri('hookify', 'claude-plugins-official') + assert uri == 'github://anthropics/claude-plugins-official@main#plugins/hookify' + + +def test_install_local_copies_and_locks_manifest(tmp_path): + source = tmp_path / 'source-plugin' + _sample_plugin(source) + global_dir = tmp_path / '.ms_agent' + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + + manifest = installer.install(str(source), scope='global') + + record = manager.get('local-demo', scope='global') + assert manifest.plugin_id == 'local-demo' + assert record is not None + assert record.manifest_path == '.claude-plugin/plugin.json' + assert record.format == 'claude' + assert record.enabled is True + assert (global_dir / 'plugins' / 'local-demo' / 'skills' / 'writer' / 'SKILL.md').is_file() + + +def test_install_uses_manifest_default_enabled(tmp_path): + source = tmp_path / 'source-plugin' + _sample_plugin(source) + (source / '.claude-plugin' / 'plugin.json').write_text( + json.dumps({ + 'name': 'local-demo', + 'version': '0.1.0', + 'defaultEnabled': False, + }), + encoding='utf-8', + ) + global_dir = tmp_path / '.ms_agent' + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + + installer.install(str(source), scope='global') + + assert manager.get('local-demo', scope='global').enabled is False + + +def test_install_github_uri_uses_sparse_checkout(tmp_path, monkeypatch): + global_dir = tmp_path / '.ms_agent' + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + + def fake_run(cmd, check, capture_output=True, text=True): + if cmd[:3] == ['git', 'clone', '--depth']: + clone_root = cmd[-1] + plugin = tmp_path / clone_root / 'plugins' / 'local-demo' + _sample_plugin(plugin) + return subprocess.CompletedProcess(cmd, 0, stdout='abc123\n', stderr='') + + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(subprocess, 'run', fake_run) + + manifest = installer.install( + 'github://owner/repo@main#plugins/local-demo', + scope='global', + ) + + assert manifest.plugin_id == 'local-demo' + record = manager.get('local-demo', scope='global') + assert record.source.type == 'github' + assert record.source.uri == 'github://owner/repo@main#plugins/local-demo' + + +def test_install_github_uri_with_commit_sha(tmp_path, monkeypatch): + global_dir = tmp_path / '.ms_agent' + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + sha = 'a' * 40 + calls = [] + + def fake_run(cmd, check, capture_output=True, text=True): + calls.append(cmd) + if cmd[:3] == ['git', 'clone', '--depth']: + clone_root = cmd[-1] + plugin = tmp_path / clone_root / 'plugins' / 'local-demo' + _sample_plugin(plugin) + return subprocess.CompletedProcess(cmd, 0, stdout=f'{sha}\n', stderr='') + + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(subprocess, 'run', fake_run) + + manifest = installer.install( + f'github://owner/repo@{sha}#plugins/local-demo', + scope='global', + ) + + assert manifest.plugin_id == 'local-demo' + clone_dir = calls[0][-1] + assert '--branch' not in calls[0] + assert calls[1] == ['git', '-C', clone_dir, 'fetch', 'origin', sha] + assert calls[2] == ['git', '-C', clone_dir, 'checkout', sha] + + +def test_install_modelscope_uri_uses_snapshot_download(tmp_path, monkeypatch): + source = tmp_path / 'downloaded' + _sample_plugin(source) + global_dir = tmp_path / '.ms_agent' + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + + def fake_snapshot_download(model_id, revision=None): + assert model_id == 'org/pack' + assert revision == 'v1' + return str(source) + + import ms_agent.plugins.installer as installer_mod + monkeypatch.setattr(installer_mod, 'snapshot_download', fake_snapshot_download) + + manifest = installer.install('modelscope://org/pack@v1', scope='global') + + assert manifest.plugin_id == 'local-demo' + assert manager.get('local-demo', scope='global').source.type == 'modelscope' + + +def test_publish_staged_install_restores_broken_symlink_on_failure( + tmp_path, + monkeypatch, +): + target = tmp_path / '.ms_agent' / 'plugins' / 'local-demo' + staging_root = target.parent / '.staging' + staging_root.mkdir(parents=True) + target.parent.mkdir(parents=True, exist_ok=True) + target.symlink_to(tmp_path / 'missing-source', target_is_directory=True) + staged = staging_root / 'staged-local-demo' + staged.mkdir() + + original_rename = type(staged).rename + + def fail_staged_rename(self, target_path): + if self == staged: + raise RuntimeError('publish failed') + return original_rename(self, target_path) + + monkeypatch.setattr(type(staged), 'rename', fail_staged_rename) + + with pytest.raises(RuntimeError): + PluginInstaller._publish_staged_install(staged, target) + + assert target.is_symlink() + assert os.readlink(target) == str(tmp_path / 'missing-source') diff --git a/tests/plugins/test_p1_features.py b/tests/plugins/test_p1_features.py new file mode 100644 index 000000000..547b617a2 --- /dev/null +++ b/tests/plugins/test_p1_features.py @@ -0,0 +1,283 @@ +import json +import tarfile + +import pytest + +from ms_agent.hooks.executors.command import HookExecutionContext, build_hook_env +from ms_agent.plugins.agents import AgentDelegate, PluginAgentRegistry +from ms_agent.plugins.dependencies import PluginDependencyError, version_satisfies +from ms_agent.plugins.installer import ( + PluginInstaller, + normalize_install_source, + resolve_ms_agent_uri, +) +from ms_agent.plugins.loader import PluginLoadContext, PluginLoader +from ms_agent.plugins.manifest import PluginManifest +from ms_agent.plugins.registry import PluginRegistry +from ms_agent.plugins.runtime import PluginRuntime +from ms_agent.plugins.types import AgentDef +from ms_agent.plugins.user_config import save_user_config, validate_values +from ms_agent.skill.catalog import SkillCatalog +from ms_agent.skill.sources import SkillSource, SkillSourceType + + +def _basic_plugin(root, *, user_config=None, dependencies=None): + (root / '.claude-plugin').mkdir(parents=True) + manifest = { + 'name': 'p1-demo', + 'version': '1.0.0', + } + if user_config: + manifest['userConfig'] = user_config + if dependencies: + manifest['dependencies'] = dependencies + (root / '.claude-plugin' / 'plugin.json').write_text( + json.dumps(manifest), + encoding='utf-8', + ) + skill = root / 'skills' / 'writer' + skill.mkdir(parents=True) + (skill / 'SKILL.md').write_text( + '---\nname: Writer\ndescription: Write better text.\n---\n', + encoding='utf-8', + ) + + +def test_ms_agent_uri_resolves_inner_source(): + from urllib.parse import quote + inner = 'github://anthropics/claude-plugins-official@main#plugins/hookify' + uri = f'ms-agent://plugin/install?source={quote(inner, safe="")}' + assert resolve_ms_agent_uri(uri) == inner + assert normalize_install_source(uri) == inner + + +def test_tarball_install(tmp_path): + source = tmp_path / 'source-plugin' + _basic_plugin(source) + archive = tmp_path / 'plugin.tar.gz' + with tarfile.open(archive, 'w:gz') as tar: + tar.add(source, arcname='p1-demo') + + global_dir = tmp_path / '.ms_agent' + from ms_agent.plugins.config_manager import PluginConfigManager + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + manifest = installer.install(str(archive), scope='global') + assert manifest.plugin_id == 'p1-demo' + + +def test_dependencies_install_order(tmp_path): + base = tmp_path / 'base-plugin' + child = tmp_path / 'child-plugin' + _basic_plugin(base) + (base / '.claude-plugin' / 'plugin.json').write_text( + json.dumps({'name': 'base-plugin', 'version': '1.0.0'}), + encoding='utf-8', + ) + _basic_plugin(child) + (child / '.claude-plugin' / 'plugin.json').write_text( + json.dumps({ + 'name': 'child-plugin', + 'version': '1.0.0', + 'dependencies': [{ + 'name': 'base-plugin', + 'version': '~1.0.0', + 'source': str(base), + }], + }), + encoding='utf-8', + ) + + global_dir = tmp_path / '.ms_agent' + from ms_agent.plugins.config_manager import PluginConfigManager + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + installer.install(str(child), scope='global') + assert manager.get('base-plugin', scope='global') is not None + assert manager.get('child-plugin', scope='global') is not None + + +def test_missing_dependency_without_source_raises(tmp_path): + root = tmp_path / 'needs-dep' + _basic_plugin(root) + (root / '.claude-plugin' / 'plugin.json').write_text( + json.dumps({ + 'name': 'needs-dep', + 'version': '1.0.0', + 'dependencies': [{'name': 'missing-plugin', 'version': '1.0.0'}], + }), + encoding='utf-8', + ) + global_dir = tmp_path / '.ms_agent' + from ms_agent.plugins.config_manager import PluginConfigManager + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + with pytest.raises(PluginDependencyError): + installer.install(str(root), scope='global') + + +def test_version_satisfies_tilde(): + assert version_satisfies('1.0.5', '~1.0.0') + assert not version_satisfies('2.0.0', '~1.0.0') + + +def test_version_satisfies_v_prefix(): + assert version_satisfies('v1.0.5', '~v1.0.0') + assert version_satisfies('V1.0.5', 'v1.0.5') + assert not version_satisfies('v2.0.0', '~v1.0.0') + + +def test_commands_merge_into_skill_catalog(tmp_path): + root = tmp_path / 'cmd-plugin' + _basic_plugin(root) + commands = root / 'commands' + commands.mkdir() + (commands / 'deploy.md').write_text( + '---\nname: deploy\ndescription: Deploy the app.\n---\nRun deploy.', + encoding='utf-8', + ) + manifest = PluginManifest.parse(root) + result = PluginLoader.load( + manifest, + PluginLoadContext( + project_path=str(tmp_path), + session_id='s1', + enabled_executors=frozenset({'command'}), + plugin_data_root=tmp_path / 'data', + ), + ) + command_sources = [ + source for source in result.skill_sources + if source.capability == 'commands' + ] + assert len(command_sources) == 1 + catalog = SkillCatalog() + catalog.load_from_sources(command_sources) + assert 'p1-demo:deploy' in catalog.get_enabled_skills() + + +def test_agent_md_subdirectory(tmp_path): + root = tmp_path / 'agent-plugin' + _basic_plugin(root) + agent_dir = root / 'agents' / 'reviewer' + agent_dir.mkdir(parents=True) + (agent_dir / 'AGENT.md').write_text( + '---\nname: reviewer\ndescription: Review code.\n---\nYou review.', + encoding='utf-8', + ) + manifest = PluginManifest.parse(root) + result = PluginLoader.load( + manifest, + PluginLoadContext( + project_path=str(tmp_path), + session_id='s1', + enabled_executors=frozenset({'command'}), + plugin_data_root=tmp_path / 'data', + ), + ) + assert any(agent.name == 'reviewer' for agent in result.agent_defs) + + +def test_build_hook_env_includes_session_id(): + env = build_hook_env(HookExecutionContext( + session_id='session-123', + project_path='/tmp/project', + plugin_root='/tmp/plugin', + plugin_data_dir='/tmp/data', + )) + assert env['MS_AGENT_SESSION_ID'] == 'session-123' + + +def test_plugin_registry_managed_paths(tmp_path): + global_dir = tmp_path / '.ms_agent' + from ms_agent.plugins.config_manager import PluginConfigManager + from ms_agent.plugins.installer import PluginInstaller + source = tmp_path / 'source-plugin' + _basic_plugin(source) + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + installer.install(str(source), scope='global') + registry = PluginRegistry(manager) + assert 'p1-demo' in registry.managed_plugin_ids() + + +def test_resolve_task_entry_maps_builtin_single_agent(): + registry = PluginAgentRegistry() + registry.rebuild([ + AgentDef( + plugin_id='hookify', + name='conversation-analyzer', + path='/tmp/agents/conversation-analyzer.md', + description='Analyze conversation', + ), + ]) + entry = AgentDelegate.resolve_task_entry( + registry, + {'subagent_type': 'general-purpose', 'prompt': 'hi'}, + ) + assert entry is not None + assert entry.defn.name == 'conversation-analyzer' + + +def test_user_config_save_and_validate(tmp_path): + schema = { + 'mode': {'type': 'string', 'title': 'Mode'}, + 'strict': {'type': 'boolean', 'title': 'Strict'}, + } + errors = validate_values(schema, {'mode': 'safe', 'strict': True}) + assert errors == [] + saved = save_user_config(tmp_path / 'data', schema, {'mode': 'safe', 'strict': True}) + assert saved['mode'] == 'safe' + + +def test_example_plugin_mcp_fixture(tmp_path): + root = tmp_path / 'example-plugin' + _basic_plugin(root) + (root / '.mcp.json').write_text( + json.dumps({ + 'mcpServers': { + 'example': { + 'type': 'http', + 'url': 'http://127.0.0.1:9999/mcp', + }, + }, + }), + encoding='utf-8', + ) + (root / 'commands').mkdir() + (root / 'commands' / 'hello.md').write_text( + '---\nname: hello\ndescription: Say hello.\n---\nHello.', + encoding='utf-8', + ) + manifest = PluginManifest.parse(root) + result = PluginLoader.load( + manifest, + PluginLoadContext( + project_path=str(tmp_path), + session_id='s1', + enabled_executors=frozenset({'command'}), + plugin_data_root=tmp_path / 'data', + ), + ) + assert 'example' in result.mcp_servers + assert result.mcp_servers['example']['source'] == 'plugin' + + +def test_runtime_user_config_roundtrip(tmp_path): + root = tmp_path / 'cfg-plugin' + _basic_plugin( + root, + user_config={'mode': {'type': 'string', 'title': 'Mode'}}, + ) + global_dir = tmp_path / '.ms_agent' + from ms_agent.plugins.config_manager import PluginConfigManager + from ms_agent.plugins.installer import PluginInstaller + manager = PluginConfigManager(global_dir=global_dir) + installer = PluginInstaller(config_manager=manager, global_root=global_dir) + installer.install(str(root), scope='global') + runtime = PluginRuntime(config_manager=manager, global_root=global_dir) + runtime.start_sync(str(tmp_path), 'test') + saved = runtime.set_user_config('p1-demo', {'mode': 'strict'}) + assert saved['values']['mode'] == 'strict' + loaded = runtime.get_user_config('p1-demo') + assert loaded['values']['mode'] == 'strict' diff --git a/tests/test_hooks.py b/tests/test_hooks.py new file mode 100644 index 000000000..e6d879320 --- /dev/null +++ b/tests/test_hooks.py @@ -0,0 +1,264 @@ +"""Unit tests for the hooks system.""" + +from __future__ import annotations + +import asyncio +import os +import stat +from pathlib import Path + +import pytest + +from ms_agent.hooks.events import HookResult +from ms_agent.hooks.executor import HookExecutor +from ms_agent.hooks.executors.command import HookExecutionContext +from ms_agent.hooks.permission_resolve import resolve_hook_permission_decision +from ms_agent.hooks.registry import HookRegistry +from ms_agent.hooks.response_adapter import ResponseAdapter +from ms_agent.permission.config import PermissionConfig +from ms_agent.permission.enforcer import PermissionEnforcer +from ms_agent.utils.pattern_matcher import match_pattern + +FIXTURES = Path(__file__).parent / 'fixtures' / 'hooks' + + +class TestPatternMatcher: + def test_wildcard(self): + assert match_pattern('file_system---*', 'file_system---read_file') + assert match_pattern('read_file|write_file', 'read_file') + assert not match_pattern('a', 'b') + + def test_empty_pattern(self): + assert not match_pattern('', 'anything') + + +class TestHookRegistry: + def test_from_dict(self): + reg = HookRegistry.from_dict({ + 'PreToolUse': [{ + 'matcher': 'code_executor---*', + 'hooks': [{'type': 'command', 'command': './hook.sh'}], + }], + }) + handlers = reg.get_handlers('PreToolUse', 'code_executor---shell_executor') + assert len(handlers) == 1 + assert handlers[0].command == './hook.sh' + + def test_unknown_event_warning(self): + reg = HookRegistry.from_dict({'UnknownEvent': []}) + assert reg.is_empty + + def test_merge(self): + a = HookRegistry.from_dict({ + 'Stop': [{'hooks': [{'type': 'command', 'command': 'a.sh'}]}], + }) + b = HookRegistry.from_dict({ + 'Stop': [{'hooks': [{'type': 'command', 'command': 'b.sh'}]}], + }) + merged = a.merge(b) + handlers = merged.get_handlers('Stop') + assert [h.command for h in handlers] == ['a.sh', 'b.sh'] + + def test_non_tool_event_no_matcher(self): + reg = HookRegistry.from_dict({ + 'SessionStart': [{'matcher': 'ignored', 'hooks': [ + {'type': 'command', 'command': 'init.sh'}, + ]}], + }) + assert len(reg.get_handlers('SessionStart')) == 1 + + def test_tool_event_requires_tool_name_for_matcher(self): + reg = HookRegistry.from_dict({ + 'PreToolUse': [{ + 'matcher': 'code_executor---*', + 'hooks': [{'type': 'command', 'command': './hook.sh'}], + }], + }) + assert reg.get_handlers('PreToolUse', None) == [] + assert len(reg.get_handlers('PreToolUse', 'code_executor---shell_executor')) == 1 + + def test_skips_disabled_executor_types(self): + reg = HookRegistry.from_dict({ + 'PreToolUse': [{ + 'hooks': [{'type': 'http', 'url': 'https://example.com/hook'}], + }], + }, enabled_executors=frozenset({'command'})) + assert reg.is_empty + + +class TestResponseAdapter: + def test_canonical_deny(self): + r = ResponseAdapter().parse('{"decision": "deny", "reason": "no"}') + assert r.action == 'deny' + + def test_claude_permission_decision(self): + r = ResponseAdapter().parse( + '{"hookSpecificOutput": {"permissionDecision": "allow"}}') + assert r.action == 'allow' + + def test_updated_args_only_passthrough(self): + r = ResponseAdapter().parse('{"updatedArgs": {"command": "ls"}}') + assert r.action == 'pass' + assert r.updated_args == {'command': 'ls'} + + def test_cursor_permission_deny(self): + r = ResponseAdapter().parse( + '{"permission": "deny", "user_message": "nope"}') + assert r.action == 'deny' + assert r.reason == 'nope' + + +class TestHookExecutor: + @pytest.fixture + def executor(self, tmp_path): + return HookExecutor(working_dir=str(tmp_path)) + + def test_env_includes_plugin_data_aliases(self, tmp_path): + from ms_agent.hooks.executors.command import build_hook_env + plugin_root = tmp_path / 'plugin' + plugin_data = tmp_path / 'data' + ctx = HookExecutionContext( + session_id='s1', + project_path=str(tmp_path), + plugin_root=str(plugin_root), + plugin_data_dir=str(plugin_data), + ) + env = build_hook_env(ctx) + assert env['MS_AGENT_PLUGIN_ROOT'] == str(plugin_root) + assert env['CLAUDE_PLUGIN_ROOT'] == str(plugin_root) + assert env['MS_AGENT_PLUGIN_DATA'] == str(plugin_data) + assert env['CLAUDE_PLUGIN_DATA'] == str(plugin_data) + + @pytest.mark.asyncio + async def test_pass_script(self, executor, tmp_path): + script = FIXTURES / 'pass.py' + os.chmod(script, script.stat().st_mode | stat.S_IEXEC) + from ms_agent.hooks.registry import HookHandlerConfig + handler = HookHandlerConfig(type='command', command=f'python3 {script}') + ctx = HookExecutionContext(session_id='s1', project_path=str(tmp_path)) + result = await executor.execute( + handler, + {'event': 'PreToolUse', 'tool_name': 't', 'tool_args': {}}, + ctx, + ) + assert result.action == 'pass' + + @pytest.mark.asyncio + async def test_deny_script(self, executor, tmp_path): + script = FIXTURES / 'deny.py' + handler = __import__('ms_agent.hooks.registry', fromlist=['HookHandlerConfig']).HookHandlerConfig( + type='command', command=f'python3 {script}') + ctx = HookExecutionContext(session_id='s1', project_path=str(tmp_path)) + result = await executor.execute( + handler, + {'event': 'PreToolUse'}, + ctx, + ) + assert result.action == 'deny' + + @pytest.mark.asyncio + async def test_exit_2_block(self, executor, tmp_path): + script = FIXTURES / 'block.sh' + os.chmod(script, script.stat().st_mode | stat.S_IEXEC) + from ms_agent.hooks.registry import HookHandlerConfig + handler = HookHandlerConfig(type='command', command=f'bash {script}') + ctx = HookExecutionContext(session_id='s1', project_path=str(tmp_path)) + result = await executor.execute(handler, {'event': 'PreToolUse'}, ctx) + assert result.action == 'deny' + + @pytest.mark.asyncio + async def test_execute_all_deny_short_circuit(self, executor, tmp_path): + deny = FIXTURES / 'deny.py' + allow = FIXTURES / 'allow.py' + from ms_agent.hooks.registry import HookHandlerConfig + handlers = [ + HookHandlerConfig(type='command', command=f'python3 {deny}'), + HookHandlerConfig(type='command', command=f'python3 {allow}'), + ] + ctx = HookExecutionContext(session_id='s1', project_path=str(tmp_path)) + result = await executor.execute_all( + handlers, {'event': 'PreToolUse'}, blockable=True, ctx=ctx) + assert result.action == 'deny' + + +class TestResolveHookPermission: + @pytest.mark.asyncio + async def test_hook_deny(self): + out = await resolve_hook_permission_decision( + HookResult(action='deny', reason='no'), + 't', {}, + permission_enforcer=None, + permission_config=None, + ) + assert isinstance(out, str) + assert 'Blocked by hook' in out + + @pytest.mark.asyncio + async def test_hook_allow_with_blacklist(self): + config = PermissionConfig( + mode='interactive', + blacklist=('code_executor---shell_executor:curl *',), + ) + enforcer = PermissionEnforcer(config=config) + out = await resolve_hook_permission_decision( + HookResult(action='allow'), + 'code_executor---shell_executor', + {'command': 'curl http://evil.com'}, + permission_enforcer=enforcer, + permission_config=config, + ) + assert out.action == 'deny' + + @pytest.mark.asyncio + async def test_pass_goes_to_enforcer(self): + config = PermissionConfig(mode='interactive') + enforcer = PermissionEnforcer(config=config) + out = await resolve_hook_permission_decision( + HookResult(action='pass'), + 'file_system---read_file', + {'path': '/tmp/x'}, + permission_enforcer=enforcer, + permission_config=config, + ) + assert out.action == 'allow' + + @pytest.mark.asyncio + async def test_hook_allow_with_ask_rule(self): + from ms_agent.hooks.permission_resolve import check_rule_based_permissions + + config = PermissionConfig( + mode='interactive', + blacklist=(), + ask_rules=('file_system---read_file:/secret/*',), + ) + rule = await check_rule_based_permissions( + 'file_system---read_file', + {'path': '/secret/data.txt'}, + config, + ) + assert rule is not None + + +class TestPluginHookPayloadCompat: + def test_plugin_compat_payload_uses_claude_tool_name(self): + from ms_agent.hooks.executors.command import ( + HookExecutionContext, + plugin_compat_payload, + ) + + ctx = HookExecutionContext( + session_id='s1', + project_path='/tmp/project', + plugin_root='/tmp/plugins/hookify', + ) + payload = plugin_compat_payload( + { + 'event': 'PreToolUse', + 'tool_name': 'code_executor---shell_executor', + 'tool_name_claude': 'Bash', + 'tool_args': {'command': 'rm -rf /'}, + }, + ctx, + ) + assert payload['tool_name'] == 'Bash' + assert payload['hook_event_name'] == 'PreToolUse' diff --git a/tests/test_hooks_context.py b/tests/test_hooks_context.py new file mode 100644 index 000000000..d757e498b --- /dev/null +++ b/tests/test_hooks_context.py @@ -0,0 +1,53 @@ +"""Tests for hook context helpers.""" + +from ms_agent.hooks.context import ( + apply_hook_result_to_messages, + condense_hook_attachments_for_llm, + extract_latest_user_prompt, + HookAttachment, +) +from ms_agent.hooks.events import HookResult +from ms_agent.llm.utils import Message + + +class TestContext: + def test_extract_latest_user_prompt(self): + msgs = [ + Message(role='system', content='sys'), + Message(role='user', content='hello'), + ] + assert extract_latest_user_prompt(msgs) == 'hello' + + def test_condense_attachments(self): + att = HookAttachment( + type='hook_additional_context', + hook_event='PostToolUse', + tool_call_id='id1', + content='extra info', + ) + tool_msg = Message(role='tool', content='result', hook_attachments=[att]) + msgs = [tool_msg] + out = condense_hook_attachments_for_llm(msgs) + assert len(out) == 2 + assert '[hook:PostToolUse]' in out[1].content + assert out[0].content == 'result' + + def test_condense_stop_blocking_feedback(self): + from ms_agent.hooks.context import append_stop_blocking_feedback + + assistant = Message(role='assistant', content='done') + msgs = [assistant] + append_stop_blocking_feedback(msgs, 'not finished yet') + out = condense_hook_attachments_for_llm(msgs) + assert len(out) == 2 + assert 'Stop hook feedback:' in out[1].content + assert 'not finished yet' in out[1].content + + def test_apply_deny_user_prompt(self): + msgs = [Message(role='user', content='bad')] + ok = apply_hook_result_to_messages( + msgs, + HookResult(action='deny', reason='no'), + hook_event='UserPromptSubmit', + ) + assert ok is False diff --git a/tests/test_hooks_loaders.py b/tests/test_hooks_loaders.py new file mode 100644 index 000000000..c728f08f7 --- /dev/null +++ b/tests/test_hooks_loaders.py @@ -0,0 +1,91 @@ +"""Tests for multi-platform hook loaders.""" + +from __future__ import annotations + +import json +from pathlib import Path + +from ms_agent.hooks.loaders.claude import ClaudeSettingsLoader +from ms_agent.hooks.loaders.cursor import CursorHooksLoader +from ms_agent.hooks.loaders.hermes import HermesShellLoader +from ms_agent.hooks.loaders.native import NativeJsonLoader, NativeYamlLoader + + +class TestClaudeLoader: + def test_pre_tool_use(self, tmp_path): + settings = { + 'hooks': { + 'PreToolUse': [{ + 'matcher': 'Bash', + 'hooks': [{ + 'type': 'command', + 'command': './hooks/check.sh', + }], + }], + }, + } + path = tmp_path / 'settings.json' + path.write_text(json.dumps(settings)) + reg = ClaudeSettingsLoader.load_file(path, str(tmp_path)) + handlers = reg.get_handlers( + 'PreToolUse', 'code_executor---shell_executor') + assert len(handlers) == 1 + assert handlers[0].command == './hooks/check.sh' + + +class TestCursorLoader: + def test_pre_tool_use(self, tmp_path): + data = { + 'hooks': { + 'preToolUse': [{ + 'command': './cursor-hook.sh', + 'matcher': 'Shell', + }], + }, + } + path = tmp_path / 'hooks.json' + path.write_text(json.dumps(data)) + reg = CursorHooksLoader.load_file(path, str(tmp_path)) + handlers = reg.get_handlers( + 'PreToolUse', 'code_executor---shell_executor') + assert len(handlers) == 1 + + +class TestHermesLoader: + def test_pre_tool_call(self, tmp_path): + import yaml + data = { + 'hooks': { + 'pre_tool_call': [{ + 'command': './hermes-hook.sh', + 'matcher': 'terminal', + }], + }, + } + path = tmp_path / 'config.yaml' + path.write_text(yaml.dump(data)) + reg = HermesShellLoader.load_file(path, str(tmp_path)) + handlers = reg.get_handlers( + 'PreToolUse', 'code_executor---shell_executor') + assert len(handlers) == 1 + + +class TestNativeLoader: + def test_yaml(self, tmp_path): + import yaml + data = {'hooks': {'Stop': [{'hooks': [ + {'type': 'command', 'command': 'cleanup.sh'}, + ]}]}} + path = tmp_path / 'hooks.yaml' + path.write_text(yaml.dump(data)) + reg = NativeYamlLoader.load_file(path) + assert len(reg.get_handlers('Stop')) == 1 + + def test_json(self, tmp_path): + data = {'hooks': {'SessionStart': [{'hooks': [ + {'type': 'command', 'command': 'init.sh'}, + ]}]}} + path = tmp_path / 'hooks.json' + path.write_text(json.dumps(data)) + reg = NativeJsonLoader.load_file(path) + assert len(reg.get_handlers('SessionStart')) == 1 diff --git a/tests/utils/test_workspace_context.py b/tests/utils/test_workspace_context.py new file mode 100644 index 000000000..af6139e30 --- /dev/null +++ b/tests/utils/test_workspace_context.py @@ -0,0 +1,75 @@ +"""Tests for workspace root resolution.""" + +import os +from types import SimpleNamespace + +import pytest + +from ms_agent.utils.workspace_context import WorkspaceContext, resolve_workspace_root + + +class TestResolveWorkspaceRoot: + def test_defaults_to_cwd_when_missing(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + config = SimpleNamespace() + assert resolve_workspace_root(config) == tmp_path.resolve() + + def test_defaults_to_cwd_when_empty(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + config = SimpleNamespace(output_dir='') + assert resolve_workspace_root(config) == tmp_path.resolve() + + def test_expands_explicit_output_dir(self, tmp_path): + custom = tmp_path / 'artifacts' + config = SimpleNamespace(output_dir=str(custom)) + assert resolve_workspace_root(config) == custom.resolve() + + def test_workspace_context_uses_same_root(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + config = SimpleNamespace(tools=SimpleNamespace(workspace_policy=None)) + ctx = WorkspaceContext.from_config(config) + assert ctx.root == tmp_path.resolve() + + def test_permission_allowed_dirs_aligns_with_workspace_root(self, tmp_path, monkeypatch): + """allowed_dirs[0] and SafetyGuard workspace_root must match resolve_workspace_root.""" + monkeypatch.chdir(tmp_path) + from ms_agent.permission.config import PermissionConfig + from ms_agent.permission.safety import SafetyGuard + from ms_agent.utils.workspace_context import resolve_workspace_root + + config = SimpleNamespace(permission=None, tools=SimpleNamespace(workspace_policy=None)) + workspace_root = str(resolve_workspace_root(config)) + perm_config = PermissionConfig.from_dict({}, project_root=workspace_root) + allowed_dirs = [workspace_root] + guard = SafetyGuard( + config=perm_config.safety, + allowed_dirs=allowed_dirs, + workspace_root=workspace_root, + ) + assert allowed_dirs[0] == workspace_root + assert guard._workspace_root == workspace_root + assert guard._allowed_dirs[0] == workspace_root + + +class TestShellValidatorWorkspaceRoot: + def test_relative_path_resolves_against_workspace_root(self, tmp_path, monkeypatch): + workspace = tmp_path / 'workspace' + workspace.mkdir() + other = tmp_path / 'other' + other.mkdir() + (workspace / 'file.txt').write_text('hello', encoding='utf-8') + + monkeypatch.chdir(other) + + from ms_agent.permission.shell_validator import PathSafetyConfig, ShellPathValidator + + validator = ShellPathValidator( + allowed_dirs=[str(workspace)], + safety_config=PathSafetyConfig(workspace_root=str(workspace)), + ) + result = validator.check('cat file.txt') + assert result.action == 'allow' + + validator_without_root = ShellPathValidator(allowed_dirs=[str(workspace)]) + result_other = validator_without_root.check('cat file.txt') + assert result_other.action in ('deny', 'ask') diff --git a/tests/utils/test_workspace_policy.py b/tests/utils/test_workspace_policy.py deleted file mode 100644 index 012888533..000000000 --- a/tests/utils/test_workspace_policy.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Tests for WorkspacePolicyKernel.""" - -import tempfile -from pathlib import Path - -import pytest - -from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel - - -def test_default_root_is_output_dir(): - with tempfile.TemporaryDirectory() as td: - out = Path(td) / 'out' - out.mkdir() - k = WorkspacePolicyKernel(out) - p = k.resolve_under_roots('foo/bar') - assert p == (out / 'foo' / 'bar').resolve() - - -def test_rejects_escape(): - with tempfile.TemporaryDirectory() as td: - out = Path(td) / 'out' - out.mkdir() - k = WorkspacePolicyKernel(out) - with pytest.raises(WorkspacePolicyError): - k.resolve_under_roots('../../etc/passwd') - - -def test_extra_allow_root(): - with tempfile.TemporaryDirectory() as td: - out = Path(td) / 'out' - other = Path(td) / 'other' - out.mkdir() - other.mkdir() - k = WorkspacePolicyKernel(out, extra_allow_roots=[str(other)]) - assert k.resolve_under_roots(str(other / 'x')) == (other / 'x').resolve() - - -def test_read_only_blocks_redirect(): - with tempfile.TemporaryDirectory() as td: - out = Path(td) / 'out' - out.mkdir() - k = WorkspacePolicyKernel( - out, - shell_default_mode='read_only', - ) - with pytest.raises(WorkspacePolicyError): - k.assert_shell_command_allowed('echo x > file.txt') - - -def test_workspace_write_allows_redirect_but_blocks_network(): - with tempfile.TemporaryDirectory() as td: - out = Path(td) / 'out' - out.mkdir() - k = WorkspacePolicyKernel( - out, - shell_default_mode='workspace_write', - shell_network_enabled=False, - ) - k.assert_shell_command_allowed('echo x > file.txt') - with pytest.raises(WorkspacePolicyError): - k.assert_shell_command_allowed('curl https://example.com') - - -def test_artifact_manager_spill(tmp_path): - from ms_agent.utils.artifact_manager import ArtifactManager - - am = ArtifactManager(tmp_path, max_combined_bytes=32) - big = 'a' * 100 - packed = am.pack_text_result( - tool_name='t', - call_id='c1', - stdout=big, - stderr='', - ) - assert packed.get('truncated') is True - assert 'artifact_path' in packed