From 6051806aea155e8ea8c8ad90ebac858796210dee Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 18 May 2026 15:46:57 +0800 Subject: [PATCH 1/3] =?UTF-8?q?chore:=20bump=20line=20length=20limit=20to?= =?UTF-8?q?=20240=20for=20ruff=20and=20markdownlint=20(#165)=20chore:=20?= =?UTF-8?q?=E5=B0=86=20ruff=20=E5=92=8C=20markdownlint=20=E8=A1=8C?= =?UTF-8?q?=E9=95=BF=E5=BA=A6=E9=99=90=E5=88=B6=E4=BB=8E=20120=20=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E4=B8=BA=20240?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .markdownlint.yml | 54 +++++++++++------------ ruff.toml | 106 +++++++++++++++++++++++----------------------- 2 files changed, 80 insertions(+), 80 deletions(-) diff --git a/.markdownlint.yml b/.markdownlint.yml index 6e21c26..de09eae 100644 --- a/.markdownlint.yml +++ b/.markdownlint.yml @@ -1,27 +1,27 @@ -# See: https://github.com/DavidAnson/markdownlint - -# Unordered list style -MD004: - style: dash - -# Disable line length for tables -MD013: - line_length: 120 - tables: false - -# Ordered list item prefix -MD029: - style: ordered - -# Spaces after list markers -MD030: - ul_single: 1 - ol_single: 1 - ul_multi: 1 - ol_multi: 1 - -MD033: false - -# Code block style -MD046: - style: fenced +# See: https://github.com/DavidAnson/markdownlint + +# Unordered list style +MD004: + style: dash + +# Disable line length for tables +MD013: + line_length: 240 + tables: false + +# Ordered list item prefix +MD029: + style: ordered + +# Spaces after list markers +MD030: + ul_single: 1 + ol_single: 1 + ul_multi: 1 + ol_multi: 1 + +MD033: false + +# Code block style +MD046: + style: fenced diff --git a/ruff.toml b/ruff.toml index cfaf29d..6dd7afc 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,53 +1,53 @@ -target-version = "py314" -line-length = 120 -src = ["src", "tests", "scripts", "tools"] - -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".ipynb_checkpoints", - ".mypy_cache", - ".nox", - ".pants.d", - ".pyenv", - ".pytest_cache", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - ".vscode", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "site-packages", - "venv" -] - -[lint] -select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort - "UP", # pyupgrade - "B", # flake8-bugbear - "SIM", # flake8-simplify - "RUF", # ruff-specific - "RUF001", # 歧义字符(例如全角字母、数字) - "RUF002", # 歧义标点符号(例如全角逗号、句号) - "RUF003", # 歧义空格(例如全角空格) -] - -ignore = ["UP006"] - -isort = { known-first-party = ["src"] } - -per-file-ignores = { "scripts/fix_fullwidth.py" = ["RUF001"] } +target-version = "py314" +line-length = 240 +src = ["src", "tests", "scripts", "tools"] + +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv" +] + +[lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "RUF", # ruff-specific + "RUF001", # 歧义字符(例如全角字母、数字) + "RUF002", # 歧义标点符号(例如全角逗号、句号) + "RUF003", # 歧义空格(例如全角空格) +] + +ignore = ["UP006"] + +isort = { known-first-party = ["src"] } + +per-file-ignores = { "scripts/fix_fullwidth.py" = ["RUF001"] } From bf3127e8963d8798ece56feb24a87fad83cbbfa1 Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 18 May 2026 15:48:57 +0800 Subject: [PATCH 2/3] =?UTF-8?q?chore:=20ignore=20claude=20and=20trace=20di?= =?UTF-8?q?rectories=20=20(#166)=20chore:=20=E5=B0=86=20Claude=20=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E8=B7=AF=E5=BE=84=E5=8A=A0=E5=85=A5=20.gitignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 629 +++++++++++++++++++++++++++-------------------------- 1 file changed, 316 insertions(+), 313 deletions(-) diff --git a/.gitignore b/.gitignore index 0c0f7ff..105e060 100644 --- a/.gitignore +++ b/.gitignore @@ -1,313 +1,316 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[codz] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py.cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# UV -# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -#uv.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock -#poetry.toml - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. -# https://pdm-project.org/en/latest/usage/project/#working-with-version-control -#pdm.lock -#pdm.toml -.pdm-python -.pdm-build/ - -# pixi -# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. -#pixi.lock -# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one -# in the .venv directory. It is recommended not to include this directory in version control. -.pixi - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.envrc -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -.idea/ - -# Abstra -# Abstra is an AI-powered process automation framework. -# Ignore directories containing user credentials, local state, and settings. -# Learn more at https://abstra.io/docs -.abstra/ - -# Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore -# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, -# you could uncomment the following to ignore the entire vscode folder -.vscode/ - -# Ruff stuff: -.ruff_cache/ - -# PyPI configuration file -.pypirc - -# Cursor -# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to -# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data -# refer to https://docs.cursor.com/context/ignore-files -.cursorignore -.cursorindexingignore - -# Marimo -marimo/_static/ -marimo/_lsp/ -__marimo__/ - -# Self Workspace -.ManualAid/ - -# Format Tools -node_modules/ - -# Rest pulled from https://github.com/github/gitignore/blob/master/Node.gitignore -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* -lerna-debug.log* - -# Diagnostic reports (https://nodejs.org/api/report.html) -report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json - -# Runtime data -pids -*.pid -*.seed -*.pid.lock - -# Directory for instrumented libs generated by jscoverage/JSCover -lib-cov - -# Coverage directory used by tools like istanbul -coverage -*.lcov - -# nyc test coverage -.nyc_output - -# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) -.grunt - -# Bower dependency directory (https://bower.io/) -bower_components - -# node-waf configuration -.lock-wscript - -# Compiled binary addons (https://nodejs.org/api/addons.html) -build/Release - -# Dependency directories -jspm_packages/ - -# TypeScript v1 declaration files -typings/ - -# TypeScript cache -*.tsbuildinfo - -# Optional npm cache directory -.npm - -# Optional eslint cache -.eslintcache - -# Optional REPL history -.node_repl_history - -# Output of 'npm pack' -*.tgz - -# Yarn Integrity file -.yarn-integrity - -# dotenv environment variables file -.env -.env.test - -# parcel-bundler cache (https://parceljs.org/) -.cache - -# next.js build output -.next - -# nuxt.js build output -.nuxt - -# vuepress build output -.vuepress/dist - -# Serverless directories -.serverless/ - -# FuseBox cache -.fusebox/ - -# DynamoDB Local files -.dynamodb/ - -# OS metadata -.DS_Store -Thumbs.db - -# Ignore built ts files -__tests__/runner/* - -# IDE files -.idea -*.code-workspace +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +.vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Self Workspace +.ManualAid/ +.claude/ +# claude-tap +.traces/ + +# Format Tools +node_modules/ + +# Rest pulled from https://github.com/github/gitignore/blob/master/Node.gitignore +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +jspm_packages/ + +# TypeScript v1 declaration files +typings/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test + +# parcel-bundler cache (https://parceljs.org/) +.cache + +# next.js build output +.next + +# nuxt.js build output +.nuxt + +# vuepress build output +.vuepress/dist + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# OS metadata +.DS_Store +Thumbs.db + +# Ignore built ts files +__tests__/runner/* + +# IDE files +.idea +*.code-workspace From fcad9f7cc678d76a425b464e018c7e1f73dbb1dc Mon Sep 17 00:00:00 2001 From: Suntion <149924916+SunYanbox@users.noreply.github.com> Date: Mon, 18 May 2026 15:55:00 +0800 Subject: [PATCH 3/3] =?UTF-8?q?style:=20apply=20new=20240-character=20line?= =?UTF-8?q?=20length=20formatting=20across=20codebase=20style:=20=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E6=96=B0=E7=9A=84=20240=20=E5=AD=97=E7=AC=A6=E8=A1=8C?= =?UTF-8?q?=E9=95=BF=E5=BA=A6=E6=A0=BC=E5=BC=8F=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/console/commands/results/copy_cmd.py | 6 +- src/console/commands/workspaces/agent_cmd.py | 349 +++-- src/console/handlers/tool_handler.py | 241 ++- src/console/result_manager.py | 4 +- src/console/ui/formatters.py | 5 +- src/console/ui/widgets/shell_result_tab.py | 345 +++-- src/console/ui/widgets/skill_config_tab.py | 538 ++++--- src/console/ui/widgets/stats_tab.py | 899 ++++++----- src/core/database_manager.py | 1403 +++++++++--------- src/core/launcher.py | 5 +- src/core/paste_window.py | 452 +++--- src/models/skill.py | 251 ++-- src/models/tools/tool_result.py | 279 ++-- src/workspace/path_validator.py | 366 +++-- src/workspace/tools/base_tool.py | 568 ++++--- src/workspace/tools/edit_tool.py | 327 ++-- src/workspace/tools/exact_search_tool.py | 345 +++-- src/workspace/tools/git_tool.py | 397 +++-- src/workspace/tools/glob_tool.py | 68 +- src/workspace/tools/ls_tool.py | 62 +- src/workspace/tools/read_tool.py | 206 ++- src/workspace/tools/regex_search_tool.py | 406 +++-- src/workspace/tools/stat_tool.py | 280 ++-- src/workspace/tools/symbol_ref_tool.py | 835 ++++++----- src/workspace/tools/write_tool.py | 142 +- src/workspace/workspace.py | 586 ++++---- tests/core/test_audit_committer.py | 342 +++-- tests/core/test_database_manager.py | 1012 +++++++------ tests/core/test_tool_call_summaries.py | 296 ++-- tests/workspace/tools/test_edit_tool.py | 416 +++--- tests/workspace/tools/test_git_tool.py | 388 +++-- tests/workspace/tools/test_write_tool.py | 304 ++-- 32 files changed, 5944 insertions(+), 6179 deletions(-) diff --git a/src/console/commands/results/copy_cmd.py b/src/console/commands/results/copy_cmd.py index cf6b22b..39869e5 100644 --- a/src/console/commands/results/copy_cmd.py +++ b/src/console/commands/results/copy_cmd.py @@ -41,11 +41,7 @@ def execute(self, context: CommandContext) -> CommandResult: if tool_obj: copy_to_clipboard(tool_obj.to_func_call()) return CommandResult(success=False, message=f"已成功复制工具{tool_name}的标准调用格式") - warn = ( - f"使用{self.__class__.__name__}时, 传入的参数tool({tool})" - f"不存在于工具{context.tool_registry.list_tools()}中" - f"或对应的已注册工具意外被删除" - ) + warn = f"使用{self.__class__.__name__}时, 传入的参数tool({tool})不存在于工具{context.tool_registry.list_tools()}中或对应的已注册工具意外被删除" context.console.print(f"[yellow]{warn}[/yellow]") warnings.warn(warn, stacklevel=2) diff --git a/src/console/commands/workspaces/agent_cmd.py b/src/console/commands/workspaces/agent_cmd.py index 6e695b3..3909f5b 100644 --- a/src/console/commands/workspaces/agent_cmd.py +++ b/src/console/commands/workspaces/agent_cmd.py @@ -1,177 +1,172 @@ -"""Agent management command (/agent).""" - -from __future__ import annotations - -from argparse import ArgumentParser - -from src.core.agent_manager import AgentManager -from src.core.copy2clip import copy_to_clipboard -from src.models.agent import AgentConfig -from src.models.commands import Command, CommandContext, CommandResult - - -def _reset_default(mgr: AgentManager, context: CommandContext) -> CommandResult: - if mgr.reset_default(): - context.console.print("[green]default.md 已根据内置Default Agent重写完成[/green]") - else: - context.console.print("[red]重置失败: 工作区根路径未初始化[/red]") - return CommandResult(success=True) - - -def _show_current(mgr: AgentManager, context: CommandContext) -> CommandResult: - agent = mgr.get_current() - context.console.print( - f"[bold]Current Agent:[/bold] {agent.name}\n" - f"[dim]{agent.description}[/dim]\n" - f"Whitelist: {agent.tool_permissions.whitelist or '(all)'}\n" - f"Blacklist: {agent.tool_permissions.blacklist or '(none)'}" - ) - return CommandResult(success=True) - - -def _list_all(mgr: AgentManager, context: CommandContext) -> CommandResult: - agents = mgr.list_agents() - if not agents: - context.console.print("[yellow]No agents found in .ManualAid/agents/[/yellow]") - return CommandResult(success=True) - - lines = ["[bold]Available Agents:[/bold]"] - for a in agents: - marker = ">" if a.name == mgr.current_agent_name else " " - lines.append(f" {marker} {a.name} — {a.description}") - context.console.print("\n".join(lines)) - return CommandResult(success=True) - - -class AgentCommand(Command): - """Manage Agent configuration""" - - def __init__(self): - super().__init__() - self.name = "agent" - self.aliases = ["/agent"] - self.description = "管理 Agent 配置 (列表、切换、复制、重置)" - self.usage = ( - "/agent — 显示当前 Agent\n" - "/agent list — 列出所有 Agent\n" - "/agent — 按名称或唯一前缀切换 Agent\n" - "/agent default — 切换到默认 Agent\n" - "/agent copy — 复制当前 Agent 的角色+工作流到剪贴板\n" - "/agent copy — 复制指定 Agent 的角色+工作流到剪贴板\n" - "/agent reset — 根据 prompts.py 重写 default.md" - ) - self.argparse = ArgumentParser("agent") - self.argparse.add_argument( - "subcommand", - nargs="?", - default=None, - help="子命令: list, default, copy, reset, 或 Agent 名称", - ) - for usage in self.usage.split("\n"): - self.argparse.add_argument( - "Usage", - nargs="?", - default=None, - help=usage, - ) - - def execute(self, context: CommandContext) -> CommandResult: - # Show help on -h / --help - if "-h" in context.parsed_input.source or "--help" in context.parsed_input.source: - context.console.print(self.argparse.format_help()) - return CommandResult(success=True) - - mgr = AgentManager() - # Parse args from source: "/agent list" -> "list" - parts = context.parsed_input.source.split() - args = " ".join(parts[1:]) if len(parts) > 1 else "" - - if not args: - return _show_current(mgr, context) - if args == "list": - return _list_all(mgr, context) - if args.startswith("copy"): - rest = args[4:].strip() - return self._copy_agent(mgr, context, rest or None) - if args == "default": - return self._switch(mgr, "default", context) - if args == "reset": - return _reset_default(mgr, context) - - # Treat as agent name (supports unique prefix matching) - return self._switch(mgr, args, context) - - def _switch(self, mgr: AgentManager, name: str, context: CommandContext) -> CommandResult: - # Try exact match first - if mgr.switch_agent(name): - agent = mgr.get_current() - context.console.print(f"[green]Switched to agent:[/green] {agent.name}") - # Update TUI dropdown if available - self._sync_tui(context, mgr.current_agent_name) - return CommandResult(success=True) - - # Try unique prefix match - matches = [n for n in mgr.agent_names() if n.startswith(name)] - if len(matches) == 1: - mgr.switch_agent(matches[0]) - context.console.print(f"[green]Switched to agent:[/green] {matches[0]}") - self._sync_tui(context, mgr.current_agent_name) - return CommandResult(success=True) - - if len(matches) > 1: - context.console.print(f"[red]Ambiguous prefix '{name}' matches: {', '.join(matches)}[/red]") - else: - context.console.print(f"[red]Agent '{name}' not found.[/red]") - context.console.print("Use [bold]/agent list[/bold] to see available agents.") - return CommandResult(success=True) - - def _copy_agent(self, mgr: AgentManager, context: CommandContext, name: str | None) -> CommandResult: - if name: - agent = mgr.get(name) - if agent is None: - matches = [n for n in mgr.agent_names() if n.startswith(name)] - if len(matches) == 1: - agent = mgr.get(matches[0]) - elif len(matches) > 1: - context.console.print(f"[red]Ambiguous prefix '{name}' matches: {', '.join(matches)}[/red]") - return CommandResult(success=False) - else: - context.console.print(f"[red]Agent '{name}' not found.[/red]") - return CommandResult(success=False) - else: - agent = mgr.get_current() - - text = self._format_agent_copy(agent) - if copy_to_clipboard(text): - context.console.print(f"[green]Agent '{agent.name}' settings copied to clipboard.[/green]") - else: - context.console.print(text) - context.console.print("[yellow](Clipboard unavailable — printed above instead)[/yellow]") - return CommandResult(success=True) - - @staticmethod - def _format_agent_copy(agent: AgentConfig) -> str: - """Format agent body (role + workflow) for external pasting.""" - parts = [f"--- Agent: {agent.name} ---", ""] - if agent.body_role: - parts.append(agent.body_role) - parts.append("") - if agent.body_workflow: - parts.append(agent.body_workflow) - parts.append("") - return "\n".join(parts).strip() - - @staticmethod - def _sync_tui(context: CommandContext, agent_name: str) -> None: - """Update TUI dropdown and title bar after agent switch.""" - app = context.app - if app is None: - return - try: - from textual.widgets import Select - - select = app.query_one("#agent-select", Select) - if select: - select.value = agent_name - except Exception: - pass +"""Agent management command (/agent).""" + +from __future__ import annotations + +from argparse import ArgumentParser + +from src.core.agent_manager import AgentManager +from src.core.copy2clip import copy_to_clipboard +from src.models.agent import AgentConfig +from src.models.commands import Command, CommandContext, CommandResult + + +def _reset_default(mgr: AgentManager, context: CommandContext) -> CommandResult: + if mgr.reset_default(): + context.console.print("[green]default.md 已根据内置Default Agent重写完成[/green]") + else: + context.console.print("[red]重置失败: 工作区根路径未初始化[/red]") + return CommandResult(success=True) + + +def _show_current(mgr: AgentManager, context: CommandContext) -> CommandResult: + agent = mgr.get_current() + context.console.print(f"[bold]Current Agent:[/bold] {agent.name}\n[dim]{agent.description}[/dim]\nWhitelist: {agent.tool_permissions.whitelist or '(all)'}\nBlacklist: {agent.tool_permissions.blacklist or '(none)'}") + return CommandResult(success=True) + + +def _list_all(mgr: AgentManager, context: CommandContext) -> CommandResult: + agents = mgr.list_agents() + if not agents: + context.console.print("[yellow]No agents found in .ManualAid/agents/[/yellow]") + return CommandResult(success=True) + + lines = ["[bold]Available Agents:[/bold]"] + for a in agents: + marker = ">" if a.name == mgr.current_agent_name else " " + lines.append(f" {marker} {a.name} — {a.description}") + context.console.print("\n".join(lines)) + return CommandResult(success=True) + + +class AgentCommand(Command): + """Manage Agent configuration""" + + def __init__(self): + super().__init__() + self.name = "agent" + self.aliases = ["/agent"] + self.description = "管理 Agent 配置 (列表、切换、复制、重置)" + self.usage = ( + "/agent — 显示当前 Agent\n" + "/agent list — 列出所有 Agent\n" + "/agent — 按名称或唯一前缀切换 Agent\n" + "/agent default — 切换到默认 Agent\n" + "/agent copy — 复制当前 Agent 的角色+工作流到剪贴板\n" + "/agent copy — 复制指定 Agent 的角色+工作流到剪贴板\n" + "/agent reset — 根据 prompts.py 重写 default.md" + ) + self.argparse = ArgumentParser("agent") + self.argparse.add_argument( + "subcommand", + nargs="?", + default=None, + help="子命令: list, default, copy, reset, 或 Agent 名称", + ) + for usage in self.usage.split("\n"): + self.argparse.add_argument( + "Usage", + nargs="?", + default=None, + help=usage, + ) + + def execute(self, context: CommandContext) -> CommandResult: + # Show help on -h / --help + if "-h" in context.parsed_input.source or "--help" in context.parsed_input.source: + context.console.print(self.argparse.format_help()) + return CommandResult(success=True) + + mgr = AgentManager() + # Parse args from source: "/agent list" -> "list" + parts = context.parsed_input.source.split() + args = " ".join(parts[1:]) if len(parts) > 1 else "" + + if not args: + return _show_current(mgr, context) + if args == "list": + return _list_all(mgr, context) + if args.startswith("copy"): + rest = args[4:].strip() + return self._copy_agent(mgr, context, rest or None) + if args == "default": + return self._switch(mgr, "default", context) + if args == "reset": + return _reset_default(mgr, context) + + # Treat as agent name (supports unique prefix matching) + return self._switch(mgr, args, context) + + def _switch(self, mgr: AgentManager, name: str, context: CommandContext) -> CommandResult: + # Try exact match first + if mgr.switch_agent(name): + agent = mgr.get_current() + context.console.print(f"[green]Switched to agent:[/green] {agent.name}") + # Update TUI dropdown if available + self._sync_tui(context, mgr.current_agent_name) + return CommandResult(success=True) + + # Try unique prefix match + matches = [n for n in mgr.agent_names() if n.startswith(name)] + if len(matches) == 1: + mgr.switch_agent(matches[0]) + context.console.print(f"[green]Switched to agent:[/green] {matches[0]}") + self._sync_tui(context, mgr.current_agent_name) + return CommandResult(success=True) + + if len(matches) > 1: + context.console.print(f"[red]Ambiguous prefix '{name}' matches: {', '.join(matches)}[/red]") + else: + context.console.print(f"[red]Agent '{name}' not found.[/red]") + context.console.print("Use [bold]/agent list[/bold] to see available agents.") + return CommandResult(success=True) + + def _copy_agent(self, mgr: AgentManager, context: CommandContext, name: str | None) -> CommandResult: + if name: + agent = mgr.get(name) + if agent is None: + matches = [n for n in mgr.agent_names() if n.startswith(name)] + if len(matches) == 1: + agent = mgr.get(matches[0]) + elif len(matches) > 1: + context.console.print(f"[red]Ambiguous prefix '{name}' matches: {', '.join(matches)}[/red]") + return CommandResult(success=False) + else: + context.console.print(f"[red]Agent '{name}' not found.[/red]") + return CommandResult(success=False) + else: + agent = mgr.get_current() + + text = self._format_agent_copy(agent) + if copy_to_clipboard(text): + context.console.print(f"[green]Agent '{agent.name}' settings copied to clipboard.[/green]") + else: + context.console.print(text) + context.console.print("[yellow](Clipboard unavailable — printed above instead)[/yellow]") + return CommandResult(success=True) + + @staticmethod + def _format_agent_copy(agent: AgentConfig) -> str: + """Format agent body (role + workflow) for external pasting.""" + parts = [f"--- Agent: {agent.name} ---", ""] + if agent.body_role: + parts.append(agent.body_role) + parts.append("") + if agent.body_workflow: + parts.append(agent.body_workflow) + parts.append("") + return "\n".join(parts).strip() + + @staticmethod + def _sync_tui(context: CommandContext, agent_name: str) -> None: + """Update TUI dropdown and title bar after agent switch.""" + app = context.app + if app is None: + return + try: + from textual.widgets import Select + + select = app.query_one("#agent-select", Select) + if select: + select.value = agent_name + except Exception: + pass diff --git a/src/console/handlers/tool_handler.py b/src/console/handlers/tool_handler.py index 0faaf67..c6d86a3 100644 --- a/src/console/handlers/tool_handler.py +++ b/src/console/handlers/tool_handler.py @@ -1,123 +1,118 @@ -from __future__ import annotations - -import time -from typing import TYPE_CHECKING - -from src.console.ui.widgets.tools_result_widget import ToolsResultWidget -from src.constants.files import EXTENSION_TO_LANGUAGE -from src.models.commands import CommandParseResult -from src.models.tools.tool_result_collection import ToolResultCollection -from src.utils.string_snapshot import truncate_for_display, truncate_params_string, truncate_single_string - -if TYPE_CHECKING: - from src.console.result_manager import ResultManager - from src.core.tool_registry import ToolRegistry - - -def _detect_language(func_name: str, func_kwargs: dict) -> str | None: - """Detect language for syntax highlighting""" - if "read" not in func_name: - return None - - file_path = "" - if "file_path" in func_kwargs: - file_path = func_kwargs["file_path"] - - ext_map = EXTENSION_TO_LANGUAGE - - for ext, lang in ext_map.items(): - if file_path.lower().endswith(ext): - return lang - - return "text" - - -def _format_tool_params(kwargs: dict) -> str: - """Format tool parameters as concise string""" - parts = [] - - # Keyword arguments - for key, value in kwargs.items(): - if isinstance(value, str): - parts.append(f'{key}="{truncate_single_string(value)}"') - else: - parts.append(f"{key}={value}") - - if not parts: - return "no parameters" - - params_str = truncate_params_string(", ".join(parts)) - - return params_str - - -def _create_result_title(index: int, func_name: str, kwargs: dict, lines_count: int) -> str: - """Create result title with Rich markup""" - params_str = _format_tool_params(kwargs) - return ( - f"[bold cyan]##{index}[/bold cyan] [bold green]{func_name}[/bold green]([dim]{params_str}[/dim])" - + f" [yellow]({lines_count} lines)[/yellow]" - ) - - -class ToolHandler: - """Handler for processing tool calls""" - - def __init__( - self, - tool_registry: ToolRegistry, - result_manager: ResultManager, - console, - ): - self.tool_registry = tool_registry - self.result_manager = result_manager - self.console = console - - def handle(self, parsed_input: CommandParseResult) -> bool: - """Handle a parsed tool call input - - Args: - parsed_input: CommandParseResult from input_parser - - Returns: - True if handled successfully, False otherwise - """ - if parsed_input.is_command: - return False - - collection: ToolResultCollection = ToolResultCollection() - - for func_name, func_kwargs in parsed_input.funcs: - parms: str = f"{{{func_kwargs}" - - # 避免多参数工具的返回值过于占上下文 - if len(parms) > 120: - parms = parms[:117] + "..." - - parms += "}" - - start = time.perf_counter() - - # 执行 - response = self.tool_registry.execute(func_name, **func_kwargs) - - collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=response.response) - - result = "" - - for results in collection.results.values(): - for _result in results: - result += _result[1] - - tool_names = truncate_for_display(",".join(collection.tools())) - - self.result_manager.add(tool_names, result) - - tool_result_widget = ToolsResultWidget() - tool_result_widget.set_collection(collection) - - self.console.print_collapsible_with_widget( - truncate_single_string(f"调用工具结果 | {tool_names}") + f" | {time.ctime(time.time())}", tool_result_widget - ) - - return True +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +from src.console.ui.widgets.tools_result_widget import ToolsResultWidget +from src.constants.files import EXTENSION_TO_LANGUAGE +from src.models.commands import CommandParseResult +from src.models.tools.tool_result_collection import ToolResultCollection +from src.utils.string_snapshot import truncate_for_display, truncate_params_string, truncate_single_string + +if TYPE_CHECKING: + from src.console.result_manager import ResultManager + from src.core.tool_registry import ToolRegistry + + +def _detect_language(func_name: str, func_kwargs: dict) -> str | None: + """Detect language for syntax highlighting""" + if "read" not in func_name: + return None + + file_path = "" + if "file_path" in func_kwargs: + file_path = func_kwargs["file_path"] + + ext_map = EXTENSION_TO_LANGUAGE + + for ext, lang in ext_map.items(): + if file_path.lower().endswith(ext): + return lang + + return "text" + + +def _format_tool_params(kwargs: dict) -> str: + """Format tool parameters as concise string""" + parts = [] + + # Keyword arguments + for key, value in kwargs.items(): + if isinstance(value, str): + parts.append(f'{key}="{truncate_single_string(value)}"') + else: + parts.append(f"{key}={value}") + + if not parts: + return "no parameters" + + params_str = truncate_params_string(", ".join(parts)) + + return params_str + + +def _create_result_title(index: int, func_name: str, kwargs: dict, lines_count: int) -> str: + """Create result title with Rich markup""" + params_str = _format_tool_params(kwargs) + return f"[bold cyan]##{index}[/bold cyan] [bold green]{func_name}[/bold green]([dim]{params_str}[/dim])" + f" [yellow]({lines_count} lines)[/yellow]" + + +class ToolHandler: + """Handler for processing tool calls""" + + def __init__( + self, + tool_registry: ToolRegistry, + result_manager: ResultManager, + console, + ): + self.tool_registry = tool_registry + self.result_manager = result_manager + self.console = console + + def handle(self, parsed_input: CommandParseResult) -> bool: + """Handle a parsed tool call input + + Args: + parsed_input: CommandParseResult from input_parser + + Returns: + True if handled successfully, False otherwise + """ + if parsed_input.is_command: + return False + + collection: ToolResultCollection = ToolResultCollection() + + for func_name, func_kwargs in parsed_input.funcs: + parms: str = f"{{{func_kwargs}" + + # 避免多参数工具的返回值过于占上下文 + if len(parms) > 120: + parms = parms[:117] + "..." + + parms += "}" + + start = time.perf_counter() + + # 执行 + response = self.tool_registry.execute(func_name, **func_kwargs) + + collection.add(func_name, time.perf_counter() - start, kwargs=func_kwargs, result=response.response) + + result = "" + + for results in collection.results.values(): + for _result in results: + result += _result[1] + + tool_names = truncate_for_display(",".join(collection.tools())) + + self.result_manager.add(tool_names, result) + + tool_result_widget = ToolsResultWidget() + tool_result_widget.set_collection(collection) + + self.console.print_collapsible_with_widget(truncate_single_string(f"调用工具结果 | {tool_names}") + f" | {time.ctime(time.time())}", tool_result_widget) + + return True diff --git a/src/console/result_manager.py b/src/console/result_manager.py index 0b5de12..bab3151 100644 --- a/src/console/result_manager.py +++ b/src/console/result_manager.py @@ -79,6 +79,4 @@ def _cleanup_expired(self) -> None: """Clean up expired entries""" now = time.time() expire_seconds = self.CLEANUP_MINUTES * 60 - self._history = [ - entry for entry in self._history if not (entry.copied and now - entry.timestamp > expire_seconds) - ] + self._history = [entry for entry in self._history if not (entry.copied and now - entry.timestamp > expire_seconds)] diff --git a/src/console/ui/formatters.py b/src/console/ui/formatters.py index 8d1ab89..457e6ee 100644 --- a/src/console/ui/formatters.py +++ b/src/console/ui/formatters.py @@ -62,10 +62,7 @@ def create_result_title( Rich markup formatted title """ params_str = OutputFormatter.format_tool_params(args, kwargs) - return ( - f"[bold cyan]##{index}[/bold cyan] [bold green]{func_name}[/bold green]([dim]{params_str}[/dim])" - + f" [yellow]({lines_count} lines)[/yellow]" - ) + return f"[bold cyan]##{index}[/bold cyan] [bold green]{func_name}[/bold green]([dim]{params_str}[/dim])" + f" [yellow]({lines_count} lines)[/yellow]" @staticmethod def detect_language(file_path: str) -> str: diff --git a/src/console/ui/widgets/shell_result_tab.py b/src/console/ui/widgets/shell_result_tab.py index 5986d37..772d06e 100644 --- a/src/console/ui/widgets/shell_result_tab.py +++ b/src/console/ui/widgets/shell_result_tab.py @@ -1,173 +1,172 @@ -"""Shell 命令结果标签页 — 查看/复制已执行的 Shell 命令输出.""" - -from __future__ import annotations - -import datetime -from typing import ClassVar - -from textual.containers import Horizontal, Vertical -from textual.widgets import Button, Collapsible, Label, Static - -from src.core.copy2clip import copy_to_clipboard - - -class ShellResultTab(Vertical): - """Shell 命令执行结果标签页. - - 展示所有已完成(已批准/已拒绝)的 Shell 命令及其输出, - 支持展开查看详细输出并复制. - """ - - DEFAULT_CSS: ClassVar[str] = """ - ShellResultTab { - height: 1fr; - width: 1fr; - padding: 0 1; - overflow-y: auto; - } - - #shell-result-placeholder { - height: 100%; - content-align: center middle; - color: $text-muted; - } - - #shell-result-empty { - height: 100%; - content-align: center middle; - color: $text-muted; - } - - #shell-result-header { - height: auto; - padding: 1 0; - text-style: bold; - color: $text; - } - - .shell-collapsible { - height: auto; - margin-bottom: 1; - } - - .shell-output-container { - max-height: 20; - overflow-y: auto; - padding: 1; - background: $surface; - border: solid $primary; - margin-bottom: 1; - } - - .shell-button-row { - height: auto; - align: left middle; - margin-bottom: 1; - } - """ - - def __init__(self) -> None: - super().__init__() - self._db = None - - def compose(self): - yield Label("正在加载...", id="shell-result-placeholder") - - def set_database(self, db) -> None: - """设置数据库引用并刷新.""" - self._db = db - self.set_timer(0.0, self._refresh) - - def on_mount(self) -> None: - if self._db is not None: - self.set_timer(0.1, self._refresh) - - async def _refresh(self) -> None: - """查询已完成的 Shell 命令并重建 UI.""" - if self._db is None: - return - - await self.remove_children() - - shells = self._db.get_shell_completed() - - if not shells: - await self.mount(Label("暂无已执行的 Shell 命令.", id="shell-result-empty")) - return - - header = Label(f"Shell 命令执行记录 ({len(shells)} 项)", id="shell-result-header") - await self.mount(header) - - for i, shell in enumerate(shells): - ( - shell_id, - command, - description, - _ts, - _sid, - audit_status, - output, - exit_code, - executed_at, - ) = shell - - is_approved = audit_status == "APPROVED" - status_icon = "✓" if is_approved else "✗" - status_color = "green" if is_approved else "red" - - # Build detailed content - lines: list[str] = [ - f"[bold]Command:[/bold] $ {command}", - f"[bold]Status:[/bold] [{status_color}]{status_icon} {audit_status}[/{status_color}]", - ] - if description: - lines.append(f"[bold]Description:[/bold] {description}") - if exit_code is not None: - lines.append(f"[bold]Exit Code:[/bold] {exit_code}") - if executed_at: - dt_str = datetime.datetime.fromtimestamp(executed_at).strftime("%Y-%m-%d %H:%M:%S") - lines.append(f"[bold]Executed At:[/bold] {dt_str}") - if output: - lines.append(f"\n[bold]Output:[/bold]\n{output}") - - content = "\n".join(lines) - - output_text = Static(content, markup=True) - output_container = Vertical(output_text, classes="shell-output-container") - copy_btn = Button("复制输出", id=f"shell_copy-{shell_id}") - btn_row = Horizontal(copy_btn, classes="shell-button-row") - - # First items expanded by default, rest collapsed - collapsed = i > 3 - collapsible = Collapsible( - Vertical(output_container, btn_row), - title=f"[{status_color}]{status_icon}[/{status_color}] " - f"Shell #{shell_id}: {command.strip()[:60]}{'...' if len(command.strip()) > 60 else ''}", - classes="shell-collapsible", - collapsed=collapsed, - ) - await self.mount(collapsible) - - async def on_button_pressed(self, event: Button.Pressed) -> None: - """处理复制按钮点击.""" - button_id = event.button.id or "" - if not button_id.startswith("shell_copy-"): - return - - try: - shell_id = int(button_id.split("-", 1)[1]) - except ValueError, IndexError: - return - - if self._db is None: - return - - shells = self._db.get_shell_completed() - for shell in shells: - if shell[0] == shell_id: - output = shell[6] or "(空输出)" - copy_to_clipboard(output) - self.notify("输出已复制到剪贴板", timeout=3) - return - - self.notify("未找到对应记录", severity="error", timeout=3) +"""Shell 命令结果标签页 — 查看/复制已执行的 Shell 命令输出.""" + +from __future__ import annotations + +import datetime +from typing import ClassVar + +from textual.containers import Horizontal, Vertical +from textual.widgets import Button, Collapsible, Label, Static + +from src.core.copy2clip import copy_to_clipboard + + +class ShellResultTab(Vertical): + """Shell 命令执行结果标签页. + + 展示所有已完成(已批准/已拒绝)的 Shell 命令及其输出, + 支持展开查看详细输出并复制. + """ + + DEFAULT_CSS: ClassVar[str] = """ + ShellResultTab { + height: 1fr; + width: 1fr; + padding: 0 1; + overflow-y: auto; + } + + #shell-result-placeholder { + height: 100%; + content-align: center middle; + color: $text-muted; + } + + #shell-result-empty { + height: 100%; + content-align: center middle; + color: $text-muted; + } + + #shell-result-header { + height: auto; + padding: 1 0; + text-style: bold; + color: $text; + } + + .shell-collapsible { + height: auto; + margin-bottom: 1; + } + + .shell-output-container { + max-height: 20; + overflow-y: auto; + padding: 1; + background: $surface; + border: solid $primary; + margin-bottom: 1; + } + + .shell-button-row { + height: auto; + align: left middle; + margin-bottom: 1; + } + """ + + def __init__(self) -> None: + super().__init__() + self._db = None + + def compose(self): + yield Label("正在加载...", id="shell-result-placeholder") + + def set_database(self, db) -> None: + """设置数据库引用并刷新.""" + self._db = db + self.set_timer(0.0, self._refresh) + + def on_mount(self) -> None: + if self._db is not None: + self.set_timer(0.1, self._refresh) + + async def _refresh(self) -> None: + """查询已完成的 Shell 命令并重建 UI.""" + if self._db is None: + return + + await self.remove_children() + + shells = self._db.get_shell_completed() + + if not shells: + await self.mount(Label("暂无已执行的 Shell 命令.", id="shell-result-empty")) + return + + header = Label(f"Shell 命令执行记录 ({len(shells)} 项)", id="shell-result-header") + await self.mount(header) + + for i, shell in enumerate(shells): + ( + shell_id, + command, + description, + _ts, + _sid, + audit_status, + output, + exit_code, + executed_at, + ) = shell + + is_approved = audit_status == "APPROVED" + status_icon = "✓" if is_approved else "✗" + status_color = "green" if is_approved else "red" + + # Build detailed content + lines: list[str] = [ + f"[bold]Command:[/bold] $ {command}", + f"[bold]Status:[/bold] [{status_color}]{status_icon} {audit_status}[/{status_color}]", + ] + if description: + lines.append(f"[bold]Description:[/bold] {description}") + if exit_code is not None: + lines.append(f"[bold]Exit Code:[/bold] {exit_code}") + if executed_at: + dt_str = datetime.datetime.fromtimestamp(executed_at).strftime("%Y-%m-%d %H:%M:%S") + lines.append(f"[bold]Executed At:[/bold] {dt_str}") + if output: + lines.append(f"\n[bold]Output:[/bold]\n{output}") + + content = "\n".join(lines) + + output_text = Static(content, markup=True) + output_container = Vertical(output_text, classes="shell-output-container") + copy_btn = Button("复制输出", id=f"shell_copy-{shell_id}") + btn_row = Horizontal(copy_btn, classes="shell-button-row") + + # First items expanded by default, rest collapsed + collapsed = i > 3 + collapsible = Collapsible( + Vertical(output_container, btn_row), + title=f"[{status_color}]{status_icon}[/{status_color}] Shell #{shell_id}: {command.strip()[:60]}{'...' if len(command.strip()) > 60 else ''}", + classes="shell-collapsible", + collapsed=collapsed, + ) + await self.mount(collapsible) + + async def on_button_pressed(self, event: Button.Pressed) -> None: + """处理复制按钮点击.""" + button_id = event.button.id or "" + if not button_id.startswith("shell_copy-"): + return + + try: + shell_id = int(button_id.split("-", 1)[1]) + except ValueError, IndexError: + return + + if self._db is None: + return + + shells = self._db.get_shell_completed() + for shell in shells: + if shell[0] == shell_id: + output = shell[6] or "(空输出)" + copy_to_clipboard(output) + self.notify("输出已复制到剪贴板", timeout=3) + return + + self.notify("未找到对应记录", severity="error", timeout=3) diff --git a/src/console/ui/widgets/skill_config_tab.py b/src/console/ui/widgets/skill_config_tab.py index 1e16296..cb3458d 100644 --- a/src/console/ui/widgets/skill_config_tab.py +++ b/src/console/ui/widgets/skill_config_tab.py @@ -1,272 +1,266 @@ -from __future__ import annotations - -from typing import ClassVar - -from textual.containers import Horizontal, Vertical -from textual.widgets import Button, DataTable, Label, Static - - -class SkillConfigTab(Vertical): - """Skill 配置标签页. - - 显示所有发现的 Skill,支持: - - 查看全局和项目级 Skill - - 启用/禁用 Skill - - 查看 Skill 详情 - """ - - DEFAULT_CSS: ClassVar[str] = """ - SkillConfigTab { - height: 1fr; - width: 1fr; - padding: 0 1; - overflow-y: auto; - } - - #skill-header { - height: auto; - padding: 1 0; - text-style: bold; - color: $text; - border-bottom: solid $primary; - } - - #skill-toolbar { - height: auto; - padding: 1 0; - align: left middle; - } - - #skill-toolbar Button { - margin-right: 1; - } - - #skill-table { - height: 1fr; - } - - #skill-detail { - height: auto; - max-height: 10; - padding: 1; - margin-top: 1; - background: $surface; - border: solid $primary; - overflow-y: auto; - } - - #skill-empty { - height: 100%; - content-align: center middle; - color: $text-muted; - } - - .skill-row-enabled { - color: $text; - } - - .skill-row-disabled { - color: $text-muted; - } - """ - - def __init__(self) -> None: - super().__init__() - self._skill_manager = None - self._workspace_root = None - self._skills_data: dict = {} - self._columns_initialized = False # 标记列是否已初始化 - - def compose(self): - yield Label("Skill 配置", id="skill-header") - with Horizontal(id="skill-toolbar"): - yield Button("刷新", id="skill-refresh-btn", variant="default") - yield Button("启用选中", id="skill-enable-btn", variant="success") - yield Button("禁用选中", id="skill-disable-btn", variant="warning") - yield Button("启用全部", id="skill-enable-all-btn", variant="success") - yield Button("禁用全部", id="skill-disable-all-btn", variant="warning") - yield DataTable(id="skill-table") - yield Static("选择 Skill 查看详情", id="skill-detail") - - def set_managers(self, skill_manager, workspace_root) -> None: - """设置 Skill 管理器和工作区根目录.""" - self._skill_manager = skill_manager - self._workspace_root = workspace_root - # 发现 skills - if workspace_root: - from pathlib import Path - - skill_manager.discover(Path(workspace_root)) - self._refresh() - - def on_mount(self) -> None: - table = self.query_one("#skill-table", DataTable) - table.add_columns("启用", "名称", "类型", "描述") - table.cursor_type = "row" - self._columns_initialized = True - # 如果已经有数据,刷新显示 - if self._skills_data: - self._refresh() - - def _refresh(self) -> None: - """刷新 Skill 列表.""" - if self._skill_manager is None: - return - - # 重新获取所有技能(会从数据库加载禁用状态) - self._skills_data = self._skill_manager.get_all() - - # 如果列还没初始化,不刷新(等 on_mount 后自动刷新) - if not self._columns_initialized: - return - - table = self.query_one("#skill-table", DataTable) - table.clear() - - # 获取当前禁用状态用于显示 - disabled_set = self._skill_manager.get_disabled() - - for name, skill in self._skills_data.items(): - is_global = skill.metadata.get("is_global", True) - skill_type = "全局" if is_global else "项目" - # 使用禁用集合判断状态,确保与持久化数据一致 - is_enabled = name not in disabled_set - status = "✓" if is_enabled else "✗" - description = skill.description[:50] + "..." if len(skill.description) > 50 else skill.description - table.add_row(status, name, skill_type, description) - - def _update_detail(self, row_index: int) -> None: - """更新详情显示.""" - if self._skill_manager is None: - return - - table = self.query_one("#skill-table", DataTable) - if row_index is None or row_index < 0: - return - - row_data = table.get_row_at(row_index) - if not row_data: - return - - name = row_data[1] - skill = self._skills_data.get(name) - if not skill: - return - - detail_text = ( - f"[bold]{skill.name}[/bold]\n" - f"位置: {skill.location}\n" - f"类型: {'全局' if skill.metadata.get('is_global', True) else '项目'}\n" - f"状态: {'启用' if skill.enabled else '禁用'}\n\n" - f"描述: {skill.description}" - ) - - detail = self.query_one("#skill-detail", Static) - detail.update(detail_text) - - async def on_data_table_row_highlighted(self, event: DataTable.RowHighlighted) -> None: - """行高亮时更新详情.""" - if event.data_table.id == "skill-table": - self._update_detail(event.cursor_row) - - async def on_button_pressed(self, event: Button.Pressed) -> None: - if self._skill_manager is None: - return - - button_id = event.button.id or "" - - if button_id == "skill-refresh-btn": - if self._workspace_root: - from pathlib import Path - - self._skill_manager.discover(Path(self._workspace_root)) - self._refresh() - self.notify("已刷新 Skill 列表") - - elif button_id == "skill-enable-btn": - # 启用选中的 Skill - table = self.query_one("#skill-table", DataTable) - row_index = table.cursor_row - if row_index is None or row_index < 0: - self.notify("请先选择一个 Skill", severity="warning") - return - - row_data = table.get_row_at(row_index) - if not row_data: - return - - name = row_data[1] - disabled = self._skill_manager.get_disabled() - if name in disabled: - disabled.discard(name) - self._skill_manager.set_disabled(disabled, persist=True) - self._refresh() - self.notify(f"已启用: {name}") - else: - self.notify(f"{name} 已经是启用状态", severity="information") - - elif button_id == "skill-disable-btn": - # 禁用选中的 Skill - table = self.query_one("#skill-table", DataTable) - row_index = table.cursor_row - if row_index is None or row_index < 0: - self.notify("请先选择一个 Skill", severity="warning") - return - - row_data = table.get_row_at(row_index) - if not row_data: - return - - name = row_data[1] - disabled = self._skill_manager.get_disabled() - if name not in disabled: - disabled.add(name) - self._skill_manager.set_disabled(disabled, persist=True) - self._refresh() - self.notify(f"已禁用: {name}") - else: - self.notify(f"{name} 已经是禁用状态", severity="information") - - elif button_id == "skill-enable-all-btn": - self._skill_manager.set_disabled(set(), persist=True) - self._refresh() - self.notify("已启用所有 Skill") - - elif button_id == "skill-disable-all-btn": - all_names = set(self._skills_data.keys()) - self._skill_manager.set_disabled(all_names, persist=True) - self._refresh() - self.notify("已禁用所有 Skill") - - async def on_data_table_cell_selected(self, event: DataTable.CellSelected) -> None: - """单元格选中时切换启用状态(点击启用列).""" - if event.data_table.id != "skill-table": - return - - if event.column_key != 0: # 只在"启用"列点击时切换 - return - - if self._skill_manager is None: - return - - row_data = event.data_table.get_row_at(event.cursor_row) - if not row_data: - return - - name = row_data[1] - skill = self._skills_data.get(name) - if not skill: - return - - # 切换状态 - disabled = self._skill_manager.get_disabled() - if name in disabled: - disabled.discard(name) - else: - disabled.add(name) - - self._skill_manager.set_disabled(disabled, persist=True) - self._refresh() - - status = "启用" if name not in disabled else "禁用" - self.notify(f"已{status}: {name}") +from __future__ import annotations + +from typing import ClassVar + +from textual.containers import Horizontal, Vertical +from textual.widgets import Button, DataTable, Label, Static + + +class SkillConfigTab(Vertical): + """Skill 配置标签页. + + 显示所有发现的 Skill,支持: + - 查看全局和项目级 Skill + - 启用/禁用 Skill + - 查看 Skill 详情 + """ + + DEFAULT_CSS: ClassVar[str] = """ + SkillConfigTab { + height: 1fr; + width: 1fr; + padding: 0 1; + overflow-y: auto; + } + + #skill-header { + height: auto; + padding: 1 0; + text-style: bold; + color: $text; + border-bottom: solid $primary; + } + + #skill-toolbar { + height: auto; + padding: 1 0; + align: left middle; + } + + #skill-toolbar Button { + margin-right: 1; + } + + #skill-table { + height: 1fr; + } + + #skill-detail { + height: auto; + max-height: 10; + padding: 1; + margin-top: 1; + background: $surface; + border: solid $primary; + overflow-y: auto; + } + + #skill-empty { + height: 100%; + content-align: center middle; + color: $text-muted; + } + + .skill-row-enabled { + color: $text; + } + + .skill-row-disabled { + color: $text-muted; + } + """ + + def __init__(self) -> None: + super().__init__() + self._skill_manager = None + self._workspace_root = None + self._skills_data: dict = {} + self._columns_initialized = False # 标记列是否已初始化 + + def compose(self): + yield Label("Skill 配置", id="skill-header") + with Horizontal(id="skill-toolbar"): + yield Button("刷新", id="skill-refresh-btn", variant="default") + yield Button("启用选中", id="skill-enable-btn", variant="success") + yield Button("禁用选中", id="skill-disable-btn", variant="warning") + yield Button("启用全部", id="skill-enable-all-btn", variant="success") + yield Button("禁用全部", id="skill-disable-all-btn", variant="warning") + yield DataTable(id="skill-table") + yield Static("选择 Skill 查看详情", id="skill-detail") + + def set_managers(self, skill_manager, workspace_root) -> None: + """设置 Skill 管理器和工作区根目录.""" + self._skill_manager = skill_manager + self._workspace_root = workspace_root + # 发现 skills + if workspace_root: + from pathlib import Path + + skill_manager.discover(Path(workspace_root)) + self._refresh() + + def on_mount(self) -> None: + table = self.query_one("#skill-table", DataTable) + table.add_columns("启用", "名称", "类型", "描述") + table.cursor_type = "row" + self._columns_initialized = True + # 如果已经有数据,刷新显示 + if self._skills_data: + self._refresh() + + def _refresh(self) -> None: + """刷新 Skill 列表.""" + if self._skill_manager is None: + return + + # 重新获取所有技能(会从数据库加载禁用状态) + self._skills_data = self._skill_manager.get_all() + + # 如果列还没初始化,不刷新(等 on_mount 后自动刷新) + if not self._columns_initialized: + return + + table = self.query_one("#skill-table", DataTable) + table.clear() + + # 获取当前禁用状态用于显示 + disabled_set = self._skill_manager.get_disabled() + + for name, skill in self._skills_data.items(): + is_global = skill.metadata.get("is_global", True) + skill_type = "全局" if is_global else "项目" + # 使用禁用集合判断状态,确保与持久化数据一致 + is_enabled = name not in disabled_set + status = "✓" if is_enabled else "✗" + description = skill.description[:50] + "..." if len(skill.description) > 50 else skill.description + table.add_row(status, name, skill_type, description) + + def _update_detail(self, row_index: int) -> None: + """更新详情显示.""" + if self._skill_manager is None: + return + + table = self.query_one("#skill-table", DataTable) + if row_index is None or row_index < 0: + return + + row_data = table.get_row_at(row_index) + if not row_data: + return + + name = row_data[1] + skill = self._skills_data.get(name) + if not skill: + return + + detail_text = f"[bold]{skill.name}[/bold]\n位置: {skill.location}\n类型: {'全局' if skill.metadata.get('is_global', True) else '项目'}\n状态: {'启用' if skill.enabled else '禁用'}\n\n描述: {skill.description}" + + detail = self.query_one("#skill-detail", Static) + detail.update(detail_text) + + async def on_data_table_row_highlighted(self, event: DataTable.RowHighlighted) -> None: + """行高亮时更新详情.""" + if event.data_table.id == "skill-table": + self._update_detail(event.cursor_row) + + async def on_button_pressed(self, event: Button.Pressed) -> None: + if self._skill_manager is None: + return + + button_id = event.button.id or "" + + if button_id == "skill-refresh-btn": + if self._workspace_root: + from pathlib import Path + + self._skill_manager.discover(Path(self._workspace_root)) + self._refresh() + self.notify("已刷新 Skill 列表") + + elif button_id == "skill-enable-btn": + # 启用选中的 Skill + table = self.query_one("#skill-table", DataTable) + row_index = table.cursor_row + if row_index is None or row_index < 0: + self.notify("请先选择一个 Skill", severity="warning") + return + + row_data = table.get_row_at(row_index) + if not row_data: + return + + name = row_data[1] + disabled = self._skill_manager.get_disabled() + if name in disabled: + disabled.discard(name) + self._skill_manager.set_disabled(disabled, persist=True) + self._refresh() + self.notify(f"已启用: {name}") + else: + self.notify(f"{name} 已经是启用状态", severity="information") + + elif button_id == "skill-disable-btn": + # 禁用选中的 Skill + table = self.query_one("#skill-table", DataTable) + row_index = table.cursor_row + if row_index is None or row_index < 0: + self.notify("请先选择一个 Skill", severity="warning") + return + + row_data = table.get_row_at(row_index) + if not row_data: + return + + name = row_data[1] + disabled = self._skill_manager.get_disabled() + if name not in disabled: + disabled.add(name) + self._skill_manager.set_disabled(disabled, persist=True) + self._refresh() + self.notify(f"已禁用: {name}") + else: + self.notify(f"{name} 已经是禁用状态", severity="information") + + elif button_id == "skill-enable-all-btn": + self._skill_manager.set_disabled(set(), persist=True) + self._refresh() + self.notify("已启用所有 Skill") + + elif button_id == "skill-disable-all-btn": + all_names = set(self._skills_data.keys()) + self._skill_manager.set_disabled(all_names, persist=True) + self._refresh() + self.notify("已禁用所有 Skill") + + async def on_data_table_cell_selected(self, event: DataTable.CellSelected) -> None: + """单元格选中时切换启用状态(点击启用列).""" + if event.data_table.id != "skill-table": + return + + if event.column_key != 0: # 只在"启用"列点击时切换 + return + + if self._skill_manager is None: + return + + row_data = event.data_table.get_row_at(event.cursor_row) + if not row_data: + return + + name = row_data[1] + skill = self._skills_data.get(name) + if not skill: + return + + # 切换状态 + disabled = self._skill_manager.get_disabled() + if name in disabled: + disabled.discard(name) + else: + disabled.add(name) + + self._skill_manager.set_disabled(disabled, persist=True) + self._refresh() + + status = "启用" if name not in disabled else "禁用" + self.notify(f"已{status}: {name}") diff --git a/src/console/ui/widgets/stats_tab.py b/src/console/ui/widgets/stats_tab.py index 3cef052..b30559c 100644 --- a/src/console/ui/widgets/stats_tab.py +++ b/src/console/ui/widgets/stats_tab.py @@ -1,454 +1,445 @@ -"""Statistics tab — session stats, tool ranking, session management.""" - -from __future__ import annotations - -from typing import ClassVar - -from textual.containers import Horizontal, Vertical -from textual.screen import ModalScreen -from textual.widgets import Button, DataTable, Input, Label, Static - - -class RenameDialog(ModalScreen[str | None]): - """Modal dialog for renaming a session.""" - - DEFAULT_CSS = """ - RenameDialog { - align: center middle; - } - - #rename-dialog { - width: 40; - height: auto; - padding: 2; - border: thick $primary; - background: $surface; - } - - #rename-dialog > Label { - text-style: bold; - margin-bottom: 1; - } - - #rename-input { - margin-bottom: 1; - } - - #rename-buttons { - height: auto; - align: right middle; - } - - #rename-buttons Button { - margin-left: 1; - } - """ - - def __init__(self, session_id: int, current_name: str) -> None: - super().__init__() - self._session_id = session_id - self._current_name = current_name - - def compose(self): - with Vertical(id="rename-dialog"): - yield Label("Rename Session") - yield Input(value=self._current_name, id="rename-input") - with Horizontal(id="rename-buttons"): - yield Button("Cancel", id="cancel-btn", variant="default") - yield Button("OK", id="ok-btn", variant="primary") - - def on_button_pressed(self, event: Button.Pressed) -> None: - if event.button.id == "ok-btn": - new_name = self.query_one("#rename-input", Input).value - self.dismiss(new_name) - elif event.button.id == "cancel-btn": - self.dismiss(None) - - -class QuestionDialog(ModalScreen[bool]): - """Simple yes/no confirmation dialog.""" - - DEFAULT_CSS = """ - QuestionDialog { - align: center middle; - } - - #question-dialog { - width: 50; - height: auto; - padding: 2; - border: thick $primary; - background: $surface; - } - - #question-dialog > Label { - text-style: bold; - margin-bottom: 1; - } - - #question-message { - margin-bottom: 1; - } - - #question-buttons { - height: auto; - align: right middle; - } - - #question-buttons Button { - margin-left: 1; - } - """ - - def __init__(self, title: str, message: str) -> None: - super().__init__() - self._title = title - self._message = message - - def compose(self): - with Vertical(id="question-dialog"): - yield Label(self._title) - yield Label(self._message, id="question-message") - with Horizontal(id="question-buttons"): - yield Button("Cancel", id="cancel-btn", variant="default") - yield Button("OK", id="ok-btn", variant="primary") - - def on_button_pressed(self, event: Button.Pressed) -> None: - if event.button.id == "ok-btn": - self.dismiss(True) - elif event.button.id == "cancel-btn": - self.dismiss(False) - - -class StatsTab(Vertical): - """Statistics & session management tab. - - Displays: - - Overview (total sessions, calls, success rate) - - Current session stats (DataTable) - - Tool usage ranking (DataTable, top 10) - - Session list with rename/delete buttons. - """ - - DEFAULT_CSS: ClassVar[str] = """ - StatsTab { - height: 1fr; - width: 1fr; - padding: 0 1; - overflow-y: auto; - } - - #stats-placeholder { - height: 100%; - content-align: center middle; - color: $text-muted; - } - - #stats-empty { - height: 100%; - content-align: center middle; - color: $text-muted; - } - - .stats-header { - height: auto; - padding: 1 0; - text-style: bold; - color: $text; - border-bottom: solid $primary; - } - - #stats-overview { - height: auto; - padding: 1; - background: $surface; - border: solid $primary; - margin-bottom: 1; - } - - StatsTab DataTable { - height: auto; - max-height: 12; - margin-bottom: 1; - } - - .stats-session-row { - height: auto; - padding: 0 0; - margin-bottom: 0; - align: left middle; - } - - .stats-session-name { - width: 1fr; - height: auto; - padding: 0 1; - } - - .stats-session-row Button { - margin-left: 1; - } - - #stats-pagination { - height: auto; - padding: 0 0; - margin-bottom: 1; - align: center middle; - } - - #stats-pagination Label { - margin: 0 1; - } - - #stats-pagination Button { - margin: 0 0; - } - """ - - def __init__(self) -> None: - super().__init__() - self._db = None - self._current_session_id: int | None = None - self._session_page: int = 0 - self._sessions_per_page: int = 15 - - def compose(self): - yield Label("Loading statistics...", id="stats-placeholder") - - def set_database(self, db, current_session_id: int | None) -> None: - """Set database reference and refresh.""" - self._db = db - self._current_session_id = current_session_id - self.set_timer(0.0, self._refresh) - - def on_mount(self) -> None: - if self._db is not None: - self.set_timer(0.1, self._refresh) - self.set_interval(1.0, self._update_live_duration) - - def _update_live_duration(self) -> None: - """Update current session duration display every second.""" - if self._db is None or self._current_session_id is None: - return - try: - dt = self.query_one("#stats-current-session", DataTable) - except Exception: - return - - row = self._db.fetchone( - "SELECT created_at FROM sessions WHERE id = ?", - (self._current_session_id,), - ) - if not row: - return - - import time - - duration = time.time() - row[0] - duration_str = self._format_duration(duration) - dt.update_cell_at((0, 1), duration_str) - - def _format_duration(self, seconds: float) -> str: - """Format seconds to human-readable string.""" - if seconds < 60: - return f"{seconds:.0f}s" - minutes = int(seconds // 60) - secs = int(seconds % 60) - if minutes < 60: - return f"{minutes}m {secs}s" - hours = minutes // 60 - minutes = minutes % 60 - return f"{hours}h {minutes}m {secs}s" - - async def _refresh(self) -> None: - """Rebuild the entire stats UI.""" - if self._db is None: - return - - await self.remove_children() - self._build_content() - - def _build_content(self) -> None: - """Mount all content widgets.""" - if self._db is None: - return - - sessions = self._db.get_all_sessions() - total_sessions = len(sessions) - - # Compute global aggregates - total_calls = 0 - total_success = 0 - for s in sessions: - sid = s[0] - summary = self._db.get_session_summary(sid) - total_calls += summary["total_calls"] - total_success += summary["success_count"] - - success_rate = (total_success / total_calls * 100) if total_calls else 0.0 - - # Current session name - current_name = "" - if self._current_session_id is not None: - row = self._db.fetchone( - "SELECT name FROM sessions WHERE id = ?", - (self._current_session_id,), - ) - if row: - current_name = row[0] - - # --- Section 1: Overview --- - overview_text = ( - f"[bold]Overview[/bold]\n" - f"Total Sessions: {total_sessions}\n" - f"Total Tool Calls: {total_calls}\n" - f"Overall Success Rate: {success_rate:.1f}%\n" - f"Active Session: {current_name or 'N/A'}" - ) - self.mount(Static(overview_text, id="stats-overview")) - - # --- Section 2: Current session stats --- - if self._current_session_id is not None: - summary = self._db.get_session_summary(self._current_session_id) - if summary: - duration_str = self._format_duration(summary["duration"]) - self.mount(Label("Current Session", classes="stats-header")) - dt = DataTable(id="stats-current-session") - dt.add_columns("Metric", "Value") - dt.add_row("Duration", duration_str) - dt.add_row("Total Calls", str(summary["total_calls"])) - dt.add_row("Successful", str(summary["success_count"])) - dt.add_row("Failed", str(summary["fail_count"])) - dt.add_row("Success Rate", f"{summary['success_rate']:.1f}%") - self.mount(dt) - - # --- Section 3: Tool usage ranking --- - ranking = self._db.get_tool_usage_ranking(self._current_session_id) - if ranking: - self.mount(Label("Top Tools", classes="stats-header")) - dt = DataTable(id="stats-tool-ranking") - dt.add_columns("#", "Tool", "Calls", "Avg Time", "Total Time") - for i, (func_name, count, avg_dur, total_dur) in enumerate(ranking, 1): - avg_str = f"{avg_dur:.1f}ms" if avg_dur is not None else "N/A" - total_str = f"{total_dur:.1f}ms" if total_dur is not None else "N/A" - dt.add_row(str(i), func_name, str(count), avg_str, total_str) - self.mount(dt) - else: - self.mount(Label("No tool calls recorded yet.", id="stats-empty")) - - # --- Section 4: Session list --- - if sessions: - total_pages = (len(sessions) + self._sessions_per_page - 1) // self._sessions_per_page - if self._session_page >= total_pages: - self._session_page = total_pages - 1 - if self._session_page < 0: - self._session_page = 0 - - start_idx = self._session_page * self._sessions_per_page - end_idx = start_idx + self._sessions_per_page - page_sessions = sessions[start_idx:end_idx] - - self.mount(Label("Sessions", classes="stats-header")) - - if total_pages > 1: - nav = Horizontal( - Button("<< Prev", id="page-prev", variant="default", disabled=(self._session_page == 0)), - Label(f" Page {self._session_page + 1}/{total_pages} "), - Button( - "Next >>", id="page-next", variant="default", disabled=(self._session_page >= total_pages - 1) - ), - id="stats-pagination", - classes="stats-pagination", - ) - self.mount(nav) - - for s in page_sessions: - sid, name, _created_at, duration = s - is_active = sid == self._current_session_id - name_display = name or "Unnamed" - if is_active: - name_display += " (active)" - - duration_str = self._format_duration(duration) if duration else "in progress" - - # Get tool call count for this session - summary = self._db.get_session_summary(sid) - total_calls = summary.get("total_calls", 0) if summary else 0 - - text = f"{name_display} [{duration_str}] ({total_calls} calls)" - - row = Horizontal( - Static(text, classes="stats-session-name"), - Button("Rename", id=f"rename-{sid}", variant="default"), - Button("Delete", id=f"delete-{sid}", variant="error"), - classes="stats-session-row", - ) - self.mount(row) - if is_active: - row.query_one(f"#delete-{sid}", Button).disabled = True - - async def on_button_pressed(self, event: Button.Pressed) -> None: - """Handle rename/delete button clicks.""" - if self._db is None: - return - - button_id = event.button.id or "" - - # Pagination buttons - if button_id == "page-prev": - self._session_page -= 1 - await self._refresh() - return - elif button_id == "page-next": - self._session_page += 1 - await self._refresh() - return - - parts = button_id.split("-", 1) - if len(parts) != 2: - return - - action, sid_str = parts - try: - session_id = int(sid_str) - except ValueError: - return - - if action == "rename": - # Get current name - row = self._db.fetchone( - "SELECT name FROM sessions WHERE id = ?", - (session_id,), - ) - current_name = row[0] if row else "" - - async def on_rename(result: str | None) -> None: - if result is not None and result.strip(): - self._db.rename_session(session_id, result.strip()) - await self._refresh() - elif result is not None: - self.notify("Name cannot be empty.", severity="warning") - - self.app.push_screen(RenameDialog(session_id, current_name), on_rename) - - elif action == "delete": - if session_id == self._current_session_id: - self.notify("Cannot delete the active session.", severity="error") - return - - async def on_confirm(result: bool | None) -> None: - if result: - self._db.delete_session_async(session_id) - self.notify( - f"Session '{session_id}' scheduled for deletion.", - ) - await self._refresh() - - self.app.push_screen( - QuestionDialog( - "Delete Session", - f"Are you sure you want to delete session '{session_id}'?\n" - "All tool calls and snapshots for this session will also be deleted.", - ), - on_confirm, - ) +"""Statistics tab — session stats, tool ranking, session management.""" + +from __future__ import annotations + +from typing import ClassVar + +from textual.containers import Horizontal, Vertical +from textual.screen import ModalScreen +from textual.widgets import Button, DataTable, Input, Label, Static + + +class RenameDialog(ModalScreen[str | None]): + """Modal dialog for renaming a session.""" + + DEFAULT_CSS = """ + RenameDialog { + align: center middle; + } + + #rename-dialog { + width: 40; + height: auto; + padding: 2; + border: thick $primary; + background: $surface; + } + + #rename-dialog > Label { + text-style: bold; + margin-bottom: 1; + } + + #rename-input { + margin-bottom: 1; + } + + #rename-buttons { + height: auto; + align: right middle; + } + + #rename-buttons Button { + margin-left: 1; + } + """ + + def __init__(self, session_id: int, current_name: str) -> None: + super().__init__() + self._session_id = session_id + self._current_name = current_name + + def compose(self): + with Vertical(id="rename-dialog"): + yield Label("Rename Session") + yield Input(value=self._current_name, id="rename-input") + with Horizontal(id="rename-buttons"): + yield Button("Cancel", id="cancel-btn", variant="default") + yield Button("OK", id="ok-btn", variant="primary") + + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "ok-btn": + new_name = self.query_one("#rename-input", Input).value + self.dismiss(new_name) + elif event.button.id == "cancel-btn": + self.dismiss(None) + + +class QuestionDialog(ModalScreen[bool]): + """Simple yes/no confirmation dialog.""" + + DEFAULT_CSS = """ + QuestionDialog { + align: center middle; + } + + #question-dialog { + width: 50; + height: auto; + padding: 2; + border: thick $primary; + background: $surface; + } + + #question-dialog > Label { + text-style: bold; + margin-bottom: 1; + } + + #question-message { + margin-bottom: 1; + } + + #question-buttons { + height: auto; + align: right middle; + } + + #question-buttons Button { + margin-left: 1; + } + """ + + def __init__(self, title: str, message: str) -> None: + super().__init__() + self._title = title + self._message = message + + def compose(self): + with Vertical(id="question-dialog"): + yield Label(self._title) + yield Label(self._message, id="question-message") + with Horizontal(id="question-buttons"): + yield Button("Cancel", id="cancel-btn", variant="default") + yield Button("OK", id="ok-btn", variant="primary") + + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "ok-btn": + self.dismiss(True) + elif event.button.id == "cancel-btn": + self.dismiss(False) + + +class StatsTab(Vertical): + """Statistics & session management tab. + + Displays: + - Overview (total sessions, calls, success rate) + - Current session stats (DataTable) + - Tool usage ranking (DataTable, top 10) + - Session list with rename/delete buttons. + """ + + DEFAULT_CSS: ClassVar[str] = """ + StatsTab { + height: 1fr; + width: 1fr; + padding: 0 1; + overflow-y: auto; + } + + #stats-placeholder { + height: 100%; + content-align: center middle; + color: $text-muted; + } + + #stats-empty { + height: 100%; + content-align: center middle; + color: $text-muted; + } + + .stats-header { + height: auto; + padding: 1 0; + text-style: bold; + color: $text; + border-bottom: solid $primary; + } + + #stats-overview { + height: auto; + padding: 1; + background: $surface; + border: solid $primary; + margin-bottom: 1; + } + + StatsTab DataTable { + height: auto; + max-height: 12; + margin-bottom: 1; + } + + .stats-session-row { + height: auto; + padding: 0 0; + margin-bottom: 0; + align: left middle; + } + + .stats-session-name { + width: 1fr; + height: auto; + padding: 0 1; + } + + .stats-session-row Button { + margin-left: 1; + } + + #stats-pagination { + height: auto; + padding: 0 0; + margin-bottom: 1; + align: center middle; + } + + #stats-pagination Label { + margin: 0 1; + } + + #stats-pagination Button { + margin: 0 0; + } + """ + + def __init__(self) -> None: + super().__init__() + self._db = None + self._current_session_id: int | None = None + self._session_page: int = 0 + self._sessions_per_page: int = 15 + + def compose(self): + yield Label("Loading statistics...", id="stats-placeholder") + + def set_database(self, db, current_session_id: int | None) -> None: + """Set database reference and refresh.""" + self._db = db + self._current_session_id = current_session_id + self.set_timer(0.0, self._refresh) + + def on_mount(self) -> None: + if self._db is not None: + self.set_timer(0.1, self._refresh) + self.set_interval(1.0, self._update_live_duration) + + def _update_live_duration(self) -> None: + """Update current session duration display every second.""" + if self._db is None or self._current_session_id is None: + return + try: + dt = self.query_one("#stats-current-session", DataTable) + except Exception: + return + + row = self._db.fetchone( + "SELECT created_at FROM sessions WHERE id = ?", + (self._current_session_id,), + ) + if not row: + return + + import time + + duration = time.time() - row[0] + duration_str = self._format_duration(duration) + dt.update_cell_at((0, 1), duration_str) + + def _format_duration(self, seconds: float) -> str: + """Format seconds to human-readable string.""" + if seconds < 60: + return f"{seconds:.0f}s" + minutes = int(seconds // 60) + secs = int(seconds % 60) + if minutes < 60: + return f"{minutes}m {secs}s" + hours = minutes // 60 + minutes = minutes % 60 + return f"{hours}h {minutes}m {secs}s" + + async def _refresh(self) -> None: + """Rebuild the entire stats UI.""" + if self._db is None: + return + + await self.remove_children() + self._build_content() + + def _build_content(self) -> None: + """Mount all content widgets.""" + if self._db is None: + return + + sessions = self._db.get_all_sessions() + total_sessions = len(sessions) + + # Compute global aggregates + total_calls = 0 + total_success = 0 + for s in sessions: + sid = s[0] + summary = self._db.get_session_summary(sid) + total_calls += summary["total_calls"] + total_success += summary["success_count"] + + success_rate = (total_success / total_calls * 100) if total_calls else 0.0 + + # Current session name + current_name = "" + if self._current_session_id is not None: + row = self._db.fetchone( + "SELECT name FROM sessions WHERE id = ?", + (self._current_session_id,), + ) + if row: + current_name = row[0] + + # --- Section 1: Overview --- + overview_text = f"[bold]Overview[/bold]\nTotal Sessions: {total_sessions}\nTotal Tool Calls: {total_calls}\nOverall Success Rate: {success_rate:.1f}%\nActive Session: {current_name or 'N/A'}" + self.mount(Static(overview_text, id="stats-overview")) + + # --- Section 2: Current session stats --- + if self._current_session_id is not None: + summary = self._db.get_session_summary(self._current_session_id) + if summary: + duration_str = self._format_duration(summary["duration"]) + self.mount(Label("Current Session", classes="stats-header")) + dt = DataTable(id="stats-current-session") + dt.add_columns("Metric", "Value") + dt.add_row("Duration", duration_str) + dt.add_row("Total Calls", str(summary["total_calls"])) + dt.add_row("Successful", str(summary["success_count"])) + dt.add_row("Failed", str(summary["fail_count"])) + dt.add_row("Success Rate", f"{summary['success_rate']:.1f}%") + self.mount(dt) + + # --- Section 3: Tool usage ranking --- + ranking = self._db.get_tool_usage_ranking(self._current_session_id) + if ranking: + self.mount(Label("Top Tools", classes="stats-header")) + dt = DataTable(id="stats-tool-ranking") + dt.add_columns("#", "Tool", "Calls", "Avg Time", "Total Time") + for i, (func_name, count, avg_dur, total_dur) in enumerate(ranking, 1): + avg_str = f"{avg_dur:.1f}ms" if avg_dur is not None else "N/A" + total_str = f"{total_dur:.1f}ms" if total_dur is not None else "N/A" + dt.add_row(str(i), func_name, str(count), avg_str, total_str) + self.mount(dt) + else: + self.mount(Label("No tool calls recorded yet.", id="stats-empty")) + + # --- Section 4: Session list --- + if sessions: + total_pages = (len(sessions) + self._sessions_per_page - 1) // self._sessions_per_page + if self._session_page >= total_pages: + self._session_page = total_pages - 1 + if self._session_page < 0: + self._session_page = 0 + + start_idx = self._session_page * self._sessions_per_page + end_idx = start_idx + self._sessions_per_page + page_sessions = sessions[start_idx:end_idx] + + self.mount(Label("Sessions", classes="stats-header")) + + if total_pages > 1: + nav = Horizontal( + Button("<< Prev", id="page-prev", variant="default", disabled=(self._session_page == 0)), + Label(f" Page {self._session_page + 1}/{total_pages} "), + Button("Next >>", id="page-next", variant="default", disabled=(self._session_page >= total_pages - 1)), + id="stats-pagination", + classes="stats-pagination", + ) + self.mount(nav) + + for s in page_sessions: + sid, name, _created_at, duration = s + is_active = sid == self._current_session_id + name_display = name or "Unnamed" + if is_active: + name_display += " (active)" + + duration_str = self._format_duration(duration) if duration else "in progress" + + # Get tool call count for this session + summary = self._db.get_session_summary(sid) + total_calls = summary.get("total_calls", 0) if summary else 0 + + text = f"{name_display} [{duration_str}] ({total_calls} calls)" + + row = Horizontal( + Static(text, classes="stats-session-name"), + Button("Rename", id=f"rename-{sid}", variant="default"), + Button("Delete", id=f"delete-{sid}", variant="error"), + classes="stats-session-row", + ) + self.mount(row) + if is_active: + row.query_one(f"#delete-{sid}", Button).disabled = True + + async def on_button_pressed(self, event: Button.Pressed) -> None: + """Handle rename/delete button clicks.""" + if self._db is None: + return + + button_id = event.button.id or "" + + # Pagination buttons + if button_id == "page-prev": + self._session_page -= 1 + await self._refresh() + return + elif button_id == "page-next": + self._session_page += 1 + await self._refresh() + return + + parts = button_id.split("-", 1) + if len(parts) != 2: + return + + action, sid_str = parts + try: + session_id = int(sid_str) + except ValueError: + return + + if action == "rename": + # Get current name + row = self._db.fetchone( + "SELECT name FROM sessions WHERE id = ?", + (session_id,), + ) + current_name = row[0] if row else "" + + async def on_rename(result: str | None) -> None: + if result is not None and result.strip(): + self._db.rename_session(session_id, result.strip()) + await self._refresh() + elif result is not None: + self.notify("Name cannot be empty.", severity="warning") + + self.app.push_screen(RenameDialog(session_id, current_name), on_rename) + + elif action == "delete": + if session_id == self._current_session_id: + self.notify("Cannot delete the active session.", severity="error") + return + + async def on_confirm(result: bool | None) -> None: + if result: + self._db.delete_session_async(session_id) + self.notify( + f"Session '{session_id}' scheduled for deletion.", + ) + await self._refresh() + + self.app.push_screen( + QuestionDialog( + "Delete Session", + f"Are you sure you want to delete session '{session_id}'?\nAll tool calls and snapshots for this session will also be deleted.", + ), + on_confirm, + ) diff --git a/src/core/database_manager.py b/src/core/database_manager.py index 64b2fe6..12af8f2 100644 --- a/src/core/database_manager.py +++ b/src/core/database_manager.py @@ -1,716 +1,687 @@ -import contextlib -import json -import sqlite3 -import threading -import time -from pathlib import Path -from typing import ClassVar - -from src.constants.manual_aid import DB_FILE, MANUALAID_DIR - - -class DatabaseManager: - """Thread-safe SQLite3 database manager (singleton per workspace path).""" - - _instances: ClassVar[dict[str, DatabaseManager]] = {} - _instance_lock: ClassVar[threading.Lock] = threading.Lock() - - def __new__(cls, workspace_root: str) -> DatabaseManager: - with cls._instance_lock: - if workspace_root not in cls._instances: - instance = super().__new__(cls) - instance._initialized = False - cls._instances[workspace_root] = instance - return cls._instances[workspace_root] - - def __init__(self, workspace_root: str) -> None: - if self._initialized: - return - - self._workspace_root = workspace_root - self._db_dir = Path(workspace_root) / MANUALAID_DIR - self._db_path = self._db_dir / DB_FILE - self._lock = threading.RLock() - self._conn: sqlite3.Connection | None = None - - self._ensure_directory() - self._init_tables() - self._initialized = True - - @property - def db_path(self) -> Path: - return self._db_path - - def _ensure_directory(self) -> None: - self._db_dir.mkdir(parents=True, exist_ok=True) - - def _get_connection(self) -> sqlite3.Connection: - if self._conn is None: - conn = sqlite3.connect( - str(self._db_path), - isolation_level=None, - check_same_thread=False, - ) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA foreign_keys=ON") - conn.execute("PRAGMA busy_timeout=5000") - self._conn = conn - return self._conn - - def _init_tables(self) -> None: - conn = self._get_connection() - conn.executescript( - """ - CREATE TABLE IF NOT EXISTS sessions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL DEFAULT '', - created_at REAL NOT NULL, - duration REAL NOT NULL DEFAULT 0.0, - deleted INTEGER NOT NULL DEFAULT 0 - ); - - CREATE TABLE IF NOT EXISTS tool_calls ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id INTEGER NOT NULL, - func_name TEXT NOT NULL, - kwargs TEXT NOT NULL DEFAULT '', - timestamp REAL NOT NULL, - duration_ms REAL NOT NULL DEFAULT 0.0, - status TEXT NOT NULL DEFAULT 'success', - audit_status TEXT NOT NULL DEFAULT 'none', - FOREIGN KEY (session_id) REFERENCES sessions(id) - ); - - CREATE TABLE IF NOT EXISTS file_read_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id INTEGER NOT NULL, - file_path TEXT NOT NULL, - mtime REAL NOT NULL, - size INTEGER NOT NULL DEFAULT 0, - checksum TEXT NOT NULL DEFAULT '', - last_read_at REAL NOT NULL, - read_count INTEGER NOT NULL DEFAULT 1, - FOREIGN KEY (session_id) REFERENCES sessions(id), - UNIQUE(session_id, file_path) - ); - - CREATE TABLE IF NOT EXISTS file_snapshots ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_path TEXT NOT NULL, - old_hash TEXT, - new_hash TEXT NOT NULL, - diff_content TEXT NOT NULL DEFAULT '', - timestamp REAL NOT NULL, - session_id INTEGER, - audit_status TEXT NOT NULL DEFAULT 'PENDING_AUDIT', - FOREIGN KEY (session_id) REFERENCES sessions(id) - ); - - CREATE TABLE IF NOT EXISTS tool_call_summaries ( - session_id INTEGER NOT NULL, - func_name TEXT NOT NULL, - kwargs_json TEXT NOT NULL, - result TEXT NOT NULL, - timestamp REAL NOT NULL, - PRIMARY KEY (session_id, func_name, kwargs_json), - FOREIGN KEY (session_id) REFERENCES sessions(id) - ); - - CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL, - category TEXT NOT NULL DEFAULT 'general', - updated_at REAL NOT NULL - ); - - CREATE TABLE IF NOT EXISTS shell_audit ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - command TEXT NOT NULL, - description TEXT NOT NULL DEFAULT '', - timestamp REAL NOT NULL, - session_id INTEGER, - audit_status TEXT NOT NULL DEFAULT 'PENDING_AUDIT', - output TEXT NOT NULL DEFAULT '', - exit_code INTEGER, - executed_at REAL, - FOREIGN KEY (session_id) REFERENCES sessions(id) - ); - """ - ) - - # Phase 2 migration: add pending_content column to file_snapshots - self._migrate_add_pending_content(conn) - - # Phase 3 migration: rename args_hash to kwargs and truncate old data - if any(row[1] == "args_hash" for row in conn.execute("PRAGMA table_info(tool_calls)")): - self._migrate_args_hash_to_kwargs(conn) - - # Phase 4 migration: add session_id to file_read_records - if not any(row[1] == "session_id" for row in conn.execute("PRAGMA table_info(file_read_records)")): - self._migrate_file_read_records_add_session(conn) - - # Phase 5 migration: add deleted column to sessions - if not any(row[1] == "deleted" for row in conn.execute("PRAGMA table_info(sessions)")): - conn.execute("ALTER TABLE sessions ADD COLUMN deleted INTEGER NOT NULL DEFAULT 0") - - # Phase 6 migration: add config table - if not any(row[1] == "key" for row in conn.execute("PRAGMA table_info(config)")): - conn.execute( - """ - CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL, - category TEXT NOT NULL DEFAULT 'general', - updated_at REAL NOT NULL - ) - """ - ) - - # Create all indexes after migrations so they apply to both - # fresh databases and those upgraded from older schemas. - conn.execute("CREATE INDEX IF NOT EXISTS idx_tool_calls_session ON tool_calls(session_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_tool_calls_func ON tool_calls(func_name)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_file_snapshots_audit ON file_snapshots(audit_status)") - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path ON file_read_records(session_id, file_path)" - ) - conn.execute("CREATE INDEX IF NOT EXISTS idx_tool_call_summaries_session ON tool_call_summaries(session_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_shell_audit_status ON shell_audit(audit_status)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_shell_audit_session ON shell_audit(session_id)") - - @staticmethod - def _migrate_add_pending_content(conn: sqlite3.Connection) -> None: - with contextlib.suppress(sqlite3.OperationalError): - conn.execute("ALTER TABLE file_snapshots ADD COLUMN pending_content TEXT NOT NULL DEFAULT ''") - - @staticmethod - def _migrate_args_hash_to_kwargs(conn: sqlite3.Connection) -> None: - conn.execute("DELETE FROM tool_calls") - conn.execute("ALTER TABLE tool_calls RENAME COLUMN args_hash TO kwargs") - - @staticmethod - def _migrate_file_read_records_add_session(conn: sqlite3.Connection) -> None: - conn.executescript( - """ - DROP TABLE IF EXISTS file_read_records; - - CREATE TABLE file_read_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id INTEGER NOT NULL, - file_path TEXT NOT NULL, - mtime REAL NOT NULL, - size INTEGER NOT NULL DEFAULT 0, - checksum TEXT NOT NULL DEFAULT '', - last_read_at REAL NOT NULL, - read_count INTEGER NOT NULL DEFAULT 1, - FOREIGN KEY (session_id) REFERENCES sessions(id), - UNIQUE(session_id, file_path) - ); - - CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path ON file_read_records(session_id, file_path); - """ - ) - - def close(self) -> None: - with self._lock: - if self._conn is not None: - self._conn.close() - self._conn = None - - # -- Unified query interface -- - - def execute(self, sql: str, params: tuple = ()) -> sqlite3.Cursor: - with self._lock: - return self._get_connection().execute(sql, params) - - def fetchone(self, sql: str, params: tuple = ()) -> tuple | None: - with self._lock: - return self._get_connection().execute(sql, params).fetchone() - - def fetchall(self, sql: str, params: tuple = ()) -> list[tuple]: - with self._lock: - return self._get_connection().execute(sql, params).fetchall() - - # -- Session lifecycle -- - - def create_session(self, name: str = "") -> int: - cursor = self.execute( - "INSERT INTO sessions (name, created_at) VALUES (?, ?)", - (name, time.time()), - ) - return cursor.lastrowid - - def close_session(self, session_id: int) -> None: - row = self.fetchone("SELECT created_at FROM sessions WHERE id = ?", (session_id,)) - if row: - duration = time.time() - row[0] - self.execute( - "UPDATE sessions SET duration = ? WHERE id = ?", - (duration, session_id), - ) - # Mark empty sessions as deleted so they can be cleaned up later - if self.is_session_orphaned(session_id): - self.mark_session_deleted(session_id) - - def update_session_duration(self, session_id: int) -> None: - """Persist current elapsed duration without closing the session. - - Used by the periodic heartbeat so that abnormal termination (window - close, Ctrl+C, SIGKILL) loses at most the heartbeat interval's worth - of session duration data. - """ - row = self.fetchone("SELECT created_at FROM sessions WHERE id = ?", (session_id,)) - if row: - duration = time.time() - row[0] - self.execute( - "UPDATE sessions SET duration = ? WHERE id = ?", - (duration, session_id), - ) - - def mark_session_deleted(self, session_id: int) -> None: - """Set the deleted flag on a session. - - 设置会话上的删除标志""" - self.execute("UPDATE sessions SET deleted = 1 WHERE id = ?", (session_id,)) - - def restore_session_deleted_flag(self, session_id: int) -> None: - """Restore the deleted flag (set it back to 0) for an active session. - - 恢复删除标志(将其设回 0)为一个活跃会话 - """ - self.execute("UPDATE sessions SET deleted = 0 WHERE id = ?", (session_id,)) - - def get_sessions_with_deleted_flag(self) -> list[int]: - """Return IDs of all sessions with the deleted flag set. - - 返回所有设置了删除标志的会话的 ID - """ - rows = self.fetchall("SELECT id FROM sessions WHERE deleted = 1") - return [r[0] for r in rows] - - def is_session_orphaned(self, session_id: int) -> bool: - """Check if a session has no associated data in any related table. - - 检查一个会话是否在任何相关表中都没有关联数据 - """ - tables = ["tool_calls", "file_read_records", "file_snapshots", "tool_call_summaries", "shell_audit"] - for table in tables: - row = self.fetchone( - f"SELECT COUNT(*) FROM {table} WHERE session_id = ?", - (session_id,), - ) - if row and row[0] > 0: - return False - return True - - def delete_session_async(self, session_id: int) -> None: - """异步轮询删除会话 - - 设置删除标志,然后每 10 秒轮询一次(共重试 3 次) - 如果标志被心跳恢复,则取消删除操作 - 否则在第三次轮询后执行实际删除 - """ - self.mark_session_deleted(session_id) - - def _poll() -> None: - for _ in range(3): - time.sleep(10) - row = self.fetchone("SELECT deleted FROM sessions WHERE id = ?", (session_id,)) - if row and row[0] == 0: - return # Flag was restored, cancel deletion / 标志已恢复,取消删除 - # 三次轮询已过,标志仍被设置——执行删除 - # Three polls passed, flag still set -- execute deletion - self.delete_session(session_id) - - thread = threading.Thread(target=_poll, daemon=True) - thread.start() - - # -- Tool call logging -- - - def log_tool_call( - self, - session_id: int, - func_name: str, - kwargs: str, - duration_ms: float = 0.0, - status: str = "success", - audit_status: str = "none", - ) -> int: - cursor = self.execute( - "INSERT INTO tool_calls (session_id, func_name, kwargs, timestamp, duration_ms, status, audit_status) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (session_id, func_name, kwargs, time.time(), duration_ms, status, audit_status), - ) - return cursor.lastrowid - - def update_tool_call_status(self, call_id: int, status: str, audit_status: str) -> None: - self.execute( - "UPDATE tool_calls SET status = ?, audit_status = ? WHERE id = ?", - (status, audit_status, call_id), - ) - - # -- File read records -- - - def record_file_read(self, session_id: int, file_path: str, mtime: float, size: int, checksum: str) -> None: - with self._lock: - self._get_connection().execute( - "INSERT INTO file_read_records " - "(session_id, file_path, mtime, size, checksum, last_read_at, read_count) " - "VALUES (?, ?, ?, ?, ?, ?, 1) " - "ON CONFLICT(session_id, file_path) DO UPDATE SET " - "mtime = excluded.mtime, " - "size = excluded.size, " - "checksum = excluded.checksum, " - "last_read_at = excluded.last_read_at, " - "read_count = read_count + 1", - (session_id, file_path, mtime, size, checksum, time.time()), - ) - - def get_file_read_record(self, session_id: int, file_path: str) -> tuple | None: - return self.fetchone( - "SELECT id, session_id, file_path, mtime, size, checksum, last_read_at, read_count " - "FROM file_read_records WHERE session_id = ? AND file_path = ?", - (session_id, file_path), - ) - - # -- File snapshots -- - - def record_file_snapshot( - self, - file_path: str, - old_hash: str | None, - new_hash: str, - diff_content: str, - audit_status: str = "PENDING_AUDIT", - session_id: int | None = None, - pending_content: str = "", - ) -> int: - cursor = self.execute( - "INSERT INTO file_snapshots " - "(file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status, pending_content) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - (file_path, old_hash, new_hash, diff_content, time.time(), session_id, audit_status, pending_content), - ) - return cursor.lastrowid - - def update_snapshot_audit(self, snapshot_id: int, audit_status: str) -> None: - self.execute( - "UPDATE file_snapshots SET audit_status = ? WHERE id = ?", - (audit_status, snapshot_id), - ) - - def get_pending_audits(self) -> list[tuple]: - return self.fetchall( - "SELECT id, file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status" - " FROM file_snapshots WHERE audit_status = 'PENDING_AUDIT'" - ) - - def get_snapshot_by_id(self, snapshot_id: int) -> tuple | None: - return self.fetchone( - "SELECT id, file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status," - " pending_content FROM file_snapshots WHERE id = ?", - (snapshot_id,), - ) - - def get_snapshots_by_audit_status(self, status: str) -> list[tuple]: - return self.fetchall( - "SELECT id, file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status," - " pending_content FROM file_snapshots WHERE audit_status = ?", - (status,), - ) - - # -- Shell command audit -- - - def record_shell_command( - self, - command: str, - description: str = "", - session_id: int | None = None, - ) -> int: - """记录一条待审核的 Shell 命令. - - Args: - command: Shell 命令内容 - description: 命令描述 - session_id: 会话 ID - - Returns: - 新记录的 ID - """ - cursor = self.execute( - "INSERT INTO shell_audit (command, description, timestamp, session_id, audit_status) " - "VALUES (?, ?, ?, ?, 'PENDING_AUDIT')", - (command, description, time.time(), session_id), - ) - return cursor.lastrowid - - def update_shell_audit( - self, - shell_id: int, - audit_status: str, - output: str = "", - exit_code: int | None = None, - ) -> None: - """更新 Shell 命令审核状态及执行结果. - - Args: - shell_id: 记录 ID - audit_status: 审核状态 (APPROVED/REJECTED) - output: 命令执行输出 - exit_code: 命令退出码 - """ - if output or exit_code is not None: - self.execute( - "UPDATE shell_audit SET audit_status = ?, output = ?, exit_code = ?, executed_at = ? WHERE id = ?", - (audit_status, output, exit_code, time.time(), shell_id), - ) - else: - self.execute( - "UPDATE shell_audit SET audit_status = ? WHERE id = ?", - (audit_status, shell_id), - ) - - def get_shell_pending_audits(self) -> list[tuple]: - """获取所有待审核的 Shell 命令. - - Returns: - 待审核记录列表 (id, command, description, timestamp, session_id, audit_status) - """ - return self.fetchall( - "SELECT id, command, description, timestamp, session_id, audit_status" - " FROM shell_audit WHERE audit_status = 'PENDING_AUDIT'" - ) - - def get_shell_by_id(self, shell_id: int) -> tuple | None: - """根据 ID 获取 Shell 命令审核记录. - - Args: - shell_id: 记录 ID - - Returns: - 记录元组或 None - """ - return self.fetchone( - "SELECT id, command, description, timestamp, session_id," - " audit_status, output, exit_code, executed_at" - " FROM shell_audit WHERE id = ?", - (shell_id,), - ) - - def get_shell_completed(self, limit: int = 200) -> list[tuple]: - """获取所有已完成的 Shell 命令(已批准/已拒绝),按执行时间倒序. - - Args: - limit: 最大返回条数 - - Returns: - 已完成记录列表, 每条含 (id, command, description, timestamp, - session_id, audit_status, output, exit_code, executed_at) - """ - return self.fetchall( - "SELECT id, command, description, timestamp, session_id," - " audit_status, output, exit_code, executed_at" - " FROM shell_audit WHERE audit_status != 'PENDING_AUDIT'" - " ORDER BY COALESCE(executed_at, timestamp) DESC LIMIT ?", - (limit,), - ) - - # -- Session statistics and management -- - - def get_session_summary(self, session_id: int) -> dict: - """Aggregated stats for a single session.""" - session = self.fetchone( - "SELECT id, name, created_at, duration FROM sessions WHERE id = ?", - (session_id,), - ) - if not session: - return {} - - total = self.fetchone( - "SELECT COUNT(*), SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END), " - "SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) " - "FROM tool_calls WHERE session_id = ?", - (session_id,), - ) - total_calls, success_count, fail_count = total or (0, 0, 0) - - return { - "id": session[0], - "name": session[1], - "created_at": session[2], - "duration": session[3], - "total_calls": total_calls or 0, - "success_count": success_count or 0, - "fail_count": fail_count or 0, - "success_rate": (success_count / total_calls * 100) if total_calls else 0.0, - } - - def get_all_sessions(self) -> list[tuple]: - """All sessions ordered by created_at descending.""" - return self.fetchall("SELECT id, name, created_at, duration FROM sessions ORDER BY created_at DESC") - - def rename_session(self, session_id: int, name: str) -> None: - self.execute("UPDATE sessions SET name = ? WHERE id = ?", (name, session_id)) - - def delete_session(self, session_id: int) -> None: - with self._lock: - conn = self._get_connection() - conn.execute("BEGIN IMMEDIATE") - try: - conn.execute("DELETE FROM tool_calls WHERE session_id = ?", (session_id,)) - conn.execute("DELETE FROM file_snapshots WHERE session_id = ?", (session_id,)) - conn.execute("DELETE FROM file_read_records WHERE session_id = ?", (session_id,)) - conn.execute("DELETE FROM shell_audit WHERE session_id = ?", (session_id,)) - conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) - conn.execute("COMMIT") - except Exception: - conn.execute("ROLLBACK") - raise - - def get_tool_usage_ranking(self, session_id: int | None = None, limit: int = 10) -> list[tuple]: - """Returns list of (func_name, call_count, avg_duration_ms, total_duration_ms) ordered by count DESC.""" - if session_id is not None: - return self.fetchall( - "SELECT func_name, COUNT(*) as cnt, AVG(duration_ms) as avg_dur, SUM(duration_ms) as total_dur " - "FROM tool_calls WHERE session_id = ? " - "GROUP BY func_name ORDER BY cnt DESC LIMIT ?", - (session_id, limit), - ) - return self.fetchall( - "SELECT func_name, COUNT(*) as cnt, AVG(duration_ms) as avg_dur, SUM(duration_ms) as total_dur " - "FROM tool_calls " - "GROUP BY func_name ORDER BY cnt DESC LIMIT ?", - (limit,), - ) - - # -- Class-level cleanup for testing -- - - @classmethod - def reset_instances(cls) -> None: - with cls._instance_lock: - for instance in cls._instances.values(): - instance.close() - cls._instances.clear() - - # -- Tool call summaries -- - - def record_tool_call_summary( - self, - session_id: int, - func_name: str, - kwargs_json: str, - result: str, - ) -> None: - with self._lock: - self._get_connection().execute( - "INSERT INTO tool_call_summaries " - "(session_id, func_name, kwargs_json, result, timestamp) " - "VALUES (?, ?, ?, ?, ?) " - "ON CONFLICT(session_id, func_name, kwargs_json) DO UPDATE SET " - "result = excluded.result, " - "timestamp = excluded.timestamp", - (session_id, func_name, kwargs_json, result, time.time()), - ) - - def get_tool_call_summaries(self, session_id: int) -> list[tuple]: - """Get all tool call summaries for a session ordered by timestamp DESC.""" - return self.fetchall( - "SELECT session_id, func_name, kwargs_json, result, timestamp " - "FROM tool_call_summaries WHERE session_id = ? ORDER BY timestamp DESC", - (session_id,), - ) - - # -- Configuration management -- - - def get_config(self, key: str, default: str | None = None) -> str | None: - """Get a configuration value by key. - - Args: - key: Configuration key - default: Default value if key not found - - Returns: - Configuration value or default - """ - row = self.fetchone("SELECT value FROM config WHERE key = ?", (key,)) - return row[0] if row else default - - def set_config(self, key: str, value: str, category: str = "general") -> None: - """Set a configuration value. - - Args: - key: Configuration key - value: Configuration value - category: Configuration category (general, skill, env, etc.) - """ - self.execute( - "INSERT INTO config (key, value, category, updated_at) VALUES (?, ?, ?, ?) " - "ON CONFLICT(key) DO UPDATE SET value = excluded.value, category = excluded.category, " - "updated_at = excluded.updated_at", - (key, value, category, time.time()), - ) - - def delete_config(self, key: str) -> None: - """Delete a configuration value. - - Args: - key: Configuration key - """ - self.execute("DELETE FROM config WHERE key = ?", (key,)) - - def get_all_config(self, category: str | None = None) -> list[tuple]: - """Get all configuration values, optionally filtered by category. - - Args: - category: Optional category filter - - Returns: - List of (key, value, category, updated_at) tuples - """ - if category: - return self.fetchall( - "SELECT key, value, category, updated_at FROM config WHERE category = ? ORDER BY key", - (category,), - ) - return self.fetchall("SELECT key, value, category, updated_at FROM config ORDER BY category, key") - - def get_config_by_prefix(self, prefix: str) -> dict[str, str]: - """Get all configuration values with a given key prefix. - - Args: - prefix: Key prefix to filter by - - Returns: - Dictionary of key-value pairs - """ - rows = self.fetchall( - "SELECT key, value FROM config WHERE key LIKE ? ORDER BY key", - (f"{prefix}%",), - ) - return {row[0]: row[1] for row in rows} - - # -- Skill configuration shortcuts -- - - def get_disabled_skills(self) -> set[str]: - """Get the set of disabled skill names. - - Returns: - Set of disabled skill names - """ - value = self.get_config("skills.disabled") - if not value: - return set() - import json - - try: - return set(json.loads(value)) - except json.JSONDecodeError, TypeError: - return set() - - def set_disabled_skills(self, names) -> None: - """Set the disabled skill names. - - Args: - names: Collection of skill names to disable (set, list, or tuple) - """ - self.set_config("skills.disabled", json.dumps(sorted(set(names))), category="skill") +import contextlib +import json +import sqlite3 +import threading +import time +from pathlib import Path +from typing import ClassVar + +from src.constants.manual_aid import DB_FILE, MANUALAID_DIR + + +class DatabaseManager: + """Thread-safe SQLite3 database manager (singleton per workspace path).""" + + _instances: ClassVar[dict[str, DatabaseManager]] = {} + _instance_lock: ClassVar[threading.Lock] = threading.Lock() + + def __new__(cls, workspace_root: str) -> DatabaseManager: + with cls._instance_lock: + if workspace_root not in cls._instances: + instance = super().__new__(cls) + instance._initialized = False + cls._instances[workspace_root] = instance + return cls._instances[workspace_root] + + def __init__(self, workspace_root: str) -> None: + if self._initialized: + return + + self._workspace_root = workspace_root + self._db_dir = Path(workspace_root) / MANUALAID_DIR + self._db_path = self._db_dir / DB_FILE + self._lock = threading.RLock() + self._conn: sqlite3.Connection | None = None + + self._ensure_directory() + self._init_tables() + self._initialized = True + + @property + def db_path(self) -> Path: + return self._db_path + + def _ensure_directory(self) -> None: + self._db_dir.mkdir(parents=True, exist_ok=True) + + def _get_connection(self) -> sqlite3.Connection: + if self._conn is None: + conn = sqlite3.connect( + str(self._db_path), + isolation_level=None, + check_same_thread=False, + ) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + conn.execute("PRAGMA busy_timeout=5000") + self._conn = conn + return self._conn + + def _init_tables(self) -> None: + conn = self._get_connection() + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL DEFAULT '', + created_at REAL NOT NULL, + duration REAL NOT NULL DEFAULT 0.0, + deleted INTEGER NOT NULL DEFAULT 0 + ); + + CREATE TABLE IF NOT EXISTS tool_calls ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + func_name TEXT NOT NULL, + kwargs TEXT NOT NULL DEFAULT '', + timestamp REAL NOT NULL, + duration_ms REAL NOT NULL DEFAULT 0.0, + status TEXT NOT NULL DEFAULT 'success', + audit_status TEXT NOT NULL DEFAULT 'none', + FOREIGN KEY (session_id) REFERENCES sessions(id) + ); + + CREATE TABLE IF NOT EXISTS file_read_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + file_path TEXT NOT NULL, + mtime REAL NOT NULL, + size INTEGER NOT NULL DEFAULT 0, + checksum TEXT NOT NULL DEFAULT '', + last_read_at REAL NOT NULL, + read_count INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY (session_id) REFERENCES sessions(id), + UNIQUE(session_id, file_path) + ); + + CREATE TABLE IF NOT EXISTS file_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT NOT NULL, + old_hash TEXT, + new_hash TEXT NOT NULL, + diff_content TEXT NOT NULL DEFAULT '', + timestamp REAL NOT NULL, + session_id INTEGER, + audit_status TEXT NOT NULL DEFAULT 'PENDING_AUDIT', + FOREIGN KEY (session_id) REFERENCES sessions(id) + ); + + CREATE TABLE IF NOT EXISTS tool_call_summaries ( + session_id INTEGER NOT NULL, + func_name TEXT NOT NULL, + kwargs_json TEXT NOT NULL, + result TEXT NOT NULL, + timestamp REAL NOT NULL, + PRIMARY KEY (session_id, func_name, kwargs_json), + FOREIGN KEY (session_id) REFERENCES sessions(id) + ); + + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'general', + updated_at REAL NOT NULL + ); + + CREATE TABLE IF NOT EXISTS shell_audit ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + command TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + timestamp REAL NOT NULL, + session_id INTEGER, + audit_status TEXT NOT NULL DEFAULT 'PENDING_AUDIT', + output TEXT NOT NULL DEFAULT '', + exit_code INTEGER, + executed_at REAL, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ); + """ + ) + + # Phase 2 migration: add pending_content column to file_snapshots + self._migrate_add_pending_content(conn) + + # Phase 3 migration: rename args_hash to kwargs and truncate old data + if any(row[1] == "args_hash" for row in conn.execute("PRAGMA table_info(tool_calls)")): + self._migrate_args_hash_to_kwargs(conn) + + # Phase 4 migration: add session_id to file_read_records + if not any(row[1] == "session_id" for row in conn.execute("PRAGMA table_info(file_read_records)")): + self._migrate_file_read_records_add_session(conn) + + # Phase 5 migration: add deleted column to sessions + if not any(row[1] == "deleted" for row in conn.execute("PRAGMA table_info(sessions)")): + conn.execute("ALTER TABLE sessions ADD COLUMN deleted INTEGER NOT NULL DEFAULT 0") + + # Phase 6 migration: add config table + if not any(row[1] == "key" for row in conn.execute("PRAGMA table_info(config)")): + conn.execute( + """ + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'general', + updated_at REAL NOT NULL + ) + """ + ) + + # Create all indexes after migrations so they apply to both + # fresh databases and those upgraded from older schemas. + conn.execute("CREATE INDEX IF NOT EXISTS idx_tool_calls_session ON tool_calls(session_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_tool_calls_func ON tool_calls(func_name)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_file_snapshots_audit ON file_snapshots(audit_status)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path ON file_read_records(session_id, file_path)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_tool_call_summaries_session ON tool_call_summaries(session_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_shell_audit_status ON shell_audit(audit_status)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_shell_audit_session ON shell_audit(session_id)") + + @staticmethod + def _migrate_add_pending_content(conn: sqlite3.Connection) -> None: + with contextlib.suppress(sqlite3.OperationalError): + conn.execute("ALTER TABLE file_snapshots ADD COLUMN pending_content TEXT NOT NULL DEFAULT ''") + + @staticmethod + def _migrate_args_hash_to_kwargs(conn: sqlite3.Connection) -> None: + conn.execute("DELETE FROM tool_calls") + conn.execute("ALTER TABLE tool_calls RENAME COLUMN args_hash TO kwargs") + + @staticmethod + def _migrate_file_read_records_add_session(conn: sqlite3.Connection) -> None: + conn.executescript( + """ + DROP TABLE IF EXISTS file_read_records; + + CREATE TABLE file_read_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + file_path TEXT NOT NULL, + mtime REAL NOT NULL, + size INTEGER NOT NULL DEFAULT 0, + checksum TEXT NOT NULL DEFAULT '', + last_read_at REAL NOT NULL, + read_count INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY (session_id) REFERENCES sessions(id), + UNIQUE(session_id, file_path) + ); + + CREATE INDEX IF NOT EXISTS idx_file_read_records_session_path ON file_read_records(session_id, file_path); + """ + ) + + def close(self) -> None: + with self._lock: + if self._conn is not None: + self._conn.close() + self._conn = None + + # -- Unified query interface -- + + def execute(self, sql: str, params: tuple = ()) -> sqlite3.Cursor: + with self._lock: + return self._get_connection().execute(sql, params) + + def fetchone(self, sql: str, params: tuple = ()) -> tuple | None: + with self._lock: + return self._get_connection().execute(sql, params).fetchone() + + def fetchall(self, sql: str, params: tuple = ()) -> list[tuple]: + with self._lock: + return self._get_connection().execute(sql, params).fetchall() + + # -- Session lifecycle -- + + def create_session(self, name: str = "") -> int: + cursor = self.execute( + "INSERT INTO sessions (name, created_at) VALUES (?, ?)", + (name, time.time()), + ) + return cursor.lastrowid + + def close_session(self, session_id: int) -> None: + row = self.fetchone("SELECT created_at FROM sessions WHERE id = ?", (session_id,)) + if row: + duration = time.time() - row[0] + self.execute( + "UPDATE sessions SET duration = ? WHERE id = ?", + (duration, session_id), + ) + # Mark empty sessions as deleted so they can be cleaned up later + if self.is_session_orphaned(session_id): + self.mark_session_deleted(session_id) + + def update_session_duration(self, session_id: int) -> None: + """Persist current elapsed duration without closing the session. + + Used by the periodic heartbeat so that abnormal termination (window + close, Ctrl+C, SIGKILL) loses at most the heartbeat interval's worth + of session duration data. + """ + row = self.fetchone("SELECT created_at FROM sessions WHERE id = ?", (session_id,)) + if row: + duration = time.time() - row[0] + self.execute( + "UPDATE sessions SET duration = ? WHERE id = ?", + (duration, session_id), + ) + + def mark_session_deleted(self, session_id: int) -> None: + """Set the deleted flag on a session. + + 设置会话上的删除标志""" + self.execute("UPDATE sessions SET deleted = 1 WHERE id = ?", (session_id,)) + + def restore_session_deleted_flag(self, session_id: int) -> None: + """Restore the deleted flag (set it back to 0) for an active session. + + 恢复删除标志(将其设回 0)为一个活跃会话 + """ + self.execute("UPDATE sessions SET deleted = 0 WHERE id = ?", (session_id,)) + + def get_sessions_with_deleted_flag(self) -> list[int]: + """Return IDs of all sessions with the deleted flag set. + + 返回所有设置了删除标志的会话的 ID + """ + rows = self.fetchall("SELECT id FROM sessions WHERE deleted = 1") + return [r[0] for r in rows] + + def is_session_orphaned(self, session_id: int) -> bool: + """Check if a session has no associated data in any related table. + + 检查一个会话是否在任何相关表中都没有关联数据 + """ + tables = ["tool_calls", "file_read_records", "file_snapshots", "tool_call_summaries", "shell_audit"] + for table in tables: + row = self.fetchone( + f"SELECT COUNT(*) FROM {table} WHERE session_id = ?", + (session_id,), + ) + if row and row[0] > 0: + return False + return True + + def delete_session_async(self, session_id: int) -> None: + """异步轮询删除会话 + + 设置删除标志,然后每 10 秒轮询一次(共重试 3 次) + 如果标志被心跳恢复,则取消删除操作 + 否则在第三次轮询后执行实际删除 + """ + self.mark_session_deleted(session_id) + + def _poll() -> None: + for _ in range(3): + time.sleep(10) + row = self.fetchone("SELECT deleted FROM sessions WHERE id = ?", (session_id,)) + if row and row[0] == 0: + return # Flag was restored, cancel deletion / 标志已恢复,取消删除 + # 三次轮询已过,标志仍被设置——执行删除 + # Three polls passed, flag still set -- execute deletion + self.delete_session(session_id) + + thread = threading.Thread(target=_poll, daemon=True) + thread.start() + + # -- Tool call logging -- + + def log_tool_call( + self, + session_id: int, + func_name: str, + kwargs: str, + duration_ms: float = 0.0, + status: str = "success", + audit_status: str = "none", + ) -> int: + cursor = self.execute( + "INSERT INTO tool_calls (session_id, func_name, kwargs, timestamp, duration_ms, status, audit_status) VALUES (?, ?, ?, ?, ?, ?, ?)", + (session_id, func_name, kwargs, time.time(), duration_ms, status, audit_status), + ) + return cursor.lastrowid + + def update_tool_call_status(self, call_id: int, status: str, audit_status: str) -> None: + self.execute( + "UPDATE tool_calls SET status = ?, audit_status = ? WHERE id = ?", + (status, audit_status, call_id), + ) + + # -- File read records -- + + def record_file_read(self, session_id: int, file_path: str, mtime: float, size: int, checksum: str) -> None: + with self._lock: + self._get_connection().execute( + "INSERT INTO file_read_records " + "(session_id, file_path, mtime, size, checksum, last_read_at, read_count) " + "VALUES (?, ?, ?, ?, ?, ?, 1) " + "ON CONFLICT(session_id, file_path) DO UPDATE SET " + "mtime = excluded.mtime, " + "size = excluded.size, " + "checksum = excluded.checksum, " + "last_read_at = excluded.last_read_at, " + "read_count = read_count + 1", + (session_id, file_path, mtime, size, checksum, time.time()), + ) + + def get_file_read_record(self, session_id: int, file_path: str) -> tuple | None: + return self.fetchone( + "SELECT id, session_id, file_path, mtime, size, checksum, last_read_at, read_count FROM file_read_records WHERE session_id = ? AND file_path = ?", + (session_id, file_path), + ) + + # -- File snapshots -- + + def record_file_snapshot( + self, + file_path: str, + old_hash: str | None, + new_hash: str, + diff_content: str, + audit_status: str = "PENDING_AUDIT", + session_id: int | None = None, + pending_content: str = "", + ) -> int: + cursor = self.execute( + "INSERT INTO file_snapshots (file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status, pending_content) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (file_path, old_hash, new_hash, diff_content, time.time(), session_id, audit_status, pending_content), + ) + return cursor.lastrowid + + def update_snapshot_audit(self, snapshot_id: int, audit_status: str) -> None: + self.execute( + "UPDATE file_snapshots SET audit_status = ? WHERE id = ?", + (audit_status, snapshot_id), + ) + + def get_pending_audits(self) -> list[tuple]: + return self.fetchall("SELECT id, file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status FROM file_snapshots WHERE audit_status = 'PENDING_AUDIT'") + + def get_snapshot_by_id(self, snapshot_id: int) -> tuple | None: + return self.fetchone( + "SELECT id, file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status, pending_content FROM file_snapshots WHERE id = ?", + (snapshot_id,), + ) + + def get_snapshots_by_audit_status(self, status: str) -> list[tuple]: + return self.fetchall( + "SELECT id, file_path, old_hash, new_hash, diff_content, timestamp, session_id, audit_status, pending_content FROM file_snapshots WHERE audit_status = ?", + (status,), + ) + + # -- Shell command audit -- + + def record_shell_command( + self, + command: str, + description: str = "", + session_id: int | None = None, + ) -> int: + """记录一条待审核的 Shell 命令. + + Args: + command: Shell 命令内容 + description: 命令描述 + session_id: 会话 ID + + Returns: + 新记录的 ID + """ + cursor = self.execute( + "INSERT INTO shell_audit (command, description, timestamp, session_id, audit_status) VALUES (?, ?, ?, ?, 'PENDING_AUDIT')", + (command, description, time.time(), session_id), + ) + return cursor.lastrowid + + def update_shell_audit( + self, + shell_id: int, + audit_status: str, + output: str = "", + exit_code: int | None = None, + ) -> None: + """更新 Shell 命令审核状态及执行结果. + + Args: + shell_id: 记录 ID + audit_status: 审核状态 (APPROVED/REJECTED) + output: 命令执行输出 + exit_code: 命令退出码 + """ + if output or exit_code is not None: + self.execute( + "UPDATE shell_audit SET audit_status = ?, output = ?, exit_code = ?, executed_at = ? WHERE id = ?", + (audit_status, output, exit_code, time.time(), shell_id), + ) + else: + self.execute( + "UPDATE shell_audit SET audit_status = ? WHERE id = ?", + (audit_status, shell_id), + ) + + def get_shell_pending_audits(self) -> list[tuple]: + """获取所有待审核的 Shell 命令. + + Returns: + 待审核记录列表 (id, command, description, timestamp, session_id, audit_status) + """ + return self.fetchall("SELECT id, command, description, timestamp, session_id, audit_status FROM shell_audit WHERE audit_status = 'PENDING_AUDIT'") + + def get_shell_by_id(self, shell_id: int) -> tuple | None: + """根据 ID 获取 Shell 命令审核记录. + + Args: + shell_id: 记录 ID + + Returns: + 记录元组或 None + """ + return self.fetchone( + "SELECT id, command, description, timestamp, session_id, audit_status, output, exit_code, executed_at FROM shell_audit WHERE id = ?", + (shell_id,), + ) + + def get_shell_completed(self, limit: int = 200) -> list[tuple]: + """获取所有已完成的 Shell 命令(已批准/已拒绝),按执行时间倒序. + + Args: + limit: 最大返回条数 + + Returns: + 已完成记录列表, 每条含 (id, command, description, timestamp, + session_id, audit_status, output, exit_code, executed_at) + """ + return self.fetchall( + "SELECT id, command, description, timestamp, session_id, audit_status, output, exit_code, executed_at FROM shell_audit WHERE audit_status != 'PENDING_AUDIT' ORDER BY COALESCE(executed_at, timestamp) DESC LIMIT ?", + (limit,), + ) + + # -- Session statistics and management -- + + def get_session_summary(self, session_id: int) -> dict: + """Aggregated stats for a single session.""" + session = self.fetchone( + "SELECT id, name, created_at, duration FROM sessions WHERE id = ?", + (session_id,), + ) + if not session: + return {} + + total = self.fetchone( + "SELECT COUNT(*), SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END), SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) FROM tool_calls WHERE session_id = ?", + (session_id,), + ) + total_calls, success_count, fail_count = total or (0, 0, 0) + + return { + "id": session[0], + "name": session[1], + "created_at": session[2], + "duration": session[3], + "total_calls": total_calls or 0, + "success_count": success_count or 0, + "fail_count": fail_count or 0, + "success_rate": (success_count / total_calls * 100) if total_calls else 0.0, + } + + def get_all_sessions(self) -> list[tuple]: + """All sessions ordered by created_at descending.""" + return self.fetchall("SELECT id, name, created_at, duration FROM sessions ORDER BY created_at DESC") + + def rename_session(self, session_id: int, name: str) -> None: + self.execute("UPDATE sessions SET name = ? WHERE id = ?", (name, session_id)) + + def delete_session(self, session_id: int) -> None: + with self._lock: + conn = self._get_connection() + conn.execute("BEGIN IMMEDIATE") + try: + conn.execute("DELETE FROM tool_calls WHERE session_id = ?", (session_id,)) + conn.execute("DELETE FROM file_snapshots WHERE session_id = ?", (session_id,)) + conn.execute("DELETE FROM file_read_records WHERE session_id = ?", (session_id,)) + conn.execute("DELETE FROM shell_audit WHERE session_id = ?", (session_id,)) + conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) + conn.execute("COMMIT") + except Exception: + conn.execute("ROLLBACK") + raise + + def get_tool_usage_ranking(self, session_id: int | None = None, limit: int = 10) -> list[tuple]: + """Returns list of (func_name, call_count, avg_duration_ms, total_duration_ms) ordered by count DESC.""" + if session_id is not None: + return self.fetchall( + "SELECT func_name, COUNT(*) as cnt, AVG(duration_ms) as avg_dur, SUM(duration_ms) as total_dur FROM tool_calls WHERE session_id = ? GROUP BY func_name ORDER BY cnt DESC LIMIT ?", + (session_id, limit), + ) + return self.fetchall( + "SELECT func_name, COUNT(*) as cnt, AVG(duration_ms) as avg_dur, SUM(duration_ms) as total_dur FROM tool_calls GROUP BY func_name ORDER BY cnt DESC LIMIT ?", + (limit,), + ) + + # -- Class-level cleanup for testing -- + + @classmethod + def reset_instances(cls) -> None: + with cls._instance_lock: + for instance in cls._instances.values(): + instance.close() + cls._instances.clear() + + # -- Tool call summaries -- + + def record_tool_call_summary( + self, + session_id: int, + func_name: str, + kwargs_json: str, + result: str, + ) -> None: + with self._lock: + self._get_connection().execute( + "INSERT INTO tool_call_summaries " + "(session_id, func_name, kwargs_json, result, timestamp) " + "VALUES (?, ?, ?, ?, ?) " + "ON CONFLICT(session_id, func_name, kwargs_json) DO UPDATE SET " + "result = excluded.result, " + "timestamp = excluded.timestamp", + (session_id, func_name, kwargs_json, result, time.time()), + ) + + def get_tool_call_summaries(self, session_id: int) -> list[tuple]: + """Get all tool call summaries for a session ordered by timestamp DESC.""" + return self.fetchall( + "SELECT session_id, func_name, kwargs_json, result, timestamp FROM tool_call_summaries WHERE session_id = ? ORDER BY timestamp DESC", + (session_id,), + ) + + # -- Configuration management -- + + def get_config(self, key: str, default: str | None = None) -> str | None: + """Get a configuration value by key. + + Args: + key: Configuration key + default: Default value if key not found + + Returns: + Configuration value or default + """ + row = self.fetchone("SELECT value FROM config WHERE key = ?", (key,)) + return row[0] if row else default + + def set_config(self, key: str, value: str, category: str = "general") -> None: + """Set a configuration value. + + Args: + key: Configuration key + value: Configuration value + category: Configuration category (general, skill, env, etc.) + """ + self.execute( + "INSERT INTO config (key, value, category, updated_at) VALUES (?, ?, ?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value, category = excluded.category, updated_at = excluded.updated_at", + (key, value, category, time.time()), + ) + + def delete_config(self, key: str) -> None: + """Delete a configuration value. + + Args: + key: Configuration key + """ + self.execute("DELETE FROM config WHERE key = ?", (key,)) + + def get_all_config(self, category: str | None = None) -> list[tuple]: + """Get all configuration values, optionally filtered by category. + + Args: + category: Optional category filter + + Returns: + List of (key, value, category, updated_at) tuples + """ + if category: + return self.fetchall( + "SELECT key, value, category, updated_at FROM config WHERE category = ? ORDER BY key", + (category,), + ) + return self.fetchall("SELECT key, value, category, updated_at FROM config ORDER BY category, key") + + def get_config_by_prefix(self, prefix: str) -> dict[str, str]: + """Get all configuration values with a given key prefix. + + Args: + prefix: Key prefix to filter by + + Returns: + Dictionary of key-value pairs + """ + rows = self.fetchall( + "SELECT key, value FROM config WHERE key LIKE ? ORDER BY key", + (f"{prefix}%",), + ) + return {row[0]: row[1] for row in rows} + + # -- Skill configuration shortcuts -- + + def get_disabled_skills(self) -> set[str]: + """Get the set of disabled skill names. + + Returns: + Set of disabled skill names + """ + value = self.get_config("skills.disabled") + if not value: + return set() + import json + + try: + return set(json.loads(value)) + except json.JSONDecodeError, TypeError: + return set() + + def set_disabled_skills(self, names) -> None: + """Set the disabled skill names. + + Args: + names: Collection of skill names to disable (set, list, or tuple) + """ + self.set_config("skills.disabled", json.dumps(sorted(set(names))), category="skill") diff --git a/src/core/launcher.py b/src/core/launcher.py index 319f8e2..e9193b3 100644 --- a/src/core/launcher.py +++ b/src/core/launcher.py @@ -122,10 +122,7 @@ def launch(path: str, console: Any) -> bool: console.print(f"[bold red]{error_msg}[/bold red]") # 显示手动启动帮助 - if _is_frozen(): - manual_cmd = f'"{_get_executable_path()}" -p "{workspace_path}"' - else: - manual_cmd = f'python main.py -p "{workspace_path}"' + manual_cmd = f'"{_get_executable_path()}" -p "{workspace_path}"' if _is_frozen() else f'python main.py -p "{workspace_path}"' info_panel = Panel( f"[yellow]手动启动命令:[/yellow]\n{manual_cmd}\n\n[dim]请在新终端中手动运行以上命令[/dim]", diff --git a/src/core/paste_window.py b/src/core/paste_window.py index 952ba58..fe271d5 100644 --- a/src/core/paste_window.py +++ b/src/core/paste_window.py @@ -1,229 +1,223 @@ -import threading -import tkinter as tk -from collections.abc import Callable -from tkinter import scrolledtext - - -class PasteWindow: - """超大文本粘贴窗口,单例模式,每次关闭后完全销毁线程和Tk实例""" - - _instance: PasteWindow | None = None - _current_thread: threading.Thread | None = None - _current_root: tk.Tk | None = None - - def __init__(self): - self._callback: Callable[[str], None] | None = None - self._window: tk.Toplevel | None = None - self._text_widget: scrolledtext.ScrolledText | None = None - self._stats_label: tk.Label | None = None - - @classmethod - def get_instance(cls) -> PasteWindow: - """获取单例实例""" - if cls._instance is None: - cls._instance = cls() - return cls._instance - - def _run_tkinter(self, initial_text: str, title: str) -> None: - """在独立线程中运行 tkinter 事件循环""" - # 创建新的 Tk 实例 - root = tk.Tk() - root.withdraw() # 隐藏主窗口 - PasteWindow._current_root = root - - # 创建粘贴窗口 - window = tk.Toplevel(root) - window.title(title) - window.geometry("800x600") - window.minsize(400, 300) - self._window = window - - # 创建文本区域 - frame = tk.Frame(window) - frame.pack(fill="both", expand=True, padx=10, pady=10) - - text_widget = scrolledtext.ScrolledText( - frame, wrap=tk.WORD, font=("Consolas", 11), undo=True, autoseparators=True, maxundo=100 - ) - text_widget.pack(fill="both", expand=True) - self._text_widget = text_widget - - if initial_text: - text_widget.insert(1.0, initial_text) - - # 按钮框架 - button_frame = tk.Frame(window) - button_frame.pack(fill="x", padx=10, pady=(0, 10)) - - # 统计信息标签 - stats_label = tk.Label(button_frame, text="字符数: 0", fg="gray") - stats_label.pack(side="left", padx=5) - self._stats_label = stats_label - - # 按钮 - copy_btn = tk.Button(button_frame, text="复制文本", command=self._copy_to_clipboard) - copy_btn.pack(side="right", padx=5) - - clear_btn = tk.Button(button_frame, text="清空", command=self._clear_text) - clear_btn.pack(side="right", padx=5) - - cancel_btn = tk.Button(button_frame, text="取消", command=self._on_cancel) - cancel_btn.pack(side="right", padx=5) - - confirm_btn = tk.Button(button_frame, text="确认粘贴", command=self._on_confirm, bg="#4CAF50", fg="white") - confirm_btn.pack(side="right", padx=5) - - # 绑定快捷键 - window.bind("", lambda e: self._select_all()) - window.bind("", lambda e: self._on_cancel()) - window.bind("", lambda e: self._on_confirm()) - - # 窗口关闭事件 - window.protocol("WM_DELETE_WINDOW", self._on_cancel) - - # 更新统计信息的绑定 - text_widget.bind("", lambda e: self._update_stats()) - - # 初始更新统计 - self._update_stats() - - # 设置焦点 - text_widget.focus_set() - - # 启动 tkinter 事件循环 - try: - root.mainloop() - except tk.TclError: - pass - finally: - # 清理资源 - try: - if root.winfo_exists(): - root.destroy() - except Exception: - pass - PasteWindow._current_root = None - self._window = None - self._text_widget = None - self._stats_label = None - PasteWindow._current_thread = None - - def _copy_to_clipboard(self) -> None: - """复制到剪贴板""" - content = self.get_text() - if content and self._window: - try: - self._window.clipboard_clear() - self._window.clipboard_append(content) - self._show_temporary_message("已复制") - except Exception: - pass - - def _show_temporary_message(self, message: str) -> None: - """显示临时消息""" - if self._stats_label and self._stats_label.winfo_exists(): - original = self._stats_label.cget("text") - self._stats_label.config(text=f"✓ {message}", fg="green") - if self._window: - self._window.after(1500, lambda: self._stats_label.config(text=original, fg="gray")) - - def _on_confirm(self) -> None: - """确认""" - content = self.get_text() - if self._callback: - self._callback(content) - self.hide() - - def _on_cancel(self) -> None: - """取消""" - self.hide() - - def _update_stats(self) -> None: - """更新统计""" - if self._text_widget and self._text_widget.winfo_exists(): - try: - content = self._text_widget.get(1.0, tk.END) - char_count = len(content) - 1 - line_count = int(self._text_widget.index(tk.END).split(".")[0]) - 1 - kb_size = char_count / 1024 - size_text = f"{kb_size:.1f}KB" if kb_size >= 1 else f"{char_count}B" - self._stats_label.config(text=f"字符数: {char_count:,} ({size_text}) | 行数: {line_count}") - except Exception: - pass - - def _clear_text(self) -> None: - """清空""" - if self._text_widget and self._text_widget.winfo_exists(): - self._text_widget.delete(1.0, tk.END) - self._update_stats() - - def _select_all(self) -> str | None: - """全选""" - if self._text_widget and self._text_widget.winfo_exists(): - self._text_widget.tag_add(tk.SEL, "1.0", tk.END) - self._text_widget.mark_set(tk.INSERT, "1.0") - return "break" - return None - - def get_text(self) -> str: - """获取文本""" - if self._text_widget and self._text_widget.winfo_exists(): - content = self._text_widget.get(1.0, tk.END) - if content.endswith("\n"): - content = content[:-1] - return content - return "" - - def show( - self, callback: Callable[[str], None] | None = None, initial_text: str = "", title: str = "粘贴超大文本" - ) -> None: - """ - 显示粘贴窗口(独立线程,非阻塞) - - 每次调用都会创建新的线程和Tk实例,确保不会出现线程冲突 - """ - self._callback = callback - - # 如果已有窗口在运行,先关闭 - if PasteWindow._current_thread and PasteWindow._current_thread.is_alive(): - self.hide() - # 等待线程完全结束 - PasteWindow._current_thread.join(timeout=1.0) - - # 创建新线程 - PasteWindow._current_thread = threading.Thread( - target=self._run_tkinter, args=(initial_text, title), daemon=True - ) - PasteWindow._current_thread.start() - - def hide(self) -> None: - """关闭窗口""" - if self._window and self._window.winfo_exists(): - try: - self._window.quit() # 退出 mainloop - self._window.destroy() - except Exception: - pass - - # 等待线程结束,避免资源冲突 - if PasteWindow._current_thread and PasteWindow._current_thread != threading.current_thread(): - PasteWindow._current_thread.join(timeout=0.5) - - -# 便捷函数 -_paste_window: PasteWindow | None = None - - -def show_paste_window(callback: Callable[[str], None] | None = None, initial_text: str = "") -> None: - """显示粘贴窗口(非阻塞)""" - global _paste_window - if _paste_window is None: - _paste_window = PasteWindow.get_instance() - _paste_window.show(callback=callback, initial_text=initial_text) - - -def close_paste_window() -> None: - """关闭粘贴窗口""" - global _paste_window - if _paste_window: - _paste_window.hide() +import threading +import tkinter as tk +from collections.abc import Callable +from tkinter import scrolledtext + + +class PasteWindow: + """超大文本粘贴窗口,单例模式,每次关闭后完全销毁线程和Tk实例""" + + _instance: PasteWindow | None = None + _current_thread: threading.Thread | None = None + _current_root: tk.Tk | None = None + + def __init__(self): + self._callback: Callable[[str], None] | None = None + self._window: tk.Toplevel | None = None + self._text_widget: scrolledtext.ScrolledText | None = None + self._stats_label: tk.Label | None = None + + @classmethod + def get_instance(cls) -> PasteWindow: + """获取单例实例""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _run_tkinter(self, initial_text: str, title: str) -> None: + """在独立线程中运行 tkinter 事件循环""" + # 创建新的 Tk 实例 + root = tk.Tk() + root.withdraw() # 隐藏主窗口 + PasteWindow._current_root = root + + # 创建粘贴窗口 + window = tk.Toplevel(root) + window.title(title) + window.geometry("800x600") + window.minsize(400, 300) + self._window = window + + # 创建文本区域 + frame = tk.Frame(window) + frame.pack(fill="both", expand=True, padx=10, pady=10) + + text_widget = scrolledtext.ScrolledText(frame, wrap=tk.WORD, font=("Consolas", 11), undo=True, autoseparators=True, maxundo=100) + text_widget.pack(fill="both", expand=True) + self._text_widget = text_widget + + if initial_text: + text_widget.insert(1.0, initial_text) + + # 按钮框架 + button_frame = tk.Frame(window) + button_frame.pack(fill="x", padx=10, pady=(0, 10)) + + # 统计信息标签 + stats_label = tk.Label(button_frame, text="字符数: 0", fg="gray") + stats_label.pack(side="left", padx=5) + self._stats_label = stats_label + + # 按钮 + copy_btn = tk.Button(button_frame, text="复制文本", command=self._copy_to_clipboard) + copy_btn.pack(side="right", padx=5) + + clear_btn = tk.Button(button_frame, text="清空", command=self._clear_text) + clear_btn.pack(side="right", padx=5) + + cancel_btn = tk.Button(button_frame, text="取消", command=self._on_cancel) + cancel_btn.pack(side="right", padx=5) + + confirm_btn = tk.Button(button_frame, text="确认粘贴", command=self._on_confirm, bg="#4CAF50", fg="white") + confirm_btn.pack(side="right", padx=5) + + # 绑定快捷键 + window.bind("", lambda e: self._select_all()) + window.bind("", lambda e: self._on_cancel()) + window.bind("", lambda e: self._on_confirm()) + + # 窗口关闭事件 + window.protocol("WM_DELETE_WINDOW", self._on_cancel) + + # 更新统计信息的绑定 + text_widget.bind("", lambda e: self._update_stats()) + + # 初始更新统计 + self._update_stats() + + # 设置焦点 + text_widget.focus_set() + + # 启动 tkinter 事件循环 + try: + root.mainloop() + except tk.TclError: + pass + finally: + # 清理资源 + try: + if root.winfo_exists(): + root.destroy() + except Exception: + pass + PasteWindow._current_root = None + self._window = None + self._text_widget = None + self._stats_label = None + PasteWindow._current_thread = None + + def _copy_to_clipboard(self) -> None: + """复制到剪贴板""" + content = self.get_text() + if content and self._window: + try: + self._window.clipboard_clear() + self._window.clipboard_append(content) + self._show_temporary_message("已复制") + except Exception: + pass + + def _show_temporary_message(self, message: str) -> None: + """显示临时消息""" + if self._stats_label and self._stats_label.winfo_exists(): + original = self._stats_label.cget("text") + self._stats_label.config(text=f"✓ {message}", fg="green") + if self._window: + self._window.after(1500, lambda: self._stats_label.config(text=original, fg="gray")) + + def _on_confirm(self) -> None: + """确认""" + content = self.get_text() + if self._callback: + self._callback(content) + self.hide() + + def _on_cancel(self) -> None: + """取消""" + self.hide() + + def _update_stats(self) -> None: + """更新统计""" + if self._text_widget and self._text_widget.winfo_exists(): + try: + content = self._text_widget.get(1.0, tk.END) + char_count = len(content) - 1 + line_count = int(self._text_widget.index(tk.END).split(".")[0]) - 1 + kb_size = char_count / 1024 + size_text = f"{kb_size:.1f}KB" if kb_size >= 1 else f"{char_count}B" + self._stats_label.config(text=f"字符数: {char_count:,} ({size_text}) | 行数: {line_count}") + except Exception: + pass + + def _clear_text(self) -> None: + """清空""" + if self._text_widget and self._text_widget.winfo_exists(): + self._text_widget.delete(1.0, tk.END) + self._update_stats() + + def _select_all(self) -> str | None: + """全选""" + if self._text_widget and self._text_widget.winfo_exists(): + self._text_widget.tag_add(tk.SEL, "1.0", tk.END) + self._text_widget.mark_set(tk.INSERT, "1.0") + return "break" + return None + + def get_text(self) -> str: + """获取文本""" + if self._text_widget and self._text_widget.winfo_exists(): + content = self._text_widget.get(1.0, tk.END) + if content.endswith("\n"): + content = content[:-1] + return content + return "" + + def show(self, callback: Callable[[str], None] | None = None, initial_text: str = "", title: str = "粘贴超大文本") -> None: + """ + 显示粘贴窗口(独立线程,非阻塞) + + 每次调用都会创建新的线程和Tk实例,确保不会出现线程冲突 + """ + self._callback = callback + + # 如果已有窗口在运行,先关闭 + if PasteWindow._current_thread and PasteWindow._current_thread.is_alive(): + self.hide() + # 等待线程完全结束 + PasteWindow._current_thread.join(timeout=1.0) + + # 创建新线程 + PasteWindow._current_thread = threading.Thread(target=self._run_tkinter, args=(initial_text, title), daemon=True) + PasteWindow._current_thread.start() + + def hide(self) -> None: + """关闭窗口""" + if self._window and self._window.winfo_exists(): + try: + self._window.quit() # 退出 mainloop + self._window.destroy() + except Exception: + pass + + # 等待线程结束,避免资源冲突 + if PasteWindow._current_thread and PasteWindow._current_thread != threading.current_thread(): + PasteWindow._current_thread.join(timeout=0.5) + + +# 便捷函数 +_paste_window: PasteWindow | None = None + + +def show_paste_window(callback: Callable[[str], None] | None = None, initial_text: str = "") -> None: + """显示粘贴窗口(非阻塞)""" + global _paste_window + if _paste_window is None: + _paste_window = PasteWindow.get_instance() + _paste_window.show(callback=callback, initial_text=initial_text) + + +def close_paste_window() -> None: + """关闭粘贴窗口""" + global _paste_window + if _paste_window: + _paste_window.hide() diff --git a/src/models/skill.py b/src/models/skill.py index 1a7c2f2..9e81015 100644 --- a/src/models/skill.py +++ b/src/models/skill.py @@ -1,127 +1,124 @@ -"""Skill 数据模型.""" - -from __future__ import annotations - -from contextlib import suppress -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -import yaml - - -@dataclass -class SkillInfo: - """Skill 信息模型. - - Attributes: - name: Skill 名称(来自目录名) - description: Skill 描述(来自 SKILL.md 第一行或 skill.txt) - location: Skill 所在目录的绝对路径 - content: SKILL.md 的完整内容 - files: Skill 目录中的其他文件列表(排除 SKILL.md) - enabled: 是否启用(用于配置) - """ - - name: str - description: str = "" - location: str = "" - content: str = "" - files: list[str] = field(default_factory=list) - enabled: bool = True - metadata: dict[str, Any] = field(default_factory=dict) - - @classmethod - def from_dir(cls, skill_dir: Path) -> SkillInfo | None: - """从目录加载 Skill 信息. - - Args: - skill_dir: Skill 目录路径 - - Returns: - SkillInfo 实例,如果目录无效则返回 None - """ - skill_md = skill_dir / "SKILL.md" - skill_txt = skill_dir / "skill.txt" - - # 必须有 SKILL.md 文件 - if not skill_md.exists(): - return None - - try: - content = skill_md.read_text(encoding="utf-8") - except Exception: - return None - - # 解析 YAML frontmatter - name = skill_dir.name # 默认使用目录名 - description = "" - - if content.startswith("---"): - # 解析 YAML frontmatter - parts = content.split("---", 2) - if len(parts) >= 3: - try: - frontmatter = yaml.safe_load(parts[1]) - if frontmatter: - name = frontmatter.get("name", name) - description = frontmatter.get("description", "") - except Exception: - pass - - # 如果没有从 frontmatter 获取到描述,尝试其他方式 - if not description: - # 优先从 skill.txt - if skill_txt.exists(): - with suppress(Exception): - description = skill_txt.read_text(encoding="utf-8").strip() - - if not description: - # 从 SKILL.md 第一行提取标题 - first_line = content.split("\n")[0] if content else "" - if first_line.startswith("#"): - description = first_line.lstrip("#").strip() - else: - description = first_line.strip() or skill_dir.name - - # 收集其他文件 - files: list[str] = [] - try: - for f in skill_dir.iterdir(): - if f.is_file() and f.name not in ("SKILL.md", "skill.txt"): - files.append(f.name) - except Exception: - pass - - return cls( - name=name, - description=description, - location=str(skill_dir), - content=content, - files=sorted(files), - enabled=True, - ) - - def to_dict(self) -> dict[str, Any]: - """转换为字典格式.""" - return { - "name": self.name, - "description": self.description, - "location": self.location, - "files": self.files, - "enabled": self.enabled, - "metadata": self.metadata, - } - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> SkillInfo: - """从字典创建实例.""" - return cls( - name=data.get("name", ""), - description=data.get("description", ""), - location=data.get("location", ""), - content=data.get("content", ""), - files=data.get("files", []), - enabled=data.get("enabled", True), - metadata=data.get("metadata", {}), - ) +"""Skill 数据模型.""" + +from __future__ import annotations + +from contextlib import suppress +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + + +@dataclass +class SkillInfo: + """Skill 信息模型. + + Attributes: + name: Skill 名称(来自目录名) + description: Skill 描述(来自 SKILL.md 第一行或 skill.txt) + location: Skill 所在目录的绝对路径 + content: SKILL.md 的完整内容 + files: Skill 目录中的其他文件列表(排除 SKILL.md) + enabled: 是否启用(用于配置) + """ + + name: str + description: str = "" + location: str = "" + content: str = "" + files: list[str] = field(default_factory=list) + enabled: bool = True + metadata: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dir(cls, skill_dir: Path) -> SkillInfo | None: + """从目录加载 Skill 信息. + + Args: + skill_dir: Skill 目录路径 + + Returns: + SkillInfo 实例,如果目录无效则返回 None + """ + skill_md = skill_dir / "SKILL.md" + skill_txt = skill_dir / "skill.txt" + + # 必须有 SKILL.md 文件 + if not skill_md.exists(): + return None + + try: + content = skill_md.read_text(encoding="utf-8") + except Exception: + return None + + # 解析 YAML frontmatter + name = skill_dir.name # 默认使用目录名 + description = "" + + if content.startswith("---"): + # 解析 YAML frontmatter + parts = content.split("---", 2) + if len(parts) >= 3: + try: + frontmatter = yaml.safe_load(parts[1]) + if frontmatter: + name = frontmatter.get("name", name) + description = frontmatter.get("description", "") + except Exception: + pass + + # 如果没有从 frontmatter 获取到描述,尝试其他方式 + if not description: + # 优先从 skill.txt + if skill_txt.exists(): + with suppress(Exception): + description = skill_txt.read_text(encoding="utf-8").strip() + + if not description: + # 从 SKILL.md 第一行提取标题 + first_line = content.split("\n")[0] if content else "" + description = first_line.lstrip("#").strip() if first_line.startswith("#") else first_line.strip() or skill_dir.name + + # 收集其他文件 + files: list[str] = [] + try: + for f in skill_dir.iterdir(): + if f.is_file() and f.name not in ("SKILL.md", "skill.txt"): + files.append(f.name) + except Exception: + pass + + return cls( + name=name, + description=description, + location=str(skill_dir), + content=content, + files=sorted(files), + enabled=True, + ) + + def to_dict(self) -> dict[str, Any]: + """转换为字典格式.""" + return { + "name": self.name, + "description": self.description, + "location": self.location, + "files": self.files, + "enabled": self.enabled, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SkillInfo: + """从字典创建实例.""" + return cls( + name=data.get("name", ""), + description=data.get("description", ""), + location=data.get("location", ""), + content=data.get("content", ""), + files=data.get("files", []), + enabled=data.get("enabled", True), + metadata=data.get("metadata", {}), + ) diff --git a/src/models/tools/tool_result.py b/src/models/tools/tool_result.py index 763e911..41107d6 100644 --- a/src/models/tools/tool_result.py +++ b/src/models/tools/tool_result.py @@ -1,143 +1,136 @@ -import json -import os -import warnings -from typing import Any, ClassVar - -from src.utils.string_snapshot import truncate_params_string - - -def to_xml_string(func_name: str, params: dict, data: Any = None, err: str | None = None) -> str: - params = params.copy() - params.pop("self", None) - - try: - messages: list[str] = [] - if data is not None: - messages.append("") - if isinstance(data, str): - messages.append(data) - elif isinstance(data, (dict, list, tuple)): - messages.append(json.dumps(data)) - else: - messages.append(f"{data.__class__.__name__}({data})") - messages.append("") - if err: - messages.extend( - [ - "", - err, - "", - ] - ) - if data is None and not err: - messages.append("没有任何工具调用数据或错误详情, 请提示用户检查工具是否正常") - - temp_result = [ - "", - f"", - *messages, - "", - "", - ] - func_result = str.join("\n", temp_result) - except Exception as e: - import traceback - - func_result = "\n".join( - [ - "", - f"", - f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})", - err if err else "", - "", - "", - ] - ) - return func_result - - -class ToolResult: - """工具执行结果的结构化包装,显式区分成功与失败. - - 所有被 @handle_tool_exceptions 装饰的工具方法均返回此类型, - 调用方可通过 success 标志可靠判断执行状态,无需依赖隐式类型约定. - """ - - __slots__ = ("data", "error", "func_kwargs", "func_name", "response", "success") - - HAD_VALIDATE: ClassVar[bool] = False - MAX_RESULT_LENGTH: ClassVar[int] = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000")) - LIST_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100")) - DICT_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100")) - - def __init__( - self, success: bool, func_name: str, func_kwargs: dict, data: Any = None, error: str | None = None - ) -> None: - self.success: bool = success - self.func_name: str = func_name - self.func_kwargs: dict = func_kwargs - self.data: Any = data - self.error: str | None = error - self.response: str = to_xml_string(self.func_name, self.func_kwargs, self.data, self.error) - ToolResult._validate_config() - - def __repr__(self) -> str: - if self.success: - return f"ToolResult(success=True, data={self.data!r})" - return f"ToolResult(success=False, error={self.error!r})" - - @property - def status(self) -> str: - return "success" if self.success else "error" - - @classmethod - def _compress_result(cls, result: Any) -> Any: - """压缩过长的结果""" - result_length = len(result) - if isinstance(result, str): - if result_length > cls.MAX_RESULT_LENGTH: - return ( - result[: cls.MAX_RESULT_LENGTH] - + f"... [字符串结果已截断 显示的字符数: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]" - ) - elif isinstance(result, (list, tuple)): - if result_length > cls.LIST_TRUNCATE_THRESHOLD: - return [ - *list(result[: cls.LIST_TRUNCATE_THRESHOLD]), - f"... [列表已截断 显示的项: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]", - ] - elif isinstance(result, dict) and result_length > cls.DICT_TRUNCATE_THRESHOLD: - compressed = {k: result[k] for k in list(result.keys())[: cls.DICT_TRUNCATE_THRESHOLD]} - compressed["..."] = f"[字典已截断 显示的项: {cls.DICT_TRUNCATE_THRESHOLD} / {result_length}]" - return compressed - - return result - - @classmethod - def _validate_config(cls) -> None: - """验证配置值确保在合理范围内""" - if cls.HAD_VALIDATE: - return None - if cls.MAX_RESULT_LENGTH < 10: - warnings.warn( - f"TOOL_MAX_RESULT_LENGTH 过小({cls.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2 - ) - cls.MAX_RESULT_LENGTH = 100 - - if cls.LIST_TRUNCATE_THRESHOLD < 10: - warnings.warn( - f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({cls.LIST_TRUNCATE_THRESHOLD}),建议至少为50", - UserWarning, - stacklevel=2, - ) - cls.LIST_TRUNCATE_THRESHOLD = 50 - - if cls.DICT_TRUNCATE_THRESHOLD < 10: - warnings.warn( - f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({cls.DICT_TRUNCATE_THRESHOLD}),建议至少为50", - UserWarning, - stacklevel=2, - ) - cls.DICT_TRUNCATE_THRESHOLD = 50 - cls.HAD_VALIDATE = True - return None +import json +import os +import warnings +from typing import Any, ClassVar + +from src.utils.string_snapshot import truncate_params_string + + +def to_xml_string(func_name: str, params: dict, data: Any = None, err: str | None = None) -> str: + params = params.copy() + params.pop("self", None) + + try: + messages: list[str] = [] + if data is not None: + messages.append("") + if isinstance(data, str): + messages.append(data) + elif isinstance(data, (dict, list, tuple)): + messages.append(json.dumps(data)) + else: + messages.append(f"{data.__class__.__name__}({data})") + messages.append("") + if err: + messages.extend( + [ + "", + err, + "", + ] + ) + if data is None and not err: + messages.append("没有任何工具调用数据或错误详情, 请提示用户检查工具是否正常") + + temp_result = [ + "", + f"", + *messages, + "", + "", + ] + func_result = str.join("\n", temp_result) + except Exception as e: + import traceback + + func_result = "\n".join( + [ + "", + f"", + f"Error={e.__class__.__name__}({e}, {traceback.format_exc()})", + err if err else "", + "", + "", + ] + ) + return func_result + + +class ToolResult: + """工具执行结果的结构化包装,显式区分成功与失败. + + 所有被 @handle_tool_exceptions 装饰的工具方法均返回此类型, + 调用方可通过 success 标志可靠判断执行状态,无需依赖隐式类型约定. + """ + + __slots__ = ("data", "error", "func_kwargs", "func_name", "response", "success") + + HAD_VALIDATE: ClassVar[bool] = False + MAX_RESULT_LENGTH: ClassVar[int] = int(os.getenv("TOOL_MAX_RESULT_LENGTH", "30000")) + LIST_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_LIST_TRUNCATE_THRESHOLD", "100")) + DICT_TRUNCATE_THRESHOLD: ClassVar[int] = int(os.getenv("TOOL_DICT_TRUNCATE_THRESHOLD", "100")) + + def __init__(self, success: bool, func_name: str, func_kwargs: dict, data: Any = None, error: str | None = None) -> None: + self.success: bool = success + self.func_name: str = func_name + self.func_kwargs: dict = func_kwargs + self.data: Any = data + self.error: str | None = error + self.response: str = to_xml_string(self.func_name, self.func_kwargs, self.data, self.error) + ToolResult._validate_config() + + def __repr__(self) -> str: + if self.success: + return f"ToolResult(success=True, data={self.data!r})" + return f"ToolResult(success=False, error={self.error!r})" + + @property + def status(self) -> str: + return "success" if self.success else "error" + + @classmethod + def _compress_result(cls, result: Any) -> Any: + """压缩过长的结果""" + result_length = len(result) + if isinstance(result, str): + if result_length > cls.MAX_RESULT_LENGTH: + return result[: cls.MAX_RESULT_LENGTH] + f"... [字符串结果已截断 显示的字符数: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]" + elif isinstance(result, (list, tuple)): + if result_length > cls.LIST_TRUNCATE_THRESHOLD: + return [ + *list(result[: cls.LIST_TRUNCATE_THRESHOLD]), + f"... [列表已截断 显示的项: {cls.LIST_TRUNCATE_THRESHOLD} / {result_length}]", + ] + elif isinstance(result, dict) and result_length > cls.DICT_TRUNCATE_THRESHOLD: + compressed = {k: result[k] for k in list(result.keys())[: cls.DICT_TRUNCATE_THRESHOLD]} + compressed["..."] = f"[字典已截断 显示的项: {cls.DICT_TRUNCATE_THRESHOLD} / {result_length}]" + return compressed + + return result + + @classmethod + def _validate_config(cls) -> None: + """验证配置值确保在合理范围内""" + if cls.HAD_VALIDATE: + return None + if cls.MAX_RESULT_LENGTH < 10: + warnings.warn(f"TOOL_MAX_RESULT_LENGTH 过小({cls.MAX_RESULT_LENGTH}),建议至少为100", UserWarning, stacklevel=2) + cls.MAX_RESULT_LENGTH = 100 + + if cls.LIST_TRUNCATE_THRESHOLD < 10: + warnings.warn( + f"TOOL_LIST_TRUNCATE_THRESHOLD 过小({cls.LIST_TRUNCATE_THRESHOLD}),建议至少为50", + UserWarning, + stacklevel=2, + ) + cls.LIST_TRUNCATE_THRESHOLD = 50 + + if cls.DICT_TRUNCATE_THRESHOLD < 10: + warnings.warn( + f"TOOL_DICT_TRUNCATE_THRESHOLD 过小({cls.DICT_TRUNCATE_THRESHOLD}),建议至少为50", + UserWarning, + stacklevel=2, + ) + cls.DICT_TRUNCATE_THRESHOLD = 50 + cls.HAD_VALIDATE = True + return None diff --git a/src/workspace/path_validator.py b/src/workspace/path_validator.py index 99a4bee..89352ed 100644 --- a/src/workspace/path_validator.py +++ b/src/workspace/path_validator.py @@ -1,184 +1,182 @@ -import os -import re -from pathlib import Path -from typing import ClassVar - -from src.workspace.exclusion_manager import ExclusionManager - - -class WorkspaceBoundaryError(Exception): - """访问工作区外的路径时抛出""" - - pass - - -class PathNotFoundError(Exception): - """工作区内路径不存在时抛出""" - - pass - - -class SensitiveFileError(Exception): - """访问敏感文件时抛出""" - - pass - - -class PathValidator: - """工作区路径安全校验器,防止路径遍历和符号链接逃逸 - - Args: - workspace_root: 工作区根目录,默认为当前目录 - """ - - # 敏感文件匹配模式(从 ExclusionManager 统一来源引用) - SENSITIVE_FILE_PATTERNS: ClassVar[list[re.Pattern]] = [ - re.compile(p) for p in ExclusionManager.SENSITIVE_FILE_PATTERNS - ] - - def __init__(self, workspace_root: str | Path = "."): - """初始化路径验证器. - - Args: - workspace_root: 工作区根目录路径,可以是字符串或 Path 对象 - 所有后续的路径验证都将以此目录为基准 - - Raises: - FileNotFoundError: 当 workspace_root 不存在时抛出 - NotADirectoryError: 当 workspace_root 不是目录时抛出 - - Example: - >>> validator = PathValidator("/app/workspace") - """ - self.root = Path(workspace_root).resolve() - - def resolve_path(self, target: str | Path) -> Path: - """获取路径的绝对路径,不检查文件是否存在. - - 只执行安全边界验证,不要求路径实际存在.适用于需要获取路径引用 - 但文件/目录可能尚未创建的场景(如准备创建新文件). - - Args: - target: 待解析的目标路径,可以是相对路径或绝对路径. - 绝对路径会相对于工作区根目录解析. - - Returns: - Path: 解析后的绝对路径对象,保证位于工作区边界内. - - Raises: - WorkspaceBoundaryError: 当解析后的路径位于工作区根目录之外时抛出. - OSError: 当路径解析过程中发生其他系统错误时抛出. - - Example: - >>> validator = PathValidator("/app/workspace") - >>> # 路径不存在但仍在工作区内 - >>> validator.resolve_path("new_folder/new_file.txt") - PosixPath('/app/workspace/new_folder/new_file.txt') - >>> # 路径越界会报错 - >>> validator.resolve_path("../outside.txt") - WorkspaceBoundaryError: 路径越界: ../outside.txt - """ - path = Path(target) - # 统一转为相对于工作区根目录的绝对路径 - resolved = (self.root / path).resolve() if not path.is_absolute() else path.resolve() - - # 边界守卫:防 .. 逃逸与符号链接越权 - if not str(resolved).startswith(str(self.root) + os.sep) and resolved != self.root: - raise WorkspaceBoundaryError(f"路径越界: {target}") - - # 敏感文件检查 - self._raise_if_sensitive(resolved, target) - - return resolved - - @classmethod - def _raise_if_sensitive(cls, resolved: Path, original_target: str | Path) -> None: - """检查路径是否匹配敏感文件模式.""" - resolved_str = str(resolved).replace(os.sep, "/") - for pattern in cls.SENSITIVE_FILE_PATTERNS: - if pattern.search(resolved_str): - raise SensitiveFileError(f"禁止访问敏感文件: {original_target}") - - def create_file_with_parents(self, target: str | Path, content: str = "") -> Path: - """在工作区内创建文件,自动创建所有不存在的父目录. - - 专门用于在工作区内创建新文件或覆盖已存在文件的场景. - 如果文件已存在,将覆盖其内容.如果需要保留已有内容,请先使用 read 工具确认. - - 执行步骤: - 1. 安全边界验证(防止路径遍历和符号链接逃逸) - 2. 自动创建所有不存在的父目录(权限使用默认755/0o755) - 3. 写入指定内容(默认为空字符串) - - Args: - target: 目标文件路径,可以是相对路径或绝对路径 - content: 要写入的文件内容,默认为空字符串 - - Returns: - Path: 创建成功后的绝对路径对象 - - Raises: - WorkspaceBoundaryError: 当解析后的路径位于工作区根目录之外时抛出 - PermissionError: 当没有权限创建目录或文件时抛出 - OSError: 当创建目录或文件过程中发生其他系统错误时抛出 - - Example: - >>> validator = PathValidator("/app/workspace") - >>> # 创建深层目录结构中的文件 - >>> validator.create_file_with_parents("a/b/c/new_file.txt", "Hello World") - PosixPath('/app/workspace/a/b/c/new_file.txt') - """ - # 1. 边界验证 - resolved = self.resolve_path(target) - - # 2. 确保父目录存在 - parent_dir = resolved.parent - if not parent_dir.exists(): - parent_dir.mkdir(parents=True, exist_ok=True) - - # 3. 写入文件内容 - resolved.write_text(content, encoding="utf-8") - - return resolved - - def validate(self, target: str | Path) -> Path: - """校验路径安全性,返回绝对路径. - - 执行多层安全检查: - 1. 路径规范化(解析 '..' 和符号链接) - 2. 工作区边界验证(防止路径逃逸) - 3. 文件存在性检查 - 4. 读取权限验证 - - Args: - target: 待验证的目标路径,可以是相对路径或绝对路径. - 绝对路径会相对于当前工作目录解析. - - Returns: - Path: 验证通过后的绝对路径对象,保证位于工作区边界内. - - Raises: - WorkspaceBoundaryError: 当解析后的路径位于工作区根目录之外时抛出. - PathNotFoundError: 当目标路径不存在时抛出. - PermissionError: 当目标路径存在但无读取权限时抛出. - OSError: 当路径解析过程中发生其他系统错误时抛出. - - Example: - >>> validator = PathValidator("/app/workspace") - >>> # 有效路径 - >>> validator.validate("src/main.py") - PosixPath('/app/workspace/src/main.py') - >>> # 符号链接验证 - >>> validator.validate("link_to_config") - PosixPath('/app/workspace/config/settings.ini') - """ - # 复用 resolve_path 进行边界验证 - resolved = self.resolve_path(target) - - # 存在性与基础权限 - if not resolved.exists(): - raise PathNotFoundError(f"路径不存在: {target}") - if not os.access(resolved, os.R_OK): - raise PermissionError(f"无读取权限: {target}") # pragma: no cover // 需要专门准备权限, 不测了 - - return resolved +import os +import re +from pathlib import Path +from typing import ClassVar + +from src.workspace.exclusion_manager import ExclusionManager + + +class WorkspaceBoundaryError(Exception): + """访问工作区外的路径时抛出""" + + pass + + +class PathNotFoundError(Exception): + """工作区内路径不存在时抛出""" + + pass + + +class SensitiveFileError(Exception): + """访问敏感文件时抛出""" + + pass + + +class PathValidator: + """工作区路径安全校验器,防止路径遍历和符号链接逃逸 + + Args: + workspace_root: 工作区根目录,默认为当前目录 + """ + + # 敏感文件匹配模式(从 ExclusionManager 统一来源引用) + SENSITIVE_FILE_PATTERNS: ClassVar[list[re.Pattern]] = [re.compile(p) for p in ExclusionManager.SENSITIVE_FILE_PATTERNS] + + def __init__(self, workspace_root: str | Path = "."): + """初始化路径验证器. + + Args: + workspace_root: 工作区根目录路径,可以是字符串或 Path 对象 + 所有后续的路径验证都将以此目录为基准 + + Raises: + FileNotFoundError: 当 workspace_root 不存在时抛出 + NotADirectoryError: 当 workspace_root 不是目录时抛出 + + Example: + >>> validator = PathValidator("/app/workspace") + """ + self.root = Path(workspace_root).resolve() + + def resolve_path(self, target: str | Path) -> Path: + """获取路径的绝对路径,不检查文件是否存在. + + 只执行安全边界验证,不要求路径实际存在.适用于需要获取路径引用 + 但文件/目录可能尚未创建的场景(如准备创建新文件). + + Args: + target: 待解析的目标路径,可以是相对路径或绝对路径. + 绝对路径会相对于工作区根目录解析. + + Returns: + Path: 解析后的绝对路径对象,保证位于工作区边界内. + + Raises: + WorkspaceBoundaryError: 当解析后的路径位于工作区根目录之外时抛出. + OSError: 当路径解析过程中发生其他系统错误时抛出. + + Example: + >>> validator = PathValidator("/app/workspace") + >>> # 路径不存在但仍在工作区内 + >>> validator.resolve_path("new_folder/new_file.txt") + PosixPath('/app/workspace/new_folder/new_file.txt') + >>> # 路径越界会报错 + >>> validator.resolve_path("../outside.txt") + WorkspaceBoundaryError: 路径越界: ../outside.txt + """ + path = Path(target) + # 统一转为相对于工作区根目录的绝对路径 + resolved = (self.root / path).resolve() if not path.is_absolute() else path.resolve() + + # 边界守卫:防 .. 逃逸与符号链接越权 + if not str(resolved).startswith(str(self.root) + os.sep) and resolved != self.root: + raise WorkspaceBoundaryError(f"路径越界: {target}") + + # 敏感文件检查 + self._raise_if_sensitive(resolved, target) + + return resolved + + @classmethod + def _raise_if_sensitive(cls, resolved: Path, original_target: str | Path) -> None: + """检查路径是否匹配敏感文件模式.""" + resolved_str = str(resolved).replace(os.sep, "/") + for pattern in cls.SENSITIVE_FILE_PATTERNS: + if pattern.search(resolved_str): + raise SensitiveFileError(f"禁止访问敏感文件: {original_target}") + + def create_file_with_parents(self, target: str | Path, content: str = "") -> Path: + """在工作区内创建文件,自动创建所有不存在的父目录. + + 专门用于在工作区内创建新文件或覆盖已存在文件的场景. + 如果文件已存在,将覆盖其内容.如果需要保留已有内容,请先使用 read 工具确认. + + 执行步骤: + 1. 安全边界验证(防止路径遍历和符号链接逃逸) + 2. 自动创建所有不存在的父目录(权限使用默认755/0o755) + 3. 写入指定内容(默认为空字符串) + + Args: + target: 目标文件路径,可以是相对路径或绝对路径 + content: 要写入的文件内容,默认为空字符串 + + Returns: + Path: 创建成功后的绝对路径对象 + + Raises: + WorkspaceBoundaryError: 当解析后的路径位于工作区根目录之外时抛出 + PermissionError: 当没有权限创建目录或文件时抛出 + OSError: 当创建目录或文件过程中发生其他系统错误时抛出 + + Example: + >>> validator = PathValidator("/app/workspace") + >>> # 创建深层目录结构中的文件 + >>> validator.create_file_with_parents("a/b/c/new_file.txt", "Hello World") + PosixPath('/app/workspace/a/b/c/new_file.txt') + """ + # 1. 边界验证 + resolved = self.resolve_path(target) + + # 2. 确保父目录存在 + parent_dir = resolved.parent + if not parent_dir.exists(): + parent_dir.mkdir(parents=True, exist_ok=True) + + # 3. 写入文件内容 + resolved.write_text(content, encoding="utf-8") + + return resolved + + def validate(self, target: str | Path) -> Path: + """校验路径安全性,返回绝对路径. + + 执行多层安全检查: + 1. 路径规范化(解析 '..' 和符号链接) + 2. 工作区边界验证(防止路径逃逸) + 3. 文件存在性检查 + 4. 读取权限验证 + + Args: + target: 待验证的目标路径,可以是相对路径或绝对路径. + 绝对路径会相对于当前工作目录解析. + + Returns: + Path: 验证通过后的绝对路径对象,保证位于工作区边界内. + + Raises: + WorkspaceBoundaryError: 当解析后的路径位于工作区根目录之外时抛出. + PathNotFoundError: 当目标路径不存在时抛出. + PermissionError: 当目标路径存在但无读取权限时抛出. + OSError: 当路径解析过程中发生其他系统错误时抛出. + + Example: + >>> validator = PathValidator("/app/workspace") + >>> # 有效路径 + >>> validator.validate("src/main.py") + PosixPath('/app/workspace/src/main.py') + >>> # 符号链接验证 + >>> validator.validate("link_to_config") + PosixPath('/app/workspace/config/settings.ini') + """ + # 复用 resolve_path 进行边界验证 + resolved = self.resolve_path(target) + + # 存在性与基础权限 + if not resolved.exists(): + raise PathNotFoundError(f"路径不存在: {target}") + if not os.access(resolved, os.R_OK): + raise PermissionError(f"无读取权限: {target}") # pragma: no cover // 需要专门准备权限, 不测了 + + return resolved diff --git a/src/workspace/tools/base_tool.py b/src/workspace/tools/base_tool.py index e958559..16d47f8 100644 --- a/src/workspace/tools/base_tool.py +++ b/src/workspace/tools/base_tool.py @@ -1,291 +1,277 @@ -import inspect -from collections.abc import Callable -from pathlib import Path -from typing import Any - -from src.core.file_tracker import FileTracker -from src.models.tools.tool_result import ToolResult -from src.workspace.workspace import Workspace - - -def build_param_list_item(name: str, params: dict[str, Any], description: str = "") -> str: - """Generate a Markdown list item describing a parameter.""" - from src.constants.prompts import clean_type_annotation - - type_str = clean_type_annotation(params.get("annotation", "Any")) - - if params.get("required", False): - req_str = "required" - else: - default_val = params.get("default", "") - req_str = f"optional, default=`{default_val}`" - - desc_suffix = f": {description}" if description else "" - - return f"- **{name}** ({type_str}, {req_str}){desc_suffix}" - - -def convert_param_type(value: str, expected_type: str) -> Any: - """ - 根据期望的类型注解转换参数值 - - Args: - value: 原始字符串值 - expected_type: 期望的类型注解字符串(如 '' 或 'int') - - Returns: - 转换后的值 - """ - # 提取类型名称 - type_name = expected_type.lower() - if " 中的 int - import re - - match = re.search(r"'(\w+)'", expected_type) - if match: - type_name = match.group(1).lower() - - # 类型转换 - if type_name in ("int", "integer"): - try: - return int(value) - except ValueError: - return value - elif type_name in ("float", "double"): - try: - return float(value) - except ValueError: - return value - elif type_name in ("bool", "boolean"): - return value.lower() in ("true", "1", "yes", "on") - elif type_name in ("list", "tuple"): - # 简单列表解析 [1,2,3] 或 ["a","b"] - if value.startswith("[") and value.endswith("]"): - import json - - try: - return json.loads(value) - except json.JSONDecodeError: - # 简单分割 - inner = value[1:-1].strip() - if inner: - return [item.strip().strip("\"'") for item in inner.split(",")] - return [] - return value - elif type_name in ("dict", "dictionary"): - # 简单字典解析 {"key":"value"} - if value.startswith("{") and value.endswith("}"): - import json - - try: - return json.loads(value) - except json.JSONDecodeError: - return value - return value - else: - # 字符串或其他类型 - return value - - -class BaseTool: - def __init__( - self, - workspace: Workspace, - name: str = "", - doc: str = "", - read_permission: bool = True, - write_permission: bool = False, - ): - self.workspace = workspace - # 排除无效工具 - if name is None or len(name) == 0: - raise ValueError(f"注册工具时{self.__class__.__name__}的名称没有有效值") - if doc is None or len(doc) == 0: - raise ValueError(f"注册工具时{self.__class__.__name__}的文档没有有效值") - # 读取权限 - self.read_permission: bool = read_permission - # 写入权限 - self.write_permission: bool = write_permission - self.name: str = name - self.doc: str = doc - self.func: Callable[..., ToolResult] | None = None - self.params: dict[str, Any] | None = None - self.param_descriptions: dict[str, str] = {} - - def to_doc(self) -> str: - """转换为模型可读文档格式""" - lines = [f'', f" {self.doc}"] - if self.params and len(self.params) > 0: - lines.append(" ") - for name, param in self.params.items(): - desc = self.param_descriptions.get(name, "") - lines.append(f" {build_param_list_item(name, param, desc)}") - lines.append(" ") - else: - lines.append(" ") - lines.append("") - return "\n".join(lines) - - def to_func_call(self) -> str: - """将工具转换为标准格式""" - func_call: str = f'\n' - for name, params in self.params.items(): - func_call += ( - f' ' - + ("" if "default" not in params else str(params.get("default"))) - + "\n" - ) - func_call += "" - return func_call - - @staticmethod - def extract_params(func: Callable[..., ToolResult]) -> dict[str, Any]: - """提取函数参数信息""" - sig = inspect.signature(func) - params = {} - for param_name, param in sig.parameters.items(): - if param_name not in ("self", "cls"): - param_info = { - "required": param.default == inspect.Parameter.empty, - "annotation": str(param.annotation) if param.annotation != inspect.Parameter.empty else "Any", - } - if param.default != inspect.Parameter.empty: - param_info["default"] = repr(param.default) - params[param_name] = param_info - return params - - def convert_args(self, kwargs: dict[str, Any]) -> dict[str, Any]: - """ - 根据参数类型注解转换参数值 - - Args: - kwargs: 原始参数字典(值可能都是字符串) - - Returns: - 转换类型后的参数字典 - """ - if not self.params: - return kwargs - - converted = {} - for param_name, param_value in kwargs.items(): - if param_name in self.params: - expected_type = self.params[param_name].get("annotation", "Any") - converted[param_name] = convert_param_type(str(param_value), expected_type) - else: - converted[param_name] = param_value - - return converted - - def _record_read_meta(self, resolved_path: Path) -> None: - try: - meta = FileTracker.get_file_meta(resolved_path) - if meta: - session_id = self.workspace.session_id - if session_id is not None: - rel_path = str(resolved_path.relative_to(self.workspace.root_path)) - self.workspace.db.record_file_read( - session_id, rel_path, meta["mtime"], meta["size"], meta["checksum"] - ) - except Exception: - pass - - def _validate_mtime(self, resolved_path: Path) -> str | None: - """校验文件自上次读取后是否被外部修改.""" - if not resolved_path.exists(): - return None - - session_id = self.workspace.session_id - if session_id is None: - return None - - rel_path = str(resolved_path.relative_to(self.workspace.root_path)) - record = self.workspace.db.get_file_read_record(session_id, rel_path) - if record is None: - return None - - stored_mtime = record[3] - current_mtime = resolved_path.stat().st_mtime - - if abs(current_mtime - stored_mtime) > 0.001: - return ( - f"ERROR: FILE_MODIFIED_EXTERNALLY - " - f'The file "{rel_path}" was modified externally since last read. ' - f'Please re-read the file with the "read" tool before editing it.' - ) - return None - - @staticmethod - def _generate_diff(old_content: str, new_content: str, file_path: str) -> str: - import difflib - - old_lines = old_content.splitlines(keepends=True) - new_lines = new_content.splitlines(keepends=True) - diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}") - return "".join(diff) - - @classmethod - def make_tool_result_response( - cls, success: bool, kwargs: dict, data: Any = None, error: str | None = None - ) -> ToolResult: - return ToolResult(success=success, func_name=cls.__name__, func_kwargs=kwargs, data=data, error=error) - - @classmethod - def make_success_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: - return cls.make_tool_result_response(success=True, kwargs=kwargs, data=data, error=error) - - @classmethod - def make_failed_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: - return cls.make_tool_result_response(success=False, kwargs=kwargs, data=data, error=error) - - @staticmethod - def handle_tool_exceptions(func) -> Callable[..., ToolResult]: - """工具方法异常处理装饰器 —— 将异常转换为 ToolResult 失败结果""" - from functools import wraps - - from src.workspace.path_validator import PathNotFoundError, SensitiveFileError, WorkspaceBoundaryError - - @wraps(func) - def wrapper(self, *args, **kwargs): - try: - raw = func(self, *args, **kwargs) - # 如果工具内部已返回 ToolResult, 直接透传 - if isinstance(raw, ToolResult): - return raw - # 否则包装为成功结果 - return ToolResult(success=True, func_name=func.__name__, func_kwargs=kwargs, data=raw) - except PathNotFoundError as err1: - return ToolResult( - success=False, - func_name=func.__name__, - func_kwargs=kwargs, - error=f"{err1.__class__.__name__}: {err1}", - ) - except WorkspaceBoundaryError as err2: - return ToolResult( - success=False, - func_name=func.__name__, - func_kwargs=kwargs, - error=f"{err2.__class__.__name__}: {err2}", - ) - except SensitiveFileError as err3: - return ToolResult( - success=False, - func_name=func.__name__, - func_kwargs=kwargs, - error=f"{err3.__class__.__name__}: {err3}", - ) - except PermissionError as err4: - return ToolResult( - success=False, - func_name=func.__name__, - func_kwargs=kwargs, - error=f"{err4.__class__.__name__}: {err4}", - ) - except Exception as err: - return ToolResult( - success=False, func_name=func.__name__, func_kwargs=kwargs, error=f"{err.__class__.__name__}: {err}" - ) - - return wrapper +import inspect +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from src.core.file_tracker import FileTracker +from src.models.tools.tool_result import ToolResult +from src.workspace.workspace import Workspace + + +def build_param_list_item(name: str, params: dict[str, Any], description: str = "") -> str: + """Generate a Markdown list item describing a parameter.""" + from src.constants.prompts import clean_type_annotation + + type_str = clean_type_annotation(params.get("annotation", "Any")) + + if params.get("required", False): + req_str = "required" + else: + default_val = params.get("default", "") + req_str = f"optional, default=`{default_val}`" + + desc_suffix = f": {description}" if description else "" + + return f"- **{name}** ({type_str}, {req_str}){desc_suffix}" + + +def convert_param_type(value: str, expected_type: str) -> Any: + """ + 根据期望的类型注解转换参数值 + + Args: + value: 原始字符串值 + expected_type: 期望的类型注解字符串(如 '' 或 'int') + + Returns: + 转换后的值 + """ + # 提取类型名称 + type_name = expected_type.lower() + if " 中的 int + import re + + match = re.search(r"'(\w+)'", expected_type) + if match: + type_name = match.group(1).lower() + + # 类型转换 + if type_name in ("int", "integer"): + try: + return int(value) + except ValueError: + return value + elif type_name in ("float", "double"): + try: + return float(value) + except ValueError: + return value + elif type_name in ("bool", "boolean"): + return value.lower() in ("true", "1", "yes", "on") + elif type_name in ("list", "tuple"): + # 简单列表解析 [1,2,3] 或 ["a","b"] + if value.startswith("[") and value.endswith("]"): + import json + + try: + return json.loads(value) + except json.JSONDecodeError: + # 简单分割 + inner = value[1:-1].strip() + if inner: + return [item.strip().strip("\"'") for item in inner.split(",")] + return [] + return value + elif type_name in ("dict", "dictionary"): + # 简单字典解析 {"key":"value"} + if value.startswith("{") and value.endswith("}"): + import json + + try: + return json.loads(value) + except json.JSONDecodeError: + return value + return value + else: + # 字符串或其他类型 + return value + + +class BaseTool: + def __init__( + self, + workspace: Workspace, + name: str = "", + doc: str = "", + read_permission: bool = True, + write_permission: bool = False, + ): + self.workspace = workspace + # 排除无效工具 + if name is None or len(name) == 0: + raise ValueError(f"注册工具时{self.__class__.__name__}的名称没有有效值") + if doc is None or len(doc) == 0: + raise ValueError(f"注册工具时{self.__class__.__name__}的文档没有有效值") + # 读取权限 + self.read_permission: bool = read_permission + # 写入权限 + self.write_permission: bool = write_permission + self.name: str = name + self.doc: str = doc + self.func: Callable[..., ToolResult] | None = None + self.params: dict[str, Any] | None = None + self.param_descriptions: dict[str, str] = {} + + def to_doc(self) -> str: + """转换为模型可读文档格式""" + lines = [f'', f" {self.doc}"] + if self.params and len(self.params) > 0: + lines.append(" ") + for name, param in self.params.items(): + desc = self.param_descriptions.get(name, "") + lines.append(f" {build_param_list_item(name, param, desc)}") + lines.append(" ") + else: + lines.append(" ") + lines.append("") + return "\n".join(lines) + + def to_func_call(self) -> str: + """将工具转换为标准格式""" + func_call: str = f'\n' + for name, params in self.params.items(): + func_call += f' ' + ("" if "default" not in params else str(params.get("default"))) + "\n" + func_call += "" + return func_call + + @staticmethod + def extract_params(func: Callable[..., ToolResult]) -> dict[str, Any]: + """提取函数参数信息""" + sig = inspect.signature(func) + params = {} + for param_name, param in sig.parameters.items(): + if param_name not in ("self", "cls"): + param_info = { + "required": param.default == inspect.Parameter.empty, + "annotation": str(param.annotation) if param.annotation != inspect.Parameter.empty else "Any", + } + if param.default != inspect.Parameter.empty: + param_info["default"] = repr(param.default) + params[param_name] = param_info + return params + + def convert_args(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """ + 根据参数类型注解转换参数值 + + Args: + kwargs: 原始参数字典(值可能都是字符串) + + Returns: + 转换类型后的参数字典 + """ + if not self.params: + return kwargs + + converted = {} + for param_name, param_value in kwargs.items(): + if param_name in self.params: + expected_type = self.params[param_name].get("annotation", "Any") + converted[param_name] = convert_param_type(str(param_value), expected_type) + else: + converted[param_name] = param_value + + return converted + + def _record_read_meta(self, resolved_path: Path) -> None: + try: + meta = FileTracker.get_file_meta(resolved_path) + if meta: + session_id = self.workspace.session_id + if session_id is not None: + rel_path = str(resolved_path.relative_to(self.workspace.root_path)) + self.workspace.db.record_file_read(session_id, rel_path, meta["mtime"], meta["size"], meta["checksum"]) + except Exception: + pass + + def _validate_mtime(self, resolved_path: Path) -> str | None: + """校验文件自上次读取后是否被外部修改.""" + if not resolved_path.exists(): + return None + + session_id = self.workspace.session_id + if session_id is None: + return None + + rel_path = str(resolved_path.relative_to(self.workspace.root_path)) + record = self.workspace.db.get_file_read_record(session_id, rel_path) + if record is None: + return None + + stored_mtime = record[3] + current_mtime = resolved_path.stat().st_mtime + + if abs(current_mtime - stored_mtime) > 0.001: + return f'ERROR: FILE_MODIFIED_EXTERNALLY - The file "{rel_path}" was modified externally since last read. Please re-read the file with the "read" tool before editing it.' + return None + + @staticmethod + def _generate_diff(old_content: str, new_content: str, file_path: str) -> str: + import difflib + + old_lines = old_content.splitlines(keepends=True) + new_lines = new_content.splitlines(keepends=True) + diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}") + return "".join(diff) + + @classmethod + def make_tool_result_response(cls, success: bool, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: + return ToolResult(success=success, func_name=cls.__name__, func_kwargs=kwargs, data=data, error=error) + + @classmethod + def make_success_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: + return cls.make_tool_result_response(success=True, kwargs=kwargs, data=data, error=error) + + @classmethod + def make_failed_response(cls, kwargs: dict, data: Any = None, error: str | None = None) -> ToolResult: + return cls.make_tool_result_response(success=False, kwargs=kwargs, data=data, error=error) + + @staticmethod + def handle_tool_exceptions(func) -> Callable[..., ToolResult]: + """工具方法异常处理装饰器 —— 将异常转换为 ToolResult 失败结果""" + from functools import wraps + + from src.workspace.path_validator import PathNotFoundError, SensitiveFileError, WorkspaceBoundaryError + + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + raw = func(self, *args, **kwargs) + # 如果工具内部已返回 ToolResult, 直接透传 + if isinstance(raw, ToolResult): + return raw + # 否则包装为成功结果 + return ToolResult(success=True, func_name=func.__name__, func_kwargs=kwargs, data=raw) + except PathNotFoundError as err1: + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err1.__class__.__name__}: {err1}", + ) + except WorkspaceBoundaryError as err2: + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err2.__class__.__name__}: {err2}", + ) + except SensitiveFileError as err3: + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err3.__class__.__name__}: {err3}", + ) + except PermissionError as err4: + return ToolResult( + success=False, + func_name=func.__name__, + func_kwargs=kwargs, + error=f"{err4.__class__.__name__}: {err4}", + ) + except Exception as err: + return ToolResult(success=False, func_name=func.__name__, func_kwargs=kwargs, error=f"{err.__class__.__name__}: {err}") + + return wrapper diff --git a/src/workspace/tools/edit_tool.py b/src/workspace/tools/edit_tool.py index b33c154..48bc41e 100644 --- a/src/workspace/tools/edit_tool.py +++ b/src/workspace/tools/edit_tool.py @@ -1,174 +1,153 @@ -"""安全的字符串替换编辑工具 -- 只发布待审核更改.""" - -from pathlib import Path - -from src.models.tool_error_response import ToolErrorResponse -from src.models.tools.tool_result import ToolResult -from src.utils.binary_detector import is_binary_file -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -class EditTool(BaseTool): - """安全的字符串替换编辑工具 — 两阶段提交 (预览 → 审核确认). - - 只计算 diff 并记录 PENDING_AUDIT 快照,不直接修改磁盘. - 由审核提交模块 (AuditCommitter) 在批准后执行实际写入. - """ - - def __init__(self, workspace: Workspace): - super().__init__(workspace, "edit", self.edit.__doc__, write_permission=True) - self.func = self.edit - self.params = BaseTool.extract_params(self.edit) - self.param_descriptions = { - "path": "文件路径", - "old_string": "待替换的字符串", - "new_string": "替换后的字符串", - "max_replacements": "最大替换次数(1~100)", - "context_before": "匹配前的上下文文本", - "context_after": "匹配后的上下文文本", - } - - @BaseTool.handle_tool_exceptions - def edit( - self, - path: str, - old_string: str, - new_string: str, - max_replacements: int = 10, - context_before: str = "", - context_after: str = "", - ) -> ToolResult: - """ - 通过在文件中进行安全的字符串替换编辑文件 - """ - # 1. 参数校验 - if not old_string: - return self.make_failed_response(locals().copy(), error=f"{ValueError('old_string 不能为空')}") - - if max_replacements < 1: - return self.make_failed_response(locals().copy(), error=f"{ValueError('max_replacements 必须 >= 1')}") - if max_replacements > 100: - max_replacements = 100 - - # 2. 路径解析 - source_path = Path(path) - resolved_path: Path = self.workspace.path_validator.resolve_path(source_path) - - if not resolved_path.is_file(): - return self.make_failed_response( - locals().copy(), error=f"{FileNotFoundError(f'文件不存在: {resolved_path}')}" - ) - - if is_binary_file(resolved_path): - return self.make_failed_response( - locals().copy(), error=f"{ValueError(f'禁止编辑二进制文件: {resolved_path}')}" - ) - - # 3. mtime 校验 - mtime_error = self._validate_mtime(resolved_path) - if mtime_error: - return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") - - # 4. 读取文件内容 - old_content = resolved_path.read_text(encoding="utf-8") - - # 5. 查找匹配 - count = 0 - idx = 0 - while count < max_replacements: - idx = old_content.find(old_string, idx) - if idx == -1: - break - count += 1 - - # 上下文校验 - if context_before or context_after: - ctx_error = self._check_context(old_content, idx, old_string, context_before, context_after, count) - if ctx_error: - return self.make_failed_response( - locals().copy(), error=f"无法修改上下文不匹配的字符串:\n{ctx_error}" - ) - - idx += len(old_string) - - if count == 0: - return self.make_failed_response( - locals().copy(), - error=f"No changes made: old_string not found in file.\nFile: {path}\nSearching for: '{old_string}'", - ) - - # 6. 执行替换(生成新内容) - new_content = old_content.replace(old_string, new_string, count) - - # 7. 生成 diff - rel_path = str(resolved_path.relative_to(self.workspace.root_path)) - diff_content = self._generate_diff(old_content, new_content, rel_path) - - # 8. 记录快照 - from src.core.file_tracker import FileTracker - - old_hash = FileTracker.compute_checksum_from_string(old_content) - new_hash = FileTracker.compute_checksum_from_string(new_content) - session_id = self.workspace.session_id - snapshot_id = self.workspace.db.record_file_snapshot( - rel_path, - old_hash, - new_hash, - diff_content, - audit_status="PENDING_AUDIT", - session_id=session_id, - pending_content=new_content, - ) - - # 9. 返回预览 - return self.make_success_response( - locals().copy(), - ( - "修改已推送到审核系统\n" - f"[Edit Preview]\n" - f"File: {rel_path}\n" - f"Snapshot ID: {snapshot_id}\n" - f"Replacements: {count}\n" - f"Diff:\n{diff_content}" - ), - ) - - @staticmethod - def _check_context( - content: str, - match_start: int, - old_string: str, - context_before: str, - context_after: str, - match_number: int, - ) -> str | None: - """校验匹配处的上下文是否与预期一致.""" - if context_before: - actual_start = max(0, match_start - len(context_before)) - actual_before = content[actual_start:match_start] - if actual_before != context_before: - return ToolErrorResponse( - "EditTool", - ValueError( - f"Match {match_number}: context_before mismatch.\n" - f" Expected: '{context_before}'\n" - f" Actual: '{actual_before}'" - ), - ).to_str() - - if context_after: - after_start = match_start + len(old_string) - after_end = min(len(content), after_start + len(context_after)) - actual_after = content[after_start:after_end] - if actual_after != context_after: - return ToolErrorResponse( - "EditTool", - ValueError( - f"Match {match_number}: context_after mismatch.\n" - f" Expected: '{context_after}'\n" - f" Actual: '{actual_after}'" - ), - ).to_str() - - return None +"""安全的字符串替换编辑工具 -- 只发布待审核更改.""" + +from pathlib import Path + +from src.models.tool_error_response import ToolErrorResponse +from src.models.tools.tool_result import ToolResult +from src.utils.binary_detector import is_binary_file +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +class EditTool(BaseTool): + """安全的字符串替换编辑工具 — 两阶段提交 (预览 → 审核确认). + + 只计算 diff 并记录 PENDING_AUDIT 快照,不直接修改磁盘. + 由审核提交模块 (AuditCommitter) 在批准后执行实际写入. + """ + + def __init__(self, workspace: Workspace): + super().__init__(workspace, "edit", self.edit.__doc__, write_permission=True) + self.func = self.edit + self.params = BaseTool.extract_params(self.edit) + self.param_descriptions = { + "path": "文件路径", + "old_string": "待替换的字符串", + "new_string": "替换后的字符串", + "max_replacements": "最大替换次数(1~100)", + "context_before": "匹配前的上下文文本", + "context_after": "匹配后的上下文文本", + } + + @BaseTool.handle_tool_exceptions + def edit( + self, + path: str, + old_string: str, + new_string: str, + max_replacements: int = 10, + context_before: str = "", + context_after: str = "", + ) -> ToolResult: + """ + 通过在文件中进行安全的字符串替换编辑文件 + """ + # 1. 参数校验 + if not old_string: + return self.make_failed_response(locals().copy(), error=f"{ValueError('old_string 不能为空')}") + + if max_replacements < 1: + return self.make_failed_response(locals().copy(), error=f"{ValueError('max_replacements 必须 >= 1')}") + if max_replacements > 100: + max_replacements = 100 + + # 2. 路径解析 + source_path = Path(path) + resolved_path: Path = self.workspace.path_validator.resolve_path(source_path) + + if not resolved_path.is_file(): + return self.make_failed_response(locals().copy(), error=f"{FileNotFoundError(f'文件不存在: {resolved_path}')}") + + if is_binary_file(resolved_path): + return self.make_failed_response(locals().copy(), error=f"{ValueError(f'禁止编辑二进制文件: {resolved_path}')}") + + # 3. mtime 校验 + mtime_error = self._validate_mtime(resolved_path) + if mtime_error: + return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") + + # 4. 读取文件内容 + old_content = resolved_path.read_text(encoding="utf-8") + + # 5. 查找匹配 + count = 0 + idx = 0 + while count < max_replacements: + idx = old_content.find(old_string, idx) + if idx == -1: + break + count += 1 + + # 上下文校验 + if context_before or context_after: + ctx_error = self._check_context(old_content, idx, old_string, context_before, context_after, count) + if ctx_error: + return self.make_failed_response(locals().copy(), error=f"无法修改上下文不匹配的字符串:\n{ctx_error}") + + idx += len(old_string) + + if count == 0: + return self.make_failed_response( + locals().copy(), + error=f"No changes made: old_string not found in file.\nFile: {path}\nSearching for: '{old_string}'", + ) + + # 6. 执行替换(生成新内容) + new_content = old_content.replace(old_string, new_string, count) + + # 7. 生成 diff + rel_path = str(resolved_path.relative_to(self.workspace.root_path)) + diff_content = self._generate_diff(old_content, new_content, rel_path) + + # 8. 记录快照 + from src.core.file_tracker import FileTracker + + old_hash = FileTracker.compute_checksum_from_string(old_content) + new_hash = FileTracker.compute_checksum_from_string(new_content) + session_id = self.workspace.session_id + snapshot_id = self.workspace.db.record_file_snapshot( + rel_path, + old_hash, + new_hash, + diff_content, + audit_status="PENDING_AUDIT", + session_id=session_id, + pending_content=new_content, + ) + + # 9. 返回预览 + return self.make_success_response( + locals().copy(), + (f"修改已推送到审核系统\n[Edit Preview]\nFile: {rel_path}\nSnapshot ID: {snapshot_id}\nReplacements: {count}\nDiff:\n{diff_content}"), + ) + + @staticmethod + def _check_context( + content: str, + match_start: int, + old_string: str, + context_before: str, + context_after: str, + match_number: int, + ) -> str | None: + """校验匹配处的上下文是否与预期一致.""" + if context_before: + actual_start = max(0, match_start - len(context_before)) + actual_before = content[actual_start:match_start] + if actual_before != context_before: + return ToolErrorResponse( + "EditTool", + ValueError(f"Match {match_number}: context_before mismatch.\n Expected: '{context_before}'\n Actual: '{actual_before}'"), + ).to_str() + + if context_after: + after_start = match_start + len(old_string) + after_end = min(len(content), after_start + len(context_after)) + actual_after = content[after_start:after_end] + if actual_after != context_after: + return ToolErrorResponse( + "EditTool", + ValueError(f"Match {match_number}: context_after mismatch.\n Expected: '{context_after}'\n Actual: '{actual_after}'"), + ).to_str() + + return None diff --git a/src/workspace/tools/exact_search_tool.py b/src/workspace/tools/exact_search_tool.py index 8456b04..3e8fdef 100644 --- a/src/workspace/tools/exact_search_tool.py +++ b/src/workspace/tools/exact_search_tool.py @@ -1,179 +1,166 @@ -import re -from pathlib import Path - -from src.models.tools.tool_result import ToolResult -from src.utils.binary_detector import is_binary_file -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -def _search_exact_in_file(lines: list[str], search_string: str, case_sensitive: bool, whole_word: bool) -> list[dict]: - """在文件中精确搜索""" - matches = [] - - for i, line in enumerate(lines): - line_content = line if case_sensitive else line.lower() - - if whole_word: - # 全词匹配:使用正则表达式 - if case_sensitive: - word_pattern = re.compile(r"\b" + re.escape(search_string) + r"\b") - else: - word_pattern = re.compile(r"\b" + re.escape(search_string) + r"\b", re.IGNORECASE) - - if word_pattern.search(line_content): - matches.append({"line_num": i + 1, "content": line.rstrip("\n\r")}) - else: - # 简单包含匹配 - if search_string in line_content: - matches.append({"line_num": i + 1, "content": line.rstrip("\n\r")}) - - return matches - - -def _format_exact_results( - results: list[dict], pattern: str, limit: int, file_count: int, case_sensitive: bool, whole_word: bool -) -> str: - """格式化精确搜索结果""" - if not results: - return f"未找到匹配字符串 '{pattern}' 的内容" - - total_matches = sum(len(r["matches"]) for r in results) - truncated = total_matches > limit - - output = [ - f"精确搜索: '{pattern}'", - f"大小写敏感: {'是' if case_sensitive else '否'}, 全词匹配: {'是' if whole_word else '否'}", - f"匹配文件数: {file_count}, 匹配项数: {min(total_matches, limit)}", - ] - - if truncated: - output.append(f"⚠️ 结果已截断,仅显示前 {limit} 个匹配项(实际共 {total_matches} 个)") - - output.append("=" * 60) - - displayed_matches = 0 - for file_result in results: - if displayed_matches >= limit: - break - - output.append(f"\n文件: {file_result['file']}") - output.append("-" * 40) - - for match in file_result["matches"]: - if displayed_matches >= limit: - output.append(f"\n... 以及 {total_matches - limit} 个未显示的匹配项") - break - output.append(f" 第 {match['line_num']:4d} 行: {match['content']}") - displayed_matches += 1 - - return "\n".join(output) - - -class ExactSearchTool(BaseTool): - """精确搜索工具,用于安全审计""" - - def __init__(self, workspace: Workspace): - super().__init__(workspace, "exact_search", self.exact_search.__doc__) - self.func = self.exact_search - self.params = BaseTool.extract_params(self.exact_search) - self.param_descriptions = { - "pattern": "搜索字符串", - "path": "搜索文件或文件夹路径", - "case_sensitive": "是否大小写敏感", - "whole_word": "是否全词匹配", - "file_pattern": "文件匹配模式,支持通配符", - "limit": "最大匹配数量限制", - "ignore": "忽略匹配正则的文件或文件夹列表", - } - self._exclusion_manager = workspace.exclusion_manager - - @BaseTool.handle_tool_exceptions - def exact_search( - self, - pattern: str, - path: str = ".", - case_sensitive: bool = True, - whole_word: bool = True, - file_pattern: str = "*", - limit: int = 256, - ignore: list[str] | None = None, - ) -> ToolResult: - """ - 精确搜索字符串 - """ - # 验证搜索路径 - search_path: Path = self.workspace.path_validator.validate(path) - - # 准备搜索字符串 - search_string = pattern if case_sensitive else pattern.lower() - - # 收集忽略模式: 合并默认排除 + 用户传入的 ignore - ignore_patterns = self._exclusion_manager.merge_ignore_regexes(ignore) - - # 搜索结果 - results = [] - file_count = 0 - total_matches = 0 - warnings = [""] - - # 确定要搜索的文件列表(支持单文件或目录) - files_to_search = ( - [search_path] - if search_path.is_file() - else [ - p - for p in search_path.rglob(file_pattern) - if p.is_file() and not self._exclusion_manager.should_exclude_path(p) - ] - ) - - # 遍历所有文件 - for file_path in files_to_search: - if not file_path.is_file(): - continue - - if is_binary_file(file_path): - continue - - # 检查是否达到限制 - if total_matches >= limit: - break - - # 检查是否应该忽略 - should_ignore = False - relative_path = file_path.relative_to(search_path) if search_path.is_dir() else file_path - - for ignore_pattern in ignore_patterns: - if ignore_pattern.search(str(relative_path)): - should_ignore = True - break - - if should_ignore: - continue - - try: - # 读取文件内容 - with open(file_path, encoding="utf-8") as f: - lines = f.readlines() - - # 搜索匹配行 - file_matches = _search_exact_in_file(lines, search_string, case_sensitive, whole_word) - - if file_matches: - results.append({"file": str(file_path), "matches": file_matches}) - file_count += 1 - total_matches += len(file_matches) - - except (OSError, UnicodeDecodeError, PermissionError) as e: - warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") - continue # 跳过无法读取的文件 - - warnings.append("") - - # 格式化输出 - return self.make_success_response( - kwargs=locals().copy(), - data=_format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word), - error="\n".join(warnings) if len(warnings) > 2 else None, - ) +import re +from pathlib import Path + +from src.models.tools.tool_result import ToolResult +from src.utils.binary_detector import is_binary_file +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +def _search_exact_in_file(lines: list[str], search_string: str, case_sensitive: bool, whole_word: bool) -> list[dict]: + """在文件中精确搜索""" + matches = [] + + for i, line in enumerate(lines): + line_content = line if case_sensitive else line.lower() + + if whole_word: + # 全词匹配:使用正则表达式 + word_pattern = re.compile(r"\b" + re.escape(search_string) + r"\b") if case_sensitive else re.compile(r"\b" + re.escape(search_string) + r"\b", re.IGNORECASE) + + if word_pattern.search(line_content): + matches.append({"line_num": i + 1, "content": line.rstrip("\n\r")}) + else: + # 简单包含匹配 + if search_string in line_content: + matches.append({"line_num": i + 1, "content": line.rstrip("\n\r")}) + + return matches + + +def _format_exact_results(results: list[dict], pattern: str, limit: int, file_count: int, case_sensitive: bool, whole_word: bool) -> str: + """格式化精确搜索结果""" + if not results: + return f"未找到匹配字符串 '{pattern}' 的内容" + + total_matches = sum(len(r["matches"]) for r in results) + truncated = total_matches > limit + + output = [ + f"精确搜索: '{pattern}'", + f"大小写敏感: {'是' if case_sensitive else '否'}, 全词匹配: {'是' if whole_word else '否'}", + f"匹配文件数: {file_count}, 匹配项数: {min(total_matches, limit)}", + ] + + if truncated: + output.append(f"⚠️ 结果已截断,仅显示前 {limit} 个匹配项(实际共 {total_matches} 个)") + + output.append("=" * 60) + + displayed_matches = 0 + for file_result in results: + if displayed_matches >= limit: + break + + output.append(f"\n文件: {file_result['file']}") + output.append("-" * 40) + + for match in file_result["matches"]: + if displayed_matches >= limit: + output.append(f"\n... 以及 {total_matches - limit} 个未显示的匹配项") + break + output.append(f" 第 {match['line_num']:4d} 行: {match['content']}") + displayed_matches += 1 + + return "\n".join(output) + + +class ExactSearchTool(BaseTool): + """精确搜索工具,用于安全审计""" + + def __init__(self, workspace: Workspace): + super().__init__(workspace, "exact_search", self.exact_search.__doc__) + self.func = self.exact_search + self.params = BaseTool.extract_params(self.exact_search) + self.param_descriptions = { + "pattern": "搜索字符串", + "path": "搜索文件或文件夹路径", + "case_sensitive": "是否大小写敏感", + "whole_word": "是否全词匹配", + "file_pattern": "文件匹配模式,支持通配符", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + } + self._exclusion_manager = workspace.exclusion_manager + + @BaseTool.handle_tool_exceptions + def exact_search( + self, + pattern: str, + path: str = ".", + case_sensitive: bool = True, + whole_word: bool = True, + file_pattern: str = "*", + limit: int = 256, + ignore: list[str] | None = None, + ) -> ToolResult: + """ + 精确搜索字符串 + """ + # 验证搜索路径 + search_path: Path = self.workspace.path_validator.validate(path) + + # 准备搜索字符串 + search_string = pattern if case_sensitive else pattern.lower() + + # 收集忽略模式: 合并默认排除 + 用户传入的 ignore + ignore_patterns = self._exclusion_manager.merge_ignore_regexes(ignore) + + # 搜索结果 + results = [] + file_count = 0 + total_matches = 0 + warnings = [""] + + # 确定要搜索的文件列表(支持单文件或目录) + files_to_search = [search_path] if search_path.is_file() else [p for p in search_path.rglob(file_pattern) if p.is_file() and not self._exclusion_manager.should_exclude_path(p)] + + # 遍历所有文件 + for file_path in files_to_search: + if not file_path.is_file(): + continue + + if is_binary_file(file_path): + continue + + # 检查是否达到限制 + if total_matches >= limit: + break + + # 检查是否应该忽略 + should_ignore = False + relative_path = file_path.relative_to(search_path) if search_path.is_dir() else file_path + + for ignore_pattern in ignore_patterns: + if ignore_pattern.search(str(relative_path)): + should_ignore = True + break + + if should_ignore: + continue + + try: + # 读取文件内容 + with open(file_path, encoding="utf-8") as f: + lines = f.readlines() + + # 搜索匹配行 + file_matches = _search_exact_in_file(lines, search_string, case_sensitive, whole_word) + + if file_matches: + results.append({"file": str(file_path), "matches": file_matches}) + file_count += 1 + total_matches += len(file_matches) + + except (OSError, UnicodeDecodeError, PermissionError) as e: + warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") + continue # 跳过无法读取的文件 + + warnings.append("") + + # 格式化输出 + return self.make_success_response( + kwargs=locals().copy(), + data=_format_exact_results(results, pattern, limit, file_count, case_sensitive, whole_word), + error="\n".join(warnings) if len(warnings) > 2 else None, + ) diff --git a/src/workspace/tools/git_tool.py b/src/workspace/tools/git_tool.py index d64f4ce..d5ce450 100644 --- a/src/workspace/tools/git_tool.py +++ b/src/workspace/tools/git_tool.py @@ -1,207 +1,190 @@ -"""统一的 Git 工具 — 白名单机制,安全封装.""" - -import re -import shlex -import subprocess - -from src.models.tools.tool_result import ToolResult -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - -# 安全命令(只读,直接执行,不需要审核) -_SAFE_COMMANDS = frozenset({"status", "diff", "log", "show"}) - -# 白名单(所有允许的命令) -_ALLOWED_COMMANDS = frozenset( - { - "status", - "diff", - "log", - "add", - "commit", - "restore", - "show", - "branch", - } -) - -# 拦截正则 — 即使白名单允许的也再检查一次 -_BLOCKED_PATTERNS = [ - re.compile(r"\bpush\b"), - re.compile(r"\bremote\b"), - re.compile(r"\breset\s+--hard\b"), - re.compile(r"\bbranch\s+-D\b"), - re.compile(r"\bmerge\b"), - re.compile(r"\brebase\b"), - re.compile(r"\bclean\b"), - re.compile(r"\bcheckout\s+-B\b"), - re.compile(r"\bcherry-pick\b"), - re.compile(r"\btag\b"), - re.compile(r"\bfetch\b"), - re.compile(r"\bpull\b"), -] - - -class GitTool(BaseTool): - """统一的 Git 工具 — 安全的子命令执行. - - 白名单机制: - - 安全命令 (status, diff, log, show): 直接执行,不触发审核 - - 修改命令 (add, commit, restore, branch): 执行并标记 PENDING_AUDIT - - 禁止命令 (push, reset --hard, merge, rebase, ...): 拦截并返回错误 - """ - - def __init__(self, workspace: Workspace): - super().__init__(workspace, "git", self.git.__doc__, read_permission=True) - self.func = self.git - self.params = BaseTool.extract_params(self.git) - self.param_descriptions = { - "command_str": "Git 子命令及其参数,如 'status'、'diff --cached'、'log --oneline -5'", - } - - def git(self, command_str: str) -> ToolResult: - """ - 执行 Git 命令 - """ - if not command_str or not command_str.strip(): - return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError("command_str 不能为空"))) - - try: - tokens = shlex.split(command_str) - except ValueError as e: - return self.make_failed_response(kwargs=locals().copy(), error=str(e)) - - if not tokens: - return self.make_failed_response( - kwargs=locals().copy(), error=str(ValueError(f"无法解析命令: `{command_str}`")) - ) - - base_command = tokens[0] - - # 1. 白名单检查 - if base_command not in _ALLOWED_COMMANDS: - allowed_list = ", ".join(sorted(_ALLOWED_COMMANDS)) - return self.make_failed_response( - kwargs=locals().copy(), - error=( - f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\n" - f"Allowed commands: {allowed_list}" - ), - ) - - # 2. 拦截正则检查 - for pattern in _BLOCKED_PATTERNS: - if pattern.search(command_str): - return self.make_failed_response( - kwargs=locals().copy(), - error=( - f"ERROR: The command was blocked by security policy.\n" - f"Pattern matched: {pattern.pattern}\n" - f"Command: {command_str}" - ), - ) - - # 3. restore 安全检查 — 必须指定文件路径 - if base_command == "restore": - non_flag_args = [t for t in tokens[1:] if not t.startswith("-")] - if not non_flag_args: - return self.make_failed_response( - kwargs=locals().copy(), error=str(ValueError("restore 需要指定文件路径,不允许裸 restore")) - ) - for arg in non_flag_args: - stripped = arg.strip() - if stripped in (".", "*", "all") or stripped.startswith("*"): - return self.make_failed_response( - kwargs=locals().copy(), error=str(ValueError("restore 需要指定具体文件路径,不允许使用通配符")) - ) - - # 4. 执行命令 - try: - env = {**__import__("os").environ, "GIT_PAGER": "cat", "GIT_TERMINAL_PROMPT": "0"} - result = subprocess.run( - ["git", *tokens], - capture_output=True, - text=True, - timeout=30, - cwd=str(self.workspace.root_path), - env=env, - ) - except FileNotFoundError: - return self.make_failed_response(kwargs=locals().copy(), error=str(OSError("Git 未安装或不在系统 PATH 中"))) - except subprocess.TimeoutExpired as time_out_exception: - return self.make_failed_response( - kwargs=locals().copy(), error=f"TimeoutExpired(Git 命令执行超时: {time_out_exception})" - ) - - # 5. 处理输出 — 总是保留 stdout 和 stderr, 即使 returncode != 0 (如 git diff --exit-code) - output_parts = [] - if result.stdout: - output_parts.append(result.stdout.rstrip("\n")) - if result.stderr: - output_parts.append(result.stderr.rstrip("\n")) - - if result.returncode != 0: - stderr = (result.stderr or "").strip() - error_msg = f"Git command exited with code {result.returncode}" - if stderr: - error_msg += f":\n{stderr}" - # 保留 stdout 在 data 中, 同时返回 error - return self.make_failed_response( - kwargs=locals().copy(), - data="\n".join(output_parts) if output_parts else "(no output)", - error=error_msg, - ) - - if result.stdout is None and not result.stderr: - # HACK: subprocess.run with capture_output=True returns None for stdout - # on this platform. Fallback: re-run with explicit PIPE (bytes mode). - # 这是Ruff的`UP022`规则报的lint警告:建议用`capture_output=True`替代显式设置`stdout=PIPE, stderr=PIPE` - # 但这里的场景特殊——正是`capture_output=True, text=True`导致stdout为`None`, - # 才需要回退到bytes模式的手动PIPE, 所以这是一个**有意为之的例外**,不应遵循该建议. - _result2 = subprocess.run( # noqa: UP022 - ["git", *tokens], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - timeout=30, - cwd=str(self.workspace.root_path), - env={**__import__("os").environ, "GIT_PAGER": "cat", "GIT_TERMINAL_PROMPT": "0"}, - ) - out = _result2.stdout.decode("utf-8", errors="replace") if _result2.stdout else "" - err = _result2.stderr.decode("utf-8", errors="replace") if _result2.stderr else "" - if out: - output_parts.append(out.rstrip("\n")) - if err: - output_parts.append(err.rstrip("\n")) - if _result2.returncode != 0: - error_msg2 = f"Git command exited with code {_result2.returncode}" - if err: - error_msg2 += f":\n{err}" - return self.make_failed_response( - kwargs=locals().copy(), - data="\n".join(output_parts) if output_parts else "(no output)", - error=error_msg2, - ) - - return self.make_success_response( - kwargs=locals().copy(), data="\n".join(output_parts) if output_parts else "(no output)" - ) - - @staticmethod - def is_safe_command(command_str: str) -> bool: - """判断一个 git 命令是否安全(只读,不需要审核). - - Args: - command_str: 完整的 git 命令字符串 - - Returns: - 如果是安全命令返回 True,否则返回 False - """ - if not command_str or not command_str.strip(): - return False - try: - tokens = shlex.split(command_str) - except ValueError: - return False - if not tokens: - return False - return tokens[0] in _SAFE_COMMANDS +"""统一的 Git 工具 — 白名单机制,安全封装.""" + +import re +import shlex +import subprocess + +from src.models.tools.tool_result import ToolResult +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + +# 安全命令(只读,直接执行,不需要审核) +_SAFE_COMMANDS = frozenset({"status", "diff", "log", "show"}) + +# 白名单(所有允许的命令) +_ALLOWED_COMMANDS = frozenset( + { + "status", + "diff", + "log", + "add", + "commit", + "restore", + "show", + "branch", + } +) + +# 拦截正则 — 即使白名单允许的也再检查一次 +_BLOCKED_PATTERNS = [ + re.compile(r"\bpush\b"), + re.compile(r"\bremote\b"), + re.compile(r"\breset\s+--hard\b"), + re.compile(r"\bbranch\s+-D\b"), + re.compile(r"\bmerge\b"), + re.compile(r"\brebase\b"), + re.compile(r"\bclean\b"), + re.compile(r"\bcheckout\s+-B\b"), + re.compile(r"\bcherry-pick\b"), + re.compile(r"\btag\b"), + re.compile(r"\bfetch\b"), + re.compile(r"\bpull\b"), +] + + +class GitTool(BaseTool): + """统一的 Git 工具 — 安全的子命令执行. + + 白名单机制: + - 安全命令 (status, diff, log, show): 直接执行,不触发审核 + - 修改命令 (add, commit, restore, branch): 执行并标记 PENDING_AUDIT + - 禁止命令 (push, reset --hard, merge, rebase, ...): 拦截并返回错误 + """ + + def __init__(self, workspace: Workspace): + super().__init__(workspace, "git", self.git.__doc__, read_permission=True) + self.func = self.git + self.params = BaseTool.extract_params(self.git) + self.param_descriptions = { + "command_str": "Git 子命令及其参数,如 'status'、'diff --cached'、'log --oneline -5'", + } + + def git(self, command_str: str) -> ToolResult: + """ + 执行 Git 命令 + """ + if not command_str or not command_str.strip(): + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError("command_str 不能为空"))) + + try: + tokens = shlex.split(command_str) + except ValueError as e: + return self.make_failed_response(kwargs=locals().copy(), error=str(e)) + + if not tokens: + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError(f"无法解析命令: `{command_str}`"))) + + base_command = tokens[0] + + # 1. 白名单检查 + if base_command not in _ALLOWED_COMMANDS: + allowed_list = ", ".join(sorted(_ALLOWED_COMMANDS)) + return self.make_failed_response( + kwargs=locals().copy(), + error=(f"ERROR: Git command '{base_command}' is not in the allowed whitelist.\nAllowed commands: {allowed_list}"), + ) + + # 2. 拦截正则检查 + for pattern in _BLOCKED_PATTERNS: + if pattern.search(command_str): + return self.make_failed_response( + kwargs=locals().copy(), + error=(f"ERROR: The command was blocked by security policy.\nPattern matched: {pattern.pattern}\nCommand: {command_str}"), + ) + + # 3. restore 安全检查 — 必须指定文件路径 + if base_command == "restore": + non_flag_args = [t for t in tokens[1:] if not t.startswith("-")] + if not non_flag_args: + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError("restore 需要指定文件路径,不允许裸 restore"))) + for arg in non_flag_args: + stripped = arg.strip() + if stripped in (".", "*", "all") or stripped.startswith("*"): + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError("restore 需要指定具体文件路径,不允许使用通配符"))) + + # 4. 执行命令 + try: + env = {**__import__("os").environ, "GIT_PAGER": "cat", "GIT_TERMINAL_PROMPT": "0"} + result = subprocess.run( + ["git", *tokens], + capture_output=True, + text=True, + timeout=30, + cwd=str(self.workspace.root_path), + env=env, + ) + except FileNotFoundError: + return self.make_failed_response(kwargs=locals().copy(), error=str(OSError("Git 未安装或不在系统 PATH 中"))) + except subprocess.TimeoutExpired as time_out_exception: + return self.make_failed_response(kwargs=locals().copy(), error=f"TimeoutExpired(Git 命令执行超时: {time_out_exception})") + + # 5. 处理输出 — 总是保留 stdout 和 stderr, 即使 returncode != 0 (如 git diff --exit-code) + output_parts = [] + if result.stdout: + output_parts.append(result.stdout.rstrip("\n")) + if result.stderr: + output_parts.append(result.stderr.rstrip("\n")) + + if result.returncode != 0: + stderr = (result.stderr or "").strip() + error_msg = f"Git command exited with code {result.returncode}" + if stderr: + error_msg += f":\n{stderr}" + # 保留 stdout 在 data 中, 同时返回 error + return self.make_failed_response( + kwargs=locals().copy(), + data="\n".join(output_parts) if output_parts else "(no output)", + error=error_msg, + ) + + if result.stdout is None and not result.stderr: + # HACK: subprocess.run with capture_output=True returns None for stdout + # on this platform. Fallback: re-run with explicit PIPE (bytes mode). + # 这是Ruff的`UP022`规则报的lint警告:建议用`capture_output=True`替代显式设置`stdout=PIPE, stderr=PIPE` + # 但这里的场景特殊——正是`capture_output=True, text=True`导致stdout为`None`, + # 才需要回退到bytes模式的手动PIPE, 所以这是一个**有意为之的例外**,不应遵循该建议. + _result2 = subprocess.run( # noqa: UP022 + ["git", *tokens], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=30, + cwd=str(self.workspace.root_path), + env={**__import__("os").environ, "GIT_PAGER": "cat", "GIT_TERMINAL_PROMPT": "0"}, + ) + out = _result2.stdout.decode("utf-8", errors="replace") if _result2.stdout else "" + err = _result2.stderr.decode("utf-8", errors="replace") if _result2.stderr else "" + if out: + output_parts.append(out.rstrip("\n")) + if err: + output_parts.append(err.rstrip("\n")) + if _result2.returncode != 0: + error_msg2 = f"Git command exited with code {_result2.returncode}" + if err: + error_msg2 += f":\n{err}" + return self.make_failed_response( + kwargs=locals().copy(), + data="\n".join(output_parts) if output_parts else "(no output)", + error=error_msg2, + ) + + return self.make_success_response(kwargs=locals().copy(), data="\n".join(output_parts) if output_parts else "(no output)") + + @staticmethod + def is_safe_command(command_str: str) -> bool: + """判断一个 git 命令是否安全(只读,不需要审核). + + Args: + command_str: 完整的 git 命令字符串 + + Returns: + 如果是安全命令返回 True,否则返回 False + """ + if not command_str or not command_str.strip(): + return False + try: + tokens = shlex.split(command_str) + except ValueError: + return False + if not tokens: + return False + return tokens[0] in _SAFE_COMMANDS diff --git a/src/workspace/tools/glob_tool.py b/src/workspace/tools/glob_tool.py index f47274a..50bed9d 100644 --- a/src/workspace/tools/glob_tool.py +++ b/src/workspace/tools/glob_tool.py @@ -1,36 +1,32 @@ -from itertools import islice - -from src.models.tools.tool_result import ToolResult -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -class GlobTool(BaseTool): - def __init__(self, workspace: Workspace): - super().__init__(workspace, "glob", self.glob.__doc__) - self.func = self.glob - self.params = BaseTool.extract_params(self.glob) - self.param_descriptions = { - "pattern": "通配符", - "path": "目录路径", - "max_ret": "最多返回多少条检索结果", - } - self._exclusion_manager = workspace.exclusion_manager - - @BaseTool.handle_tool_exceptions - def glob(self, pattern: str, path: str = ".", max_ret: int = 1000) -> ToolResult: - """ - 在工作区内按通配符模式匹配并列出所有路径,带[Folder]或[File]的类型标记. 失败时返回错误信息 - """ - root_path = self.workspace.path_validator.validate(path) - if not root_path.is_dir(): - return self.make_failed_response(kwargs=locals().copy(), error=f"{root_path}不是一个文件夹路径") - - return self.make_success_response( - kwargs=locals().copy(), - data=[ - f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" - for item in islice(root_path.glob(pattern), max_ret) - if not self._exclusion_manager.should_exclude_path(item) - ], - ) +from itertools import islice + +from src.models.tools.tool_result import ToolResult +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +class GlobTool(BaseTool): + def __init__(self, workspace: Workspace): + super().__init__(workspace, "glob", self.glob.__doc__) + self.func = self.glob + self.params = BaseTool.extract_params(self.glob) + self.param_descriptions = { + "pattern": "通配符", + "path": "目录路径", + "max_ret": "最多返回多少条检索结果", + } + self._exclusion_manager = workspace.exclusion_manager + + @BaseTool.handle_tool_exceptions + def glob(self, pattern: str, path: str = ".", max_ret: int = 1000) -> ToolResult: + """ + 在工作区内按通配符模式匹配并列出所有路径,带[Folder]或[File]的类型标记. 失败时返回错误信息 + """ + root_path = self.workspace.path_validator.validate(path) + if not root_path.is_dir(): + return self.make_failed_response(kwargs=locals().copy(), error=f"{root_path}不是一个文件夹路径") + + return self.make_success_response( + kwargs=locals().copy(), + data=[f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" for item in islice(root_path.glob(pattern), max_ret) if not self._exclusion_manager.should_exclude_path(item)], + ) diff --git a/src/workspace/tools/ls_tool.py b/src/workspace/tools/ls_tool.py index cbb176a..65aa43a 100644 --- a/src/workspace/tools/ls_tool.py +++ b/src/workspace/tools/ls_tool.py @@ -1,33 +1,29 @@ -from pathlib import Path - -from src.models.tools.tool_result import ToolResult -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -class LsTool(BaseTool): - def __init__(self, workspace: Workspace): - super().__init__(workspace, "ls", self.ls.__doc__) - self.func = self.ls - self.params = BaseTool.extract_params(self.ls) - self.param_descriptions = { - "path": "目录路径", - } - self._exclusion_manager = workspace.exclusion_manager - - @BaseTool.handle_tool_exceptions - def ls(self, path: str = ".") -> ToolResult: - """ - 列出指定目录下的文件和文件夹. 返回相对路径列表, 并标记[Folder]或[File] - """ - folder_path: Path = self.workspace.path_validator.validate(path) - if not folder_path.is_dir(): - return self.make_failed_response(kwargs=locals().copy(), error=f'参数错误: "{folder_path}"不是一个目录') - return self.make_success_response( - kwargs=locals().copy(), - data=[ - f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" - for item in folder_path.iterdir() - if not self._exclusion_manager.should_exclude_path(item) - ], - ) +from pathlib import Path + +from src.models.tools.tool_result import ToolResult +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +class LsTool(BaseTool): + def __init__(self, workspace: Workspace): + super().__init__(workspace, "ls", self.ls.__doc__) + self.func = self.ls + self.params = BaseTool.extract_params(self.ls) + self.param_descriptions = { + "path": "目录路径", + } + self._exclusion_manager = workspace.exclusion_manager + + @BaseTool.handle_tool_exceptions + def ls(self, path: str = ".") -> ToolResult: + """ + 列出指定目录下的文件和文件夹. 返回相对路径列表, 并标记[Folder]或[File] + """ + folder_path: Path = self.workspace.path_validator.validate(path) + if not folder_path.is_dir(): + return self.make_failed_response(kwargs=locals().copy(), error=f'参数错误: "{folder_path}"不是一个目录') + return self.make_success_response( + kwargs=locals().copy(), + data=[f"{'[Folder]' if item.is_dir() else '[File]'} {item.relative_to(self.workspace.root_path)}" for item in folder_path.iterdir() if not self._exclusion_manager.should_exclude_path(item)], + ) diff --git a/src/workspace/tools/read_tool.py b/src/workspace/tools/read_tool.py index 7bac9d8..e533cad 100644 --- a/src/workspace/tools/read_tool.py +++ b/src/workspace/tools/read_tool.py @@ -1,108 +1,98 @@ -import os -from pathlib import Path - -from src.models.tools.tool_result import ToolResult -from src.utils.binary_detector import is_binary_file -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - -# 最大文件读取大小, 超过此大小的文件将被拒绝读取(默认 10MB) -MAX_FILE_SIZE = int(os.getenv("TOOL_READ_MAX_FILE_SIZE", str(10 * 1024 * 1024))) - - -def _resolve_index(idx: int, total: int) -> int: - """Resolve a 1-based or negative index to a clamped 1-based line number.""" - if idx < 0: - idx = total + 1 + idx - if idx < 1: - return 1 - if idx > total: - return total - return idx - - -class ReadTool(BaseTool): - def __init__(self, workspace: Workspace): - super().__init__(workspace, "read", self.read.__doc__) - self.func = self.read - self.params = BaseTool.extract_params(self.read) - self.param_descriptions = { - "path": "文件路径", - "start": "起始行号(1开始; 负数表示倒数, -1=最后一行)", - "end": "结束行号(1开始; 负数表示倒数, -1=最后一行)", - "context": "扩展结果行数范围 行数范围最终为(start-context, end+context)", - "encoding": "编码", - } - - @BaseTool.handle_tool_exceptions - def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encoding: str = "utf-8") -> ToolResult: - """ - 读取文件内容, 返回带行号的格式化内容 - """ - file_path: Path = self.workspace.path_validator.validate(path) - - if not file_path.is_file(): - return self.make_failed_response( - kwargs=locals().copy(), error=str(ValueError(f"读取文件{file_path}时未读取到完整文件")) - ) - - if is_binary_file(file_path): - return self.make_failed_response( - kwargs=locals().copy(), - error=str(ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64.")), - ) - - file_size = file_path.stat().st_size - if file_size > MAX_FILE_SIZE: - return self.make_failed_response( - kwargs=locals().copy(), - error=str( - ValueError( - f"文件过大 ({file_size} 字节), 超过最大限制 ({MAX_FILE_SIZE} 字节): {file_path}. " - f"请使用范围参数 (start/end) 分批读取." - ) - ), - ) - - with open(file_path, encoding=encoding) as f: - lines = f.readlines() - - total_lines = len(lines) - - if total_lines == 0: - header = f"\n[文件: {file_path}]\n[行 0-0 / 共 0 行]\n" - separator = "-" * 80 + "\n" - self._record_read_meta(file_path) - return self.make_success_response(kwargs=locals().copy(), data=header + separator) - - context = max(0, context) - - actual_start = _resolve_index(start, total_lines) - context - actual_end = _resolve_index(end, total_lines) + context - - if actual_start < 1: - actual_start = 1 - if actual_end > total_lines: - actual_end = total_lines - - if actual_end < actual_start: - return self.make_failed_response( - kwargs=locals().copy(), - error=( - f"错误:解析后的结束行 {actual_end} 小于起始行 {actual_start} " - f"(原始参数: start={start}, end={end}, context={context})" - ), - ) - - result_lines = [] - for i in range(actual_start - 1, actual_end): - line_num = i + 1 - content = lines[i].rstrip("\n\r") - result_lines.append(f"{line_num:6d} | {content}") - - header = f"\n[文件: {file_path}]\n[行 {actual_start}-{actual_end} / 共 {total_lines} 行]\n" - separator = "-" * 80 + "\n" - - self._record_read_meta(file_path) - - return self.make_success_response(kwargs=locals().copy(), data=header + separator + "\n".join(result_lines)) +import os +from pathlib import Path + +from src.models.tools.tool_result import ToolResult +from src.utils.binary_detector import is_binary_file +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + +# 最大文件读取大小, 超过此大小的文件将被拒绝读取(默认 10MB) +MAX_FILE_SIZE = int(os.getenv("TOOL_READ_MAX_FILE_SIZE", str(10 * 1024 * 1024))) + + +def _resolve_index(idx: int, total: int) -> int: + """Resolve a 1-based or negative index to a clamped 1-based line number.""" + if idx < 0: + idx = total + 1 + idx + if idx < 1: + return 1 + if idx > total: + return total + return idx + + +class ReadTool(BaseTool): + def __init__(self, workspace: Workspace): + super().__init__(workspace, "read", self.read.__doc__) + self.func = self.read + self.params = BaseTool.extract_params(self.read) + self.param_descriptions = { + "path": "文件路径", + "start": "起始行号(1开始; 负数表示倒数, -1=最后一行)", + "end": "结束行号(1开始; 负数表示倒数, -1=最后一行)", + "context": "扩展结果行数范围 行数范围最终为(start-context, end+context)", + "encoding": "编码", + } + + @BaseTool.handle_tool_exceptions + def read(self, path: str, start: int = 1, end: int = -1, context: int = 0, encoding: str = "utf-8") -> ToolResult: + """ + 读取文件内容, 返回带行号的格式化内容 + """ + file_path: Path = self.workspace.path_validator.validate(path) + + if not file_path.is_file(): + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError(f"读取文件{file_path}时未读取到完整文件"))) + + if is_binary_file(file_path): + return self.make_failed_response( + kwargs=locals().copy(), + error=str(ValueError(f"无法读取二进制文件: {file_path}. 请使用二进制安全工具或转换为 base64.")), + ) + + file_size = file_path.stat().st_size + if file_size > MAX_FILE_SIZE: + return self.make_failed_response( + kwargs=locals().copy(), + error=str(ValueError(f"文件过大 ({file_size} 字节), 超过最大限制 ({MAX_FILE_SIZE} 字节): {file_path}. 请使用范围参数 (start/end) 分批读取.")), + ) + + with open(file_path, encoding=encoding) as f: + lines = f.readlines() + + total_lines = len(lines) + + if total_lines == 0: + header = f"\n[文件: {file_path}]\n[行 0-0 / 共 0 行]\n" + separator = "-" * 80 + "\n" + self._record_read_meta(file_path) + return self.make_success_response(kwargs=locals().copy(), data=header + separator) + + context = max(0, context) + + actual_start = _resolve_index(start, total_lines) - context + actual_end = _resolve_index(end, total_lines) + context + + if actual_start < 1: + actual_start = 1 + if actual_end > total_lines: + actual_end = total_lines + + if actual_end < actual_start: + return self.make_failed_response( + kwargs=locals().copy(), + error=(f"错误:解析后的结束行 {actual_end} 小于起始行 {actual_start} (原始参数: start={start}, end={end}, context={context})"), + ) + + result_lines = [] + for i in range(actual_start - 1, actual_end): + line_num = i + 1 + content = lines[i].rstrip("\n\r") + result_lines.append(f"{line_num:6d} | {content}") + + header = f"\n[文件: {file_path}]\n[行 {actual_start}-{actual_end} / 共 {total_lines} 行]\n" + separator = "-" * 80 + "\n" + + self._record_read_meta(file_path) + + return self.make_success_response(kwargs=locals().copy(), data=header + separator + "\n".join(result_lines)) diff --git a/src/workspace/tools/regex_search_tool.py b/src/workspace/tools/regex_search_tool.py index 7ec6fad..70c3ed1 100644 --- a/src/workspace/tools/regex_search_tool.py +++ b/src/workspace/tools/regex_search_tool.py @@ -1,207 +1,199 @@ -import re -from pathlib import Path - -from src.models.tools.tool_result import ToolResult -from src.utils.binary_detector import is_binary_file -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -def _search_in_file(lines: list[str], regex: re.Pattern, context: int) -> list[dict]: - """在文件行中搜索匹配""" - matches = [] - line_count = len(lines) - - for i, line in enumerate(lines): - if regex.search(line): - # 计算上下文行范围 - start_line = max(0, i - context) - end_line = min(line_count, i + context + 1) - - # 收集上下文行 - context_lines = [] - for j in range(start_line, end_line): - is_match = j == i - context_lines.append({"line_num": j + 1, "content": lines[j].rstrip("\n\r"), "is_match": is_match}) - - matches.append({"line_num": i + 1, "content": line.rstrip("\n\r"), "context": context_lines}) - - return matches - - -def _format_regex_results(results: list[dict], pattern: str, limit: int, file_count: int) -> str: - """格式化正则搜索结果""" - if not results: - return f"未找到匹配正则表达式 '{pattern}' 的内容" - - total_matches = sum(len(r["matches"]) for r in results) - truncated = total_matches > limit - - output = [f"正则表达式搜索: '{pattern}'", f"匹配文件数: {file_count}, 匹配项数: {min(total_matches, limit)}"] - - if truncated: - output.append(f"⚠️ 结果已截断,仅显示前 {limit} 个匹配项(实际共 {total_matches} 个)") - - output.append("=" * 60) - - displayed_matches = 0 - for file_result in results: - if displayed_matches >= limit: - break - - output.append(f"\n文件: {file_result['file']}") - output.append("-" * 40) - - for match in file_result["matches"]: - if displayed_matches >= limit: - output.append(f"\n... 以及 {total_matches - limit} 个未显示的匹配项") - break - - if "context" in match: - output.append(f"第 {match['line_num']} 行匹配:") - for ctx_line in match["context"]: - # 增强包裹标记:匹配行用 >>> 和 <<< 包裹 - if ctx_line["is_match"]: - # 提取匹配的具体内容并用标记包裹 - content = ctx_line["content"] - # 使用正则找到实际匹配的部分并包裹 - try: - regex = re.compile(pattern) - # 找到所有匹配位置并添加标记 - matches_positions = list(regex.finditer(content)) - if matches_positions: - # 从后往前插入标记,避免位置偏移 - result_parts = [] - last_pos = 0 - for m in matches_positions: - result_parts.append(content[last_pos : m.start()]) - result_parts.append(f">>>{m.group()}<<<") - last_pos = m.end() - result_parts.append(content[last_pos:]) - marked_content = "".join(result_parts) - else: - marked_content = f">>>{content}<<<" - except re.error: - marked_content = f">>>{content}<<<" - - output.append(f" >>> L{ctx_line['line_num']:4d}: {marked_content}") - else: - output.append(f" L{ctx_line['line_num']:4d}: {ctx_line['content']}") - else: - # 旧格式兼容 - output.append(f" 第 {match['line_num']:4d} 行: >>>{match['content']}<<<") - output.append("") - displayed_matches += 1 - - return "\n".join(output) - - -class RegexSearchTool(BaseTool): - """正则表达式搜索工具""" - - def __init__(self, workspace: Workspace): - super().__init__(workspace, "regex_search", self.regex_search.__doc__) - self.func = self.regex_search - self.params = BaseTool.extract_params(self.regex_search) - self.param_descriptions = { - "pattern": "正则表达式模式", - "path": "搜索文件或文件夹路径", - "context": "显示匹配行的上下文行数", - "file_pattern": "文件匹配模式,支持通配符", - "limit": "最大匹配数量限制", - "ignore": "忽略匹配正则的文件或文件夹列表", - } - self._exclusion_manager = workspace.exclusion_manager - - @BaseTool.handle_tool_exceptions - def regex_search( - self, - pattern: str, - path: str = ".", - context: int = 3, - file_pattern: str = "*", - limit: int = 256, - ignore: list[str] | None = None, - ) -> ToolResult: - """ - 使用正则表达式搜索文件内容, 支持上下文显示、文件过滤和忽略路径, 返回匹配详情; 适合代码与文档探索 - """ - # 验证搜索路径 - search_path: Path = self.workspace.path_validator.validate(path) - - # 编译正则表达式 - try: - regex = re.compile(pattern) - except re.error as e: - return self.make_failed_response(kwargs=locals().copy(), error=f"无效的正则表达式: {e}") - - # 收集忽略模式: 合并默认排除 + 用户传入的 ignore - ignore_patterns = self._exclusion_manager.merge_ignore_regexes(ignore) - - # 搜索结果 - results = [] - file_count = 0 - total_matches = 0 - warnings = [""] - - # 确定要搜索的文件列表(支持单文件或目录) - files_to_search = ( - [search_path] - if search_path.is_file() - else [ - p - for p in search_path.rglob(file_pattern) - if p.is_file() and not self._exclusion_manager.should_exclude_path(p) - ] - ) - - # 遍历文件 - for file_path in files_to_search: - if not file_path.is_file(): - continue - - if is_binary_file(file_path): - continue - - # 检查是否达到限制 - if total_matches >= limit: - break - - # 检查是否应该忽略该文件或文件夹 - should_ignore = False - relative_path = file_path.relative_to(search_path) if search_path.is_dir() else file_path - - for ignore_pattern in ignore_patterns: - # 检查是否匹配忽略模式 - if ignore_pattern.search(str(relative_path)): - should_ignore = True - break - - if should_ignore: - continue - - try: - # 读取文件内容 - with open(file_path, encoding="utf-8") as f: - lines = f.readlines() - - # 搜索匹配行 - file_results = _search_in_file(lines, regex, context) - - if file_results: - results.append({"file": str(file_path), "matches": file_results}) - file_count += 1 - total_matches += len(file_results) - - except (OSError, UnicodeDecodeError, PermissionError) as e: - warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") - continue # 跳过无法读取的文件 - - warnings.append("") - - # 格式化输出 - return self.make_success_response( - kwargs=locals().copy(), - data=_format_regex_results(results, pattern, limit, file_count), - error="\n".join(warnings) if len(warnings) > 2 else None, - ) +import re +from pathlib import Path + +from src.models.tools.tool_result import ToolResult +from src.utils.binary_detector import is_binary_file +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +def _search_in_file(lines: list[str], regex: re.Pattern, context: int) -> list[dict]: + """在文件行中搜索匹配""" + matches = [] + line_count = len(lines) + + for i, line in enumerate(lines): + if regex.search(line): + # 计算上下文行范围 + start_line = max(0, i - context) + end_line = min(line_count, i + context + 1) + + # 收集上下文行 + context_lines = [] + for j in range(start_line, end_line): + is_match = j == i + context_lines.append({"line_num": j + 1, "content": lines[j].rstrip("\n\r"), "is_match": is_match}) + + matches.append({"line_num": i + 1, "content": line.rstrip("\n\r"), "context": context_lines}) + + return matches + + +def _format_regex_results(results: list[dict], pattern: str, limit: int, file_count: int) -> str: + """格式化正则搜索结果""" + if not results: + return f"未找到匹配正则表达式 '{pattern}' 的内容" + + total_matches = sum(len(r["matches"]) for r in results) + truncated = total_matches > limit + + output = [f"正则表达式搜索: '{pattern}'", f"匹配文件数: {file_count}, 匹配项数: {min(total_matches, limit)}"] + + if truncated: + output.append(f"⚠️ 结果已截断,仅显示前 {limit} 个匹配项(实际共 {total_matches} 个)") + + output.append("=" * 60) + + displayed_matches = 0 + for file_result in results: + if displayed_matches >= limit: + break + + output.append(f"\n文件: {file_result['file']}") + output.append("-" * 40) + + for match in file_result["matches"]: + if displayed_matches >= limit: + output.append(f"\n... 以及 {total_matches - limit} 个未显示的匹配项") + break + + if "context" in match: + output.append(f"第 {match['line_num']} 行匹配:") + for ctx_line in match["context"]: + # 增强包裹标记:匹配行用 >>> 和 <<< 包裹 + if ctx_line["is_match"]: + # 提取匹配的具体内容并用标记包裹 + content = ctx_line["content"] + # 使用正则找到实际匹配的部分并包裹 + try: + regex = re.compile(pattern) + # 找到所有匹配位置并添加标记 + matches_positions = list(regex.finditer(content)) + if matches_positions: + # 从后往前插入标记,避免位置偏移 + result_parts = [] + last_pos = 0 + for m in matches_positions: + result_parts.append(content[last_pos : m.start()]) + result_parts.append(f">>>{m.group()}<<<") + last_pos = m.end() + result_parts.append(content[last_pos:]) + marked_content = "".join(result_parts) + else: + marked_content = f">>>{content}<<<" + except re.error: + marked_content = f">>>{content}<<<" + + output.append(f" >>> L{ctx_line['line_num']:4d}: {marked_content}") + else: + output.append(f" L{ctx_line['line_num']:4d}: {ctx_line['content']}") + else: + # 旧格式兼容 + output.append(f" 第 {match['line_num']:4d} 行: >>>{match['content']}<<<") + output.append("") + displayed_matches += 1 + + return "\n".join(output) + + +class RegexSearchTool(BaseTool): + """正则表达式搜索工具""" + + def __init__(self, workspace: Workspace): + super().__init__(workspace, "regex_search", self.regex_search.__doc__) + self.func = self.regex_search + self.params = BaseTool.extract_params(self.regex_search) + self.param_descriptions = { + "pattern": "正则表达式模式", + "path": "搜索文件或文件夹路径", + "context": "显示匹配行的上下文行数", + "file_pattern": "文件匹配模式,支持通配符", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + } + self._exclusion_manager = workspace.exclusion_manager + + @BaseTool.handle_tool_exceptions + def regex_search( + self, + pattern: str, + path: str = ".", + context: int = 3, + file_pattern: str = "*", + limit: int = 256, + ignore: list[str] | None = None, + ) -> ToolResult: + """ + 使用正则表达式搜索文件内容, 支持上下文显示、文件过滤和忽略路径, 返回匹配详情; 适合代码与文档探索 + """ + # 验证搜索路径 + search_path: Path = self.workspace.path_validator.validate(path) + + # 编译正则表达式 + try: + regex = re.compile(pattern) + except re.error as e: + return self.make_failed_response(kwargs=locals().copy(), error=f"无效的正则表达式: {e}") + + # 收集忽略模式: 合并默认排除 + 用户传入的 ignore + ignore_patterns = self._exclusion_manager.merge_ignore_regexes(ignore) + + # 搜索结果 + results = [] + file_count = 0 + total_matches = 0 + warnings = [""] + + # 确定要搜索的文件列表(支持单文件或目录) + files_to_search = [search_path] if search_path.is_file() else [p for p in search_path.rglob(file_pattern) if p.is_file() and not self._exclusion_manager.should_exclude_path(p)] + + # 遍历文件 + for file_path in files_to_search: + if not file_path.is_file(): + continue + + if is_binary_file(file_path): + continue + + # 检查是否达到限制 + if total_matches >= limit: + break + + # 检查是否应该忽略该文件或文件夹 + should_ignore = False + relative_path = file_path.relative_to(search_path) if search_path.is_dir() else file_path + + for ignore_pattern in ignore_patterns: + # 检查是否匹配忽略模式 + if ignore_pattern.search(str(relative_path)): + should_ignore = True + break + + if should_ignore: + continue + + try: + # 读取文件内容 + with open(file_path, encoding="utf-8") as f: + lines = f.readlines() + + # 搜索匹配行 + file_results = _search_in_file(lines, regex, context) + + if file_results: + results.append({"file": str(file_path), "matches": file_results}) + file_count += 1 + total_matches += len(file_results) + + except (OSError, UnicodeDecodeError, PermissionError) as e: + warnings.append(f"在文件{file_path}搜索匹配行时出错: {e}") + continue # 跳过无法读取的文件 + + warnings.append("") + + # 格式化输出 + return self.make_success_response( + kwargs=locals().copy(), + data=_format_regex_results(results, pattern, limit, file_count), + error="\n".join(warnings) if len(warnings) > 2 else None, + ) diff --git a/src/workspace/tools/stat_tool.py b/src/workspace/tools/stat_tool.py index f6caeae..3997694 100644 --- a/src/workspace/tools/stat_tool.py +++ b/src/workspace/tools/stat_tool.py @@ -1,146 +1,134 @@ -import stat as stat_constants -from datetime import datetime -from pathlib import Path - -from src.models.tools.tool_result import ToolResult -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -class StatTool(BaseTool): - """获取文件或目录的详细信息""" - - def __init__(self, workspace: Workspace): - super().__init__(workspace, "stat", self.stat.__doc__) - self.func = self.stat - self.params = BaseTool.extract_params(self.stat) - self.param_descriptions = { - "path": "文件或目录路径", - } - - @BaseTool.handle_tool_exceptions - def stat(self, path: str = ".") -> ToolResult: - """ - 获取工作区内文件或目录的详细信息,包括大小、行数(仅文件)、修改时间、权限等 - """ - # 验证路径 - target_path: Path = self.workspace.path_validator.validate(path) - - # 获取基本 stat 信息 - path_stat = target_path.stat() - - # 构建输出 - output = [ - f"路径: {target_path.relative_to(self.workspace.root_path)}", - f"绝对路径: {target_path.resolve()}", - f"类型: {'目录' if target_path.is_dir() else '文件' if target_path.is_file() else '其他'}", - ] - - # 大小信息 - size_bytes = path_stat.st_size - if size_bytes < 1024: - size_str = f"{size_bytes} B" - elif size_bytes < 1024 * 1024: - size_str = f"{size_bytes / 1024:.2f} KB" - elif size_bytes < 1024 * 1024 * 1024: - size_str = f"{size_bytes / (1024 * 1024):.2f} MB" - else: - size_str = f"{size_bytes / (1024 * 1024 * 1024):.2f} GB" - - output.append(f"大小: {size_str} ({size_bytes} bytes)") - - # 行数(仅对文件有效) - if target_path.is_file(): - try: - with open(target_path, encoding="utf-8") as f: - line_count = sum(1 for _ in f) - output.append(f"行数: {line_count}") - except UnicodeDecodeError, PermissionError, OSError: - output.append("行数: 无法读取(二进制文件或编码错误)") - - # 时间信息 - def format_timestamp(timestamp: float) -> str: - dt = datetime.fromtimestamp(timestamp) - return dt.strftime("%Y-%m-%d %H:%M:%S") - - output.append(f"创建时间: {format_timestamp(path_stat.st_ctime)}") - output.append(f"修改时间: {format_timestamp(path_stat.st_mtime)}") - output.append(f"访问时间: {format_timestamp(path_stat.st_atime)}") - - # 权限信息 - mode = path_stat.st_mode - - # 文件类型 - if stat_constants.S_ISDIR(mode): - file_type = "d" - elif stat_constants.S_ISREG(mode): - file_type = "-" - elif stat_constants.S_ISLNK(mode): - file_type = "l" - elif stat_constants.S_ISCHR(mode): - file_type = "c" - elif stat_constants.S_ISBLK(mode): - file_type = "b" - elif stat_constants.S_ISFIFO(mode): - file_type = "p" - elif stat_constants.S_ISSOCK(mode): - file_type = "s" - else: - file_type = "?" - - # 所有者权限 - owner = ( - ("r" if mode & stat_constants.S_IRUSR else "-") - + ("w" if mode & stat_constants.S_IWUSR else "-") - + ("x" if mode & stat_constants.S_IXUSR else "-") - ) - # 组权限 - group = ( - ("r" if mode & stat_constants.S_IRGRP else "-") - + ("w" if mode & stat_constants.S_IWGRP else "-") - + ("x" if mode & stat_constants.S_IXGRP else "-") - ) - # 其他用户权限 - other = ( - ("r" if mode & stat_constants.S_IROTH else "-") - + ("w" if mode & stat_constants.S_IWOTH else "-") - + ("x" if mode & stat_constants.S_IXOTH else "-") - ) - - permissions = f"{file_type}{owner}{group}{other}" - output.append(f"权限: {permissions}") - - # 数值权限(八进制) - numeric_perms = oct(mode & 0o777)[2:] - output.append(f"权限(八进制): {numeric_perms}") - - # 所有者信息 - try: - uid = path_stat.st_uid - gid = path_stat.st_gid - output.append(f"所有者UID: {uid}, GID: {gid}") - except AttributeError: - pass # Windows 可能没有 uid/gid - - # 链接数 - output.append(f"硬链接数: {path_stat.st_nlink}") - - # 如果是符号链接,显示目标 - if target_path.is_symlink(): - try: - link_target = target_path.resolve() - output.append(f"符号链接指向: {link_target}") - except OSError: - output.append("符号链接指向: 无法解析") - - # 如果是目录,显示子项数量 - if target_path.is_dir(): - try: - items = list(target_path.iterdir()) - dir_count = sum(1 for item in items if item.is_dir()) - file_count = sum(1 for item in items if item.is_file()) - output.append(f"目录内容: {len(items)} 项({dir_count} 个目录, {file_count} 个文件)") - except PermissionError: - output.append("目录内容: 无法访问") - - return self.make_success_response(kwargs=locals().copy(), data="\n".join(output)) +import stat as stat_constants +from datetime import datetime +from pathlib import Path + +from src.models.tools.tool_result import ToolResult +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +class StatTool(BaseTool): + """获取文件或目录的详细信息""" + + def __init__(self, workspace: Workspace): + super().__init__(workspace, "stat", self.stat.__doc__) + self.func = self.stat + self.params = BaseTool.extract_params(self.stat) + self.param_descriptions = { + "path": "文件或目录路径", + } + + @BaseTool.handle_tool_exceptions + def stat(self, path: str = ".") -> ToolResult: + """ + 获取工作区内文件或目录的详细信息,包括大小、行数(仅文件)、修改时间、权限等 + """ + # 验证路径 + target_path: Path = self.workspace.path_validator.validate(path) + + # 获取基本 stat 信息 + path_stat = target_path.stat() + + # 构建输出 + output = [ + f"路径: {target_path.relative_to(self.workspace.root_path)}", + f"绝对路径: {target_path.resolve()}", + f"类型: {'目录' if target_path.is_dir() else '文件' if target_path.is_file() else '其他'}", + ] + + # 大小信息 + size_bytes = path_stat.st_size + if size_bytes < 1024: + size_str = f"{size_bytes} B" + elif size_bytes < 1024 * 1024: + size_str = f"{size_bytes / 1024:.2f} KB" + elif size_bytes < 1024 * 1024 * 1024: + size_str = f"{size_bytes / (1024 * 1024):.2f} MB" + else: + size_str = f"{size_bytes / (1024 * 1024 * 1024):.2f} GB" + + output.append(f"大小: {size_str} ({size_bytes} bytes)") + + # 行数(仅对文件有效) + if target_path.is_file(): + try: + with open(target_path, encoding="utf-8") as f: + line_count = sum(1 for _ in f) + output.append(f"行数: {line_count}") + except UnicodeDecodeError, PermissionError, OSError: + output.append("行数: 无法读取(二进制文件或编码错误)") + + # 时间信息 + def format_timestamp(timestamp: float) -> str: + dt = datetime.fromtimestamp(timestamp) + return dt.strftime("%Y-%m-%d %H:%M:%S") + + output.append(f"创建时间: {format_timestamp(path_stat.st_ctime)}") + output.append(f"修改时间: {format_timestamp(path_stat.st_mtime)}") + output.append(f"访问时间: {format_timestamp(path_stat.st_atime)}") + + # 权限信息 + mode = path_stat.st_mode + + # 文件类型 + if stat_constants.S_ISDIR(mode): + file_type = "d" + elif stat_constants.S_ISREG(mode): + file_type = "-" + elif stat_constants.S_ISLNK(mode): + file_type = "l" + elif stat_constants.S_ISCHR(mode): + file_type = "c" + elif stat_constants.S_ISBLK(mode): + file_type = "b" + elif stat_constants.S_ISFIFO(mode): + file_type = "p" + elif stat_constants.S_ISSOCK(mode): + file_type = "s" + else: + file_type = "?" + + # 所有者权限 + owner = ("r" if mode & stat_constants.S_IRUSR else "-") + ("w" if mode & stat_constants.S_IWUSR else "-") + ("x" if mode & stat_constants.S_IXUSR else "-") + # 组权限 + group = ("r" if mode & stat_constants.S_IRGRP else "-") + ("w" if mode & stat_constants.S_IWGRP else "-") + ("x" if mode & stat_constants.S_IXGRP else "-") + # 其他用户权限 + other = ("r" if mode & stat_constants.S_IROTH else "-") + ("w" if mode & stat_constants.S_IWOTH else "-") + ("x" if mode & stat_constants.S_IXOTH else "-") + + permissions = f"{file_type}{owner}{group}{other}" + output.append(f"权限: {permissions}") + + # 数值权限(八进制) + numeric_perms = oct(mode & 0o777)[2:] + output.append(f"权限(八进制): {numeric_perms}") + + # 所有者信息 + try: + uid = path_stat.st_uid + gid = path_stat.st_gid + output.append(f"所有者UID: {uid}, GID: {gid}") + except AttributeError: + pass # Windows 可能没有 uid/gid + + # 链接数 + output.append(f"硬链接数: {path_stat.st_nlink}") + + # 如果是符号链接,显示目标 + if target_path.is_symlink(): + try: + link_target = target_path.resolve() + output.append(f"符号链接指向: {link_target}") + except OSError: + output.append("符号链接指向: 无法解析") + + # 如果是目录,显示子项数量 + if target_path.is_dir(): + try: + items = list(target_path.iterdir()) + dir_count = sum(1 for item in items if item.is_dir()) + file_count = sum(1 for item in items if item.is_file()) + output.append(f"目录内容: {len(items)} 项({dir_count} 个目录, {file_count} 个文件)") + except PermissionError: + output.append("目录内容: 无法访问") + + return self.make_success_response(kwargs=locals().copy(), data="\n".join(output)) diff --git a/src/workspace/tools/symbol_ref_tool.py b/src/workspace/tools/symbol_ref_tool.py index 4f42291..ecc73c6 100644 --- a/src/workspace/tools/symbol_ref_tool.py +++ b/src/workspace/tools/symbol_ref_tool.py @@ -1,421 +1,414 @@ -"""符号引用查找工具 - 查找函数、类、变量等的定义和引用""" - -import re -from pathlib import Path - -from src.models.tools.tool_result import ToolResult -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -def _get_file_pattern_by_language(language: str) -> str: - """根据语言获取默认的文件匹配模式""" - lang_patterns = { - "python": "*.py", - "javascript": "*.js", - "typescript": "*.ts", - "markdown": "*.md", - "general": "*", - } - return lang_patterns.get(language, "*") - - -def _generate_patterns(symbol_name: str, language: str, include_def: bool, include_ref: bool) -> list[dict[str, str]]: - """根据语言生成搜索模式""" - patterns = [] - - # 转义特殊字符 - escaped_name = re.escape(symbol_name) - - if language == "python": - if include_def: - # 函数定义: def func_name( - patterns.append({"pattern": rf"^\s*def\s+{escaped_name}\s*\(", "type": "definition_function"}) - # 类定义: class ClassName: - patterns.append({"pattern": rf"^\s*class\s+{escaped_name}\s*[:\(]", "type": "definition_class"}) - # 变量/属性定义 - patterns.append({"pattern": rf"^\s*{escaped_name}\s*=", "type": "definition_variable"}) - # 方法定义 - patterns.append({"pattern": rf"^\s*def\s+{escaped_name}\s*\(self", "type": "definition_method"}) - - if include_ref: - # 函数调用: func_name( - patterns.append({"pattern": rf"{escaped_name}\s*\(", "type": "reference_call"}) - # 类实例化: ClassName( - patterns.append({"pattern": rf"{escaped_name}\s*\(", "type": "reference_instantiation"}) - # 属性访问: .symbol_name 或 symbol_name. - patterns.append({"pattern": rf"\.{escaped_name}\b|\b{escaped_name}\.", "type": "reference_attribute"}) - # 导入语句: from module import symbol_name - patterns.append( - { - "pattern": rf"\bimport\s+.*\b{escaped_name}\b|\bfrom\s+.*\s+import\s+.*\b{escaped_name}\b", - "type": "reference_import", - } - ) - - elif language in ["javascript", "typescript"]: - if include_def: - # 函数定义: function func_name() 或 const func_name = - patterns.append( - { - "pattern": ( - rf"(function\s+{escaped_name}\s*\(|const\s+{escaped_name}\s*=" - rf"\s*(function|\()|let\s+{escaped_name}\s*=\s*(function|\())" - ), - "type": "definition_function", - } - ) - # 类定义: class ClassName - patterns.append( - { - "pattern": rf"class\s+{escaped_name}\s*\{{|class\s+{escaped_name}\s+extends", - "type": "definition_class", - } - ) - # 变量定义: const/let/var symbol_name = - patterns.append({"pattern": rf"\b(const|let|var)\s+{escaped_name}\s*=", "type": "definition_variable"}) - # 导出定义: export ... - patterns.append({"pattern": rf"export\s+.*\b{escaped_name}\b", "type": "definition_export"}) - - if include_ref: - # 函数调用/方法调用 - patterns.append({"pattern": rf"{escaped_name}\s*\(", "type": "reference_call"}) - # 属性访问 - patterns.append({"pattern": rf"\.{escaped_name}\b|\b{escaped_name}\.", "type": "reference_property"}) - # 导入语句 - patterns.append( - { - "pattern": ( - rf"import\s+.*\b{escaped_name}\b|import\s+\{{\s*.*\b{escaped_name}" - rf"\b.*\s*\}}|require\s*\(.*\b{escaped_name}\b" - ), - "type": "reference_import", - } - ) - - elif language == "markdown": - # Markdown 中的标题、链接、代码块引用 - if include_def: - patterns.append({"pattern": rf"^#+\s+.*\b{escaped_name}\b", "type": "definition_heading"}) - if include_ref: - patterns.append( - { - "pattern": rf"\[{escaped_name}\]\(|\[{escaped_name}\]\[|\[{escaped_name}\]\:|`{escaped_name}`", - "type": "reference_link_or_code", - } - ) - - else: # general - # 通用模式: 作为独立单词出现 - patterns.append({"pattern": rf"\b{escaped_name}\b", "type": "general_reference"}) - - # 去重: 根据pattern去重 - unique_patterns = [] - seen = set() - for p in patterns: - if p["pattern"] not in seen: - seen.add(p["pattern"]) - unique_patterns.append(p) - - return unique_patterns - - -def _search_all_patterns( - workspace: Workspace, - patterns: list[dict], - search_path: str, - symbol_name: str, - context_lines: int, - limit: int, - file_pattern: str, - ignore: list[str] | None, -) -> list[dict]: - """单次文件遍历搜索所有模式, 返回按文件分组且带上下文的结果. - - 相比旧实现的优势: - - 文件系统只遍历 1 次(旧: N 次, N=模式数量) - - 每个文件只读取 1 次(旧: N 次) - - 直接使用结构化数据, 无需 格式化→正则解析 的反模式 - - 真正实现上下文行读取(旧实现仅将匹配行本身作为上下文) - """ - # 编译所有正则模式 - compiled_patterns: list[tuple[re.Pattern, str]] = [] - for p in patterns: - try: - regex = re.compile(p["pattern"], re.IGNORECASE) - compiled_patterns.append((regex, p["type"])) - except re.error: - continue - - if not compiled_patterns: - return [] - - # 单次遍历搜索: 所有模式在一次文件遍历中完成 - matches = workspace.search_content_multi_pattern( - patterns=compiled_patterns, - folder_path=search_path, - file_pattern=file_pattern, - max_workers=4, - ignore=ignore, - ) - - if not matches: - return [] - - # 应用 limit 截断 - matches = matches[:limit] - - # 按文件分组并构建带上下文的结果 - return _build_results_with_context(matches, context_lines, symbol_name, workspace.root_path) - - -def _build_results_with_context( - matches: list[dict], - context_lines: int, - symbol_name: str, - root_path: Path, -) -> list[dict]: - """从扁平匹配列表构建按文件分组、带上下文的结果""" - context_lines = max(0, context_lines) - - # 按文件分组, 保留首次出现的顺序 - file_matches: dict[str, list[dict]] = {} - file_order: list[str] = [] - for m in matches: - f = m["file"] - if f not in file_matches: - file_matches[f] = [] - file_order.append(f) - file_matches[f].append(m) - - results = [] - for file_rel in file_order: - file_match_list = file_matches[file_rel] - - # 收集需要读取的行号范围(匹配行 ± 上下文行) - needed_lines: set[int] = set() - for m in file_match_list: - for delta in range(-context_lines, context_lines + 1): - needed_lines.add(m["line_num"] + delta) - - # 仅读取需要的行(不加载整个文件到内存) - line_cache: dict[int, str] = {} - file_full_path = root_path / file_rel - try: - if needed_lines: - max_needed = max(needed_lines) - with open(file_full_path, encoding="utf-8") as f: - for line_num, line in enumerate(f, 1): - if line_num in needed_lines: - line_cache[line_num] = line.rstrip("\n\r") - if line_num >= max_needed: - break - except UnicodeDecodeError, PermissionError, OSError: - pass - - # 为每个匹配项构建上下文 - built_matches = [] - for m in file_match_list: - match_line_num = m["line_num"] - context = [] - for delta in range(-context_lines, context_lines + 1): - ctx_line_num = match_line_num + delta - if ctx_line_num in line_cache: - context.append( - { - "line_num": ctx_line_num, - "content": line_cache[ctx_line_num], - "is_match": delta == 0, - } - ) - - built_matches.append( - { - "line_num": match_line_num, - "content": m["content"], - "context": context, - "match_type": m["pattern_type"], - "symbol_name": symbol_name, - } - ) - - results.append( - { - "file": file_rel, - "matches": built_matches, - "type": file_match_list[0]["pattern_type"], - } - ) - - return results - - -def _detect_language(search_path: Path, specified_lang: str) -> str: - """检测项目的主要语言""" - if specified_lang != "auto": - return specified_lang - - # 检查项目文件 - lang_indicators = { - "python": ["*.py", "requirements.txt", "pyproject.toml", "setup.py"], - "javascript": ["*.js", "package.json", "*.jsx"], - "typescript": ["*.ts", "*.tsx", "tsconfig.json"], - "markdown": ["*.md", "*.markdown"], - } - - lang_scores = {lang: 0 for lang in lang_indicators} - - for lang, patterns in lang_indicators.items(): - for pattern in patterns: - matches = list(search_path.rglob(pattern)) - lang_scores[lang] += len(matches) - - # 选择得分最高的语言 - if max(lang_scores.values()) > 0: - return max(lang_scores, key=lang_scores.get) - - return "general" - - -def _get_type_label(match_type: str) -> str: - """获取匹配类型的友好标签""" - labels = { - "definition_function": "函数定义", - "definition_class": "类定义", - "definition_method": "方法定义", - "definition_variable": "变量定义", - "definition_export": "导出定义", - "definition_heading": "标题定义", - "reference_call": "函数/方法调用", - "reference_instantiation": "实例化", - "reference_attribute": "属性访问", - "reference_import": "导入引用", - "reference_property": "属性访问", - "reference_link_or_code": "链接或代码引用", - "general_reference": "通用引用", - } - return labels.get(match_type, f"{match_type}") - - -def _format_results(results: list[dict], symbol_name: str, language: str, limit: int) -> str: - """格式化搜索结果""" - if not results: - return ( - f"未找到符号 '{symbol_name}' 的定义或引用" - f"(语言: {language})\n\n提示: 可以尝试指定不同的语言类型或调整搜索路径" - ) - - total_matches = sum(len(r["matches"]) for r in results) - truncated = total_matches > limit - - output = [ - f"符号引用查找: '{symbol_name}'", - f"语言: {language}", - f"匹配文件数: {len(results)}, 匹配项数: {min(total_matches, limit)}", - ] - - if truncated: - output.append(f"[WARN] 结果已截断,仅显示前 {limit} 个匹配项(实际共 {total_matches} 个)") - - output.append("=" * 80) - - displayed_matches = 0 - for file_result in results: - if displayed_matches >= limit: - break - - output.append(f"\n文件: {file_result['file']}") - output.append(f"匹配类型: {file_result.get('type', 'unknown')}") - output.append("-" * 80) - - for match in file_result["matches"]: - if displayed_matches >= limit: - output.append(f"\n... 以及 {total_matches - limit} 个未显示的匹配项") - break - - # 类型标签 - type_label = _get_type_label(match.get("match_type", "general_reference")) - output.append(f"{type_label} 第 {match['line_num']} 行:") - - # 显示上下文 - for ctx_line in match.get("context", []): - prefix = ">>>" if ctx_line.get("is_match", False) else " " - output.append(f" {prefix} L{ctx_line['line_num']:4d}: {ctx_line['content']}") - - output.append("") # 空行分隔 - displayed_matches += 1 - - return "\n".join(output) - - -class SymbolRefTool(BaseTool): - """查找符号引用工具 - 定位函数、类、变量等的定义和引用位置""" - - def __init__(self, workspace: Workspace): - super().__init__(workspace, "symbol_ref", self.symbol_ref.__doc__, read_permission=True, write_permission=False) - self.func = self.symbol_ref - self.params = BaseTool.extract_params(self.symbol_ref) - self.param_descriptions = { - "symbol_name": "要查找的符号名称(如函数名、类名、变量名)", - "path": "搜索文件或文件夹路径", - "language": "语言类型(auto/python/javascript/typescript/markdown/general)", - "include_definitions": "是否包含定义位置", - "include_references": "是否包含引用位置", - "context_lines": "显示匹配行的上下文行数", - "limit": "最大匹配数量限制", - "ignore": "忽略匹配正则的文件或文件夹列表", - "file_pattern": "文件匹配模式(如 *.py),默认根据语言自动选择", - } - - @BaseTool.handle_tool_exceptions - def symbol_ref( - self, - symbol_name: str, - path: str = ".", - language: str = "auto", - include_definitions: bool = True, - include_references: bool = True, - context_lines: int = 2, - limit: int = 256, - ignore: list[str] | None = None, - file_pattern: str | None = None, - ) -> ToolResult: - """ - 查找符号(函数、类、变量等)的定义和引用位置, 适用于代码探索、重构影响分析、理解代码结构等场景. - """ - # 验证搜索路径 - search_path: Path = self.workspace.path_validator.validate(path) - if not search_path.exists(): - return self.make_failed_response(kwargs=locals().copy(), error=f"路径不存在: {path}") - - # 自动检测语言 - detected_lang = _detect_language(search_path, language) - - # 确定文件匹配模式 - if file_pattern is None: - file_pattern = _get_file_pattern_by_language(detected_lang) - - # 生成搜索模式 - patterns = _generate_patterns(symbol_name, detected_lang, include_definitions, include_references) - - if not patterns: - return self.make_failed_response( - kwargs=locals().copy(), error=f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})" - ) - - # 使用并发搜索执行所有模式 - all_results = _search_all_patterns( - self.workspace, - patterns, - path, - symbol_name, - context_lines, - limit, - file_pattern, - ignore, - ) - - # 格式化输出 - return self.make_success_response( - kwargs=locals().copy(), data=_format_results(all_results, symbol_name, detected_lang, limit) - ) +"""符号引用查找工具 - 查找函数、类、变量等的定义和引用""" + +import re +from pathlib import Path + +from src.models.tools.tool_result import ToolResult +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +def _get_file_pattern_by_language(language: str) -> str: + """根据语言获取默认的文件匹配模式""" + lang_patterns = { + "python": "*.py", + "javascript": "*.js", + "typescript": "*.ts", + "markdown": "*.md", + "general": "*", + } + return lang_patterns.get(language, "*") + + +def _generate_patterns(symbol_name: str, language: str, include_def: bool, include_ref: bool) -> list[dict[str, str]]: + """根据语言生成搜索模式""" + patterns = [] + + # 转义特殊字符 + escaped_name = re.escape(symbol_name) + + if language == "python": + if include_def: + # 函数定义: def func_name( + patterns.append({"pattern": rf"^\s*def\s+{escaped_name}\s*\(", "type": "definition_function"}) + # 类定义: class ClassName: + patterns.append({"pattern": rf"^\s*class\s+{escaped_name}\s*[:\(]", "type": "definition_class"}) + # 变量/属性定义 + patterns.append({"pattern": rf"^\s*{escaped_name}\s*=", "type": "definition_variable"}) + # 方法定义 + patterns.append({"pattern": rf"^\s*def\s+{escaped_name}\s*\(self", "type": "definition_method"}) + + if include_ref: + # 函数调用: func_name( + patterns.append({"pattern": rf"{escaped_name}\s*\(", "type": "reference_call"}) + # 类实例化: ClassName( + patterns.append({"pattern": rf"{escaped_name}\s*\(", "type": "reference_instantiation"}) + # 属性访问: .symbol_name 或 symbol_name. + patterns.append({"pattern": rf"\.{escaped_name}\b|\b{escaped_name}\.", "type": "reference_attribute"}) + # 导入语句: from module import symbol_name + patterns.append( + { + "pattern": rf"\bimport\s+.*\b{escaped_name}\b|\bfrom\s+.*\s+import\s+.*\b{escaped_name}\b", + "type": "reference_import", + } + ) + + elif language in ["javascript", "typescript"]: + if include_def: + # 函数定义: function func_name() 或 const func_name = + patterns.append( + { + "pattern": ( + rf"(function\s+{escaped_name}\s*\(|const\s+{escaped_name}\s*=" + rf"\s*(function|\()|let\s+{escaped_name}\s*=\s*(function|\())" + ), + "type": "definition_function", + } + ) + # 类定义: class ClassName + patterns.append( + { + "pattern": rf"class\s+{escaped_name}\s*\{{|class\s+{escaped_name}\s+extends", + "type": "definition_class", + } + ) + # 变量定义: const/let/var symbol_name = + patterns.append({"pattern": rf"\b(const|let|var)\s+{escaped_name}\s*=", "type": "definition_variable"}) + # 导出定义: export ... + patterns.append({"pattern": rf"export\s+.*\b{escaped_name}\b", "type": "definition_export"}) + + if include_ref: + # 函数调用/方法调用 + patterns.append({"pattern": rf"{escaped_name}\s*\(", "type": "reference_call"}) + # 属性访问 + patterns.append({"pattern": rf"\.{escaped_name}\b|\b{escaped_name}\.", "type": "reference_property"}) + # 导入语句 + patterns.append( + { + "pattern": ( + rf"import\s+.*\b{escaped_name}\b|import\s+\{{\s*.*\b{escaped_name}" + rf"\b.*\s*\}}|require\s*\(.*\b{escaped_name}\b" + ), + "type": "reference_import", + } + ) + + elif language == "markdown": + # Markdown 中的标题、链接、代码块引用 + if include_def: + patterns.append({"pattern": rf"^#+\s+.*\b{escaped_name}\b", "type": "definition_heading"}) + if include_ref: + patterns.append( + { + "pattern": rf"\[{escaped_name}\]\(|\[{escaped_name}\]\[|\[{escaped_name}\]\:|`{escaped_name}`", + "type": "reference_link_or_code", + } + ) + + else: # general + # 通用模式: 作为独立单词出现 + patterns.append({"pattern": rf"\b{escaped_name}\b", "type": "general_reference"}) + + # 去重: 根据pattern去重 + unique_patterns = [] + seen = set() + for p in patterns: + if p["pattern"] not in seen: + seen.add(p["pattern"]) + unique_patterns.append(p) + + return unique_patterns + + +def _search_all_patterns( + workspace: Workspace, + patterns: list[dict], + search_path: str, + symbol_name: str, + context_lines: int, + limit: int, + file_pattern: str, + ignore: list[str] | None, +) -> list[dict]: + """单次文件遍历搜索所有模式, 返回按文件分组且带上下文的结果. + + 相比旧实现的优势: + - 文件系统只遍历 1 次(旧: N 次, N=模式数量) + - 每个文件只读取 1 次(旧: N 次) + - 直接使用结构化数据, 无需 格式化→正则解析 的反模式 + - 真正实现上下文行读取(旧实现仅将匹配行本身作为上下文) + """ + # 编译所有正则模式 + compiled_patterns: list[tuple[re.Pattern, str]] = [] + for p in patterns: + try: + regex = re.compile(p["pattern"], re.IGNORECASE) + compiled_patterns.append((regex, p["type"])) + except re.error: + continue + + if not compiled_patterns: + return [] + + # 单次遍历搜索: 所有模式在一次文件遍历中完成 + matches = workspace.search_content_multi_pattern( + patterns=compiled_patterns, + folder_path=search_path, + file_pattern=file_pattern, + max_workers=4, + ignore=ignore, + ) + + if not matches: + return [] + + # 应用 limit 截断 + matches = matches[:limit] + + # 按文件分组并构建带上下文的结果 + return _build_results_with_context(matches, context_lines, symbol_name, workspace.root_path) + + +def _build_results_with_context( + matches: list[dict], + context_lines: int, + symbol_name: str, + root_path: Path, +) -> list[dict]: + """从扁平匹配列表构建按文件分组、带上下文的结果""" + context_lines = max(0, context_lines) + + # 按文件分组, 保留首次出现的顺序 + file_matches: dict[str, list[dict]] = {} + file_order: list[str] = [] + for m in matches: + f = m["file"] + if f not in file_matches: + file_matches[f] = [] + file_order.append(f) + file_matches[f].append(m) + + results = [] + for file_rel in file_order: + file_match_list = file_matches[file_rel] + + # 收集需要读取的行号范围(匹配行 ± 上下文行) + needed_lines: set[int] = set() + for m in file_match_list: + for delta in range(-context_lines, context_lines + 1): + needed_lines.add(m["line_num"] + delta) + + # 仅读取需要的行(不加载整个文件到内存) + line_cache: dict[int, str] = {} + file_full_path = root_path / file_rel + try: + if needed_lines: + max_needed = max(needed_lines) + with open(file_full_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + if line_num in needed_lines: + line_cache[line_num] = line.rstrip("\n\r") + if line_num >= max_needed: + break + except UnicodeDecodeError, PermissionError, OSError: + pass + + # 为每个匹配项构建上下文 + built_matches = [] + for m in file_match_list: + match_line_num = m["line_num"] + context = [] + for delta in range(-context_lines, context_lines + 1): + ctx_line_num = match_line_num + delta + if ctx_line_num in line_cache: + context.append( + { + "line_num": ctx_line_num, + "content": line_cache[ctx_line_num], + "is_match": delta == 0, + } + ) + + built_matches.append( + { + "line_num": match_line_num, + "content": m["content"], + "context": context, + "match_type": m["pattern_type"], + "symbol_name": symbol_name, + } + ) + + results.append( + { + "file": file_rel, + "matches": built_matches, + "type": file_match_list[0]["pattern_type"], + } + ) + + return results + + +def _detect_language(search_path: Path, specified_lang: str) -> str: + """检测项目的主要语言""" + if specified_lang != "auto": + return specified_lang + + # 检查项目文件 + lang_indicators = { + "python": ["*.py", "requirements.txt", "pyproject.toml", "setup.py"], + "javascript": ["*.js", "package.json", "*.jsx"], + "typescript": ["*.ts", "*.tsx", "tsconfig.json"], + "markdown": ["*.md", "*.markdown"], + } + + lang_scores = {lang: 0 for lang in lang_indicators} + + for lang, patterns in lang_indicators.items(): + for pattern in patterns: + matches = list(search_path.rglob(pattern)) + lang_scores[lang] += len(matches) + + # 选择得分最高的语言 + if max(lang_scores.values()) > 0: + return max(lang_scores, key=lang_scores.get) + + return "general" + + +def _get_type_label(match_type: str) -> str: + """获取匹配类型的友好标签""" + labels = { + "definition_function": "函数定义", + "definition_class": "类定义", + "definition_method": "方法定义", + "definition_variable": "变量定义", + "definition_export": "导出定义", + "definition_heading": "标题定义", + "reference_call": "函数/方法调用", + "reference_instantiation": "实例化", + "reference_attribute": "属性访问", + "reference_import": "导入引用", + "reference_property": "属性访问", + "reference_link_or_code": "链接或代码引用", + "general_reference": "通用引用", + } + return labels.get(match_type, f"{match_type}") + + +def _format_results(results: list[dict], symbol_name: str, language: str, limit: int) -> str: + """格式化搜索结果""" + if not results: + return f"未找到符号 '{symbol_name}' 的定义或引用(语言: {language})\n\n提示: 可以尝试指定不同的语言类型或调整搜索路径" + + total_matches = sum(len(r["matches"]) for r in results) + truncated = total_matches > limit + + output = [ + f"符号引用查找: '{symbol_name}'", + f"语言: {language}", + f"匹配文件数: {len(results)}, 匹配项数: {min(total_matches, limit)}", + ] + + if truncated: + output.append(f"[WARN] 结果已截断,仅显示前 {limit} 个匹配项(实际共 {total_matches} 个)") + + output.append("=" * 80) + + displayed_matches = 0 + for file_result in results: + if displayed_matches >= limit: + break + + output.append(f"\n文件: {file_result['file']}") + output.append(f"匹配类型: {file_result.get('type', 'unknown')}") + output.append("-" * 80) + + for match in file_result["matches"]: + if displayed_matches >= limit: + output.append(f"\n... 以及 {total_matches - limit} 个未显示的匹配项") + break + + # 类型标签 + type_label = _get_type_label(match.get("match_type", "general_reference")) + output.append(f"{type_label} 第 {match['line_num']} 行:") + + # 显示上下文 + for ctx_line in match.get("context", []): + prefix = ">>>" if ctx_line.get("is_match", False) else " " + output.append(f" {prefix} L{ctx_line['line_num']:4d}: {ctx_line['content']}") + + output.append("") # 空行分隔 + displayed_matches += 1 + + return "\n".join(output) + + +class SymbolRefTool(BaseTool): + """查找符号引用工具 - 定位函数、类、变量等的定义和引用位置""" + + def __init__(self, workspace: Workspace): + super().__init__(workspace, "symbol_ref", self.symbol_ref.__doc__, read_permission=True, write_permission=False) + self.func = self.symbol_ref + self.params = BaseTool.extract_params(self.symbol_ref) + self.param_descriptions = { + "symbol_name": "要查找的符号名称(如函数名、类名、变量名)", + "path": "搜索文件或文件夹路径", + "language": "语言类型(auto/python/javascript/typescript/markdown/general)", + "include_definitions": "是否包含定义位置", + "include_references": "是否包含引用位置", + "context_lines": "显示匹配行的上下文行数", + "limit": "最大匹配数量限制", + "ignore": "忽略匹配正则的文件或文件夹列表", + "file_pattern": "文件匹配模式(如 *.py),默认根据语言自动选择", + } + + @BaseTool.handle_tool_exceptions + def symbol_ref( + self, + symbol_name: str, + path: str = ".", + language: str = "auto", + include_definitions: bool = True, + include_references: bool = True, + context_lines: int = 2, + limit: int = 256, + ignore: list[str] | None = None, + file_pattern: str | None = None, + ) -> ToolResult: + """ + 查找符号(函数、类、变量等)的定义和引用位置, 适用于代码探索、重构影响分析、理解代码结构等场景. + """ + # 验证搜索路径 + search_path: Path = self.workspace.path_validator.validate(path) + if not search_path.exists(): + return self.make_failed_response(kwargs=locals().copy(), error=f"路径不存在: {path}") + + # 自动检测语言 + detected_lang = _detect_language(search_path, language) + + # 确定文件匹配模式 + if file_pattern is None: + file_pattern = _get_file_pattern_by_language(detected_lang) + + # 生成搜索模式 + patterns = _generate_patterns(symbol_name, detected_lang, include_definitions, include_references) + + if not patterns: + return self.make_failed_response(kwargs=locals().copy(), error=f"无法为符号 '{symbol_name}' 生成有效的搜索模式(语言: {detected_lang})") + + # 使用并发搜索执行所有模式 + all_results = _search_all_patterns( + self.workspace, + patterns, + path, + symbol_name, + context_lines, + limit, + file_pattern, + ignore, + ) + + # 格式化输出 + return self.make_success_response(kwargs=locals().copy(), data=_format_results(all_results, symbol_name, detected_lang, limit)) diff --git a/src/workspace/tools/write_tool.py b/src/workspace/tools/write_tool.py index fe1a09c..4ec44f0 100644 --- a/src/workspace/tools/write_tool.py +++ b/src/workspace/tools/write_tool.py @@ -1,76 +1,66 @@ -from pathlib import Path - -from src.core.file_tracker import FileTracker -from src.models.tools.tool_result import ToolResult -from src.utils.binary_detector import is_binary_file -from src.workspace.tools.base_tool import BaseTool -from src.workspace.workspace import Workspace - - -class WriteTool(BaseTool): - def __init__(self, workspace: Workspace): - super().__init__(workspace, "write", self.write.__doc__, write_permission=True) - self.func = self.write - self.params = BaseTool.extract_params(self.write) - self.param_descriptions = { - "path": "文件路径", - "content": "写入内容", - } - - @BaseTool.handle_tool_exceptions - def write(self, path: str, content: str = "") -> ToolResult: - """ - 写入文件内容, 如文件不存在则创建(含父目录) - """ - source_path = Path(path) - path: Path = self.workspace.path_validator.resolve_path(source_path) - - if path.exists() and path.is_dir(): - return self.make_failed_response( - kwargs=locals().copy(), error=str(ValueError(f"路径 {path} 是一个目录,无法写入")) - ) - - if is_binary_file(path): - return self.make_failed_response( - kwargs=locals().copy(), error=str(ValueError(f"禁止写入二进制文件: {path}")) - ) - - mtime_error = self._validate_mtime(path) - if mtime_error: - return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") - - old_content = "" - old_meta = None - if path.exists() and path.is_file(): - old_meta = FileTracker.get_file_meta(path) - try: - old_content = path.read_text(encoding="utf-8") - except Exception: - old_content = "" - - rel_path = str(path.relative_to(self.workspace.root_path)) - old_hash = old_meta.get("checksum") if old_meta else None - new_hash = FileTracker.compute_checksum_from_string(content) - diff_content = self._generate_diff(old_content, content, rel_path) - - session_id = self.workspace.session_id - snapshot_id = self.workspace.db.record_file_snapshot( - rel_path, - old_hash, - new_hash, - diff_content, - audit_status="PENDING_AUDIT", - session_id=session_id, - pending_content=content, - ) - - return self.make_success_response( - kwargs=locals().copy(), - data=( - f"修改已推送到审核系统\n" - f"[Write Preview]\n" - f"File: {rel_path}\n" - f"Snapshot ID: {snapshot_id}\n" - f"Diff:\n{diff_content}" - ), - ) +from pathlib import Path + +from src.core.file_tracker import FileTracker +from src.models.tools.tool_result import ToolResult +from src.utils.binary_detector import is_binary_file +from src.workspace.tools.base_tool import BaseTool +from src.workspace.workspace import Workspace + + +class WriteTool(BaseTool): + def __init__(self, workspace: Workspace): + super().__init__(workspace, "write", self.write.__doc__, write_permission=True) + self.func = self.write + self.params = BaseTool.extract_params(self.write) + self.param_descriptions = { + "path": "文件路径", + "content": "写入内容", + } + + @BaseTool.handle_tool_exceptions + def write(self, path: str, content: str = "") -> ToolResult: + """ + 写入文件内容, 如文件不存在则创建(含父目录) + """ + source_path = Path(path) + path: Path = self.workspace.path_validator.resolve_path(source_path) + + if path.exists() and path.is_dir(): + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError(f"路径 {path} 是一个目录,无法写入"))) + + if is_binary_file(path): + return self.make_failed_response(kwargs=locals().copy(), error=str(ValueError(f"禁止写入二进制文件: {path}"))) + + mtime_error = self._validate_mtime(path) + if mtime_error: + return self.make_failed_response(locals().copy(), error=f"无法编辑被修改过的文件:\n{mtime_error}") + + old_content = "" + old_meta = None + if path.exists() and path.is_file(): + old_meta = FileTracker.get_file_meta(path) + try: + old_content = path.read_text(encoding="utf-8") + except Exception: + old_content = "" + + rel_path = str(path.relative_to(self.workspace.root_path)) + old_hash = old_meta.get("checksum") if old_meta else None + new_hash = FileTracker.compute_checksum_from_string(content) + diff_content = self._generate_diff(old_content, content, rel_path) + + session_id = self.workspace.session_id + snapshot_id = self.workspace.db.record_file_snapshot( + rel_path, + old_hash, + new_hash, + diff_content, + audit_status="PENDING_AUDIT", + session_id=session_id, + pending_content=content, + ) + + return self.make_success_response( + kwargs=locals().copy(), + data=(f"修改已推送到审核系统\n[Write Preview]\nFile: {rel_path}\nSnapshot ID: {snapshot_id}\nDiff:\n{diff_content}"), + ) diff --git a/src/workspace/workspace.py b/src/workspace/workspace.py index 295b2d6..08ee6b0 100644 --- a/src/workspace/workspace.py +++ b/src/workspace/workspace.py @@ -1,297 +1,289 @@ -import re -import sys -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import date -from pathlib import Path - -from src.models.tool_error_response import ToolErrorResponse -from src.workspace.exclusion_manager import ExclusionManager -from src.workspace.path_validator import PathNotFoundError, PathValidator, WorkspaceBoundaryError - - -def _highlight_matches(line: str, regex: re.Pattern) -> str: - """ - 高亮显示行中的匹配部分(内部方法) - - Args: - line: 原始行内容 - regex: 编译后的正则表达式 - - Returns: - 带有 **匹配** 标记的行内容 - """ - - def replacer(match): - return f"**{match.group(0)}**" - - try: - highlighted = regex.sub(replacer, line) - return highlighted - except Exception: - return line - - -class Workspace: - _instance: Workspace = None - - def __new__(cls, path: str): - if not cls._instance: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self, path: str): - if self._initialized: - return - self.root_path = Path(path).resolve() - self.path_validator: PathValidator = PathValidator(self.root_path) - self.exclusion_manager: ExclusionManager = ExclusionManager(self.root_path) - self.is_git_repo: bool = (self.root_path / ".git").is_dir() - self.platform: str = sys.platform - self.date: str = date.today().strftime("%y-%m-%d") - self._db = None - self._current_session_id: int | None = None - self._initialized = True - - @property - def db(self): - if self._db is None: - from src.core.database_manager import DatabaseManager - - self._db = DatabaseManager(str(self.root_path)) - return self._db - - @property - def session_id(self) -> int | None: - """当前会话 ID 的公开 getter —— 工具类通过此接口访问, 避免直接访问私有属性.""" - return self._current_session_id - - @session_id.setter - def session_id(self, value: int | None) -> None: - self._current_session_id = value - - def search_content( - self, - pattern: str, - folder_path: str = ".", - exclude_dirs: list[str] | None = None, - file_pattern: str = "*", - max_workers: int = 4, - case_sensitive: bool = False, - ) -> str: - """在工作区内递归搜索文件内容(正则匹配),支持排除目录、并发读取和匹配高亮.返回格式化搜索结果或错误""" - try: - path = self.path_validator.validate(folder_path) - - # 初始化排除目录集合: 合并默认排除 + 用户传入排除 - if exclude_dirs is not None: - exclude_set = set(exclude_dirs) | self.exclusion_manager.excluded_dir_names - else: - exclude_set = self.exclusion_manager.excluded_dir_names - - # 编译正则表达式 - flags = 0 if case_sensitive else re.IGNORECASE - try: - regex = re.compile(pattern, flags) - except re.error as e: - return f"错误:正则表达式无效 - {e}" - - # 收集所有要搜索的文件(支持单文件或目录) - files_to_search = [] - if path.is_file(): - files_to_search = [path] - else: - for file_path in path.rglob(file_pattern): - if file_path.is_file(): - # 检查是否在排除目录中 - should_exclude = False - for parent in file_path.parents: - if parent.name in exclude_set: - should_exclude = True - break - if not should_exclude: - files_to_search.append(file_path) - - if not files_to_search: - return f"在 {folder_path} 中没有找到匹配 {file_pattern} 的文件" - - # 异步搜索文件 - results = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = { - executor.submit(self._search_in_file, file_path, regex): file_path for file_path in files_to_search - } - - for future in as_completed(futures): - file_path = futures[future] - try: - file_results = future.result() - if file_results: - results.extend(file_results) - except Exception as e: - results.append(f"[错误] {file_path.relative_to(self.root_path)}: {e!s}") - - # 格式化输出 - if not results: - return f"未找到匹配 '{pattern}' 的内容" - - output_lines = [ - f"搜索模式: {pattern}", - f"搜索路径: {folder_path}", - f"排除目录: {', '.join(sorted(exclude_set)) if exclude_set else '无'}", - f"文件模式: {file_pattern}", - f"匹配文件数: {len(set(r[0] for r in results if not r[0].startswith('[错误]')))}", - f"匹配行数: {len(results)}", - "-" * 80, - "", - ] - - current_file = None - for file_rel, line_num, line_content in results: - if file_rel != current_file: - current_file = file_rel - output_lines.append(f"\n[文件] {current_file}") - output_lines.append("-" * 40) - - # 高亮显示匹配的部分 - highlighted = _highlight_matches(line_content, regex) - output_lines.append(f" {line_num:4d} | {highlighted}") - - return "\n".join(output_lines) - - except PathNotFoundError as err1: - return ToolErrorResponse(self.search_content.__name__, err1).to_str() - except WorkspaceBoundaryError as err2: - return ToolErrorResponse(self.search_content.__name__, err2).to_str() - except PermissionError as err3: - return ToolErrorResponse(self.search_content.__name__, err3).to_str() - except Exception as err: - return ToolErrorResponse(self.search_content.__name__, err).to_str() - - def _search_in_file(self, file_path: Path, regex: re.Pattern) -> list[tuple]: - """ - 在单个文件中搜索匹配内容(内部方法) - - Returns: - List of (relative_path, line_number, line_content) - """ - results = [] - relative_path = str(file_path.relative_to(self.root_path)) - - try: - with open(file_path, encoding="utf-8") as f: - for line_num, line in enumerate(f, 1): - if regex.search(line): - results.append((relative_path, line_num, line.rstrip("\n\r"))) - except UnicodeDecodeError, PermissionError: - # 跳过无法读取的二进制文件或无权限的文件 - pass - except Exception: - pass - - return results - - def search_content_multi_pattern( - self, - patterns: list[tuple[re.Pattern, str]], - folder_path: str = ".", - file_pattern: str = "*", - max_workers: int = 4, - ignore: list[str] | None = None, - ) -> list[dict]: - """单次文件遍历匹配多个正则模式, 直接返回结构化数据. - - 与 search_content 的区别: - - 接受多个已编译的正则 + 类型标签, 一次遍历全部匹配 - - 返回结构化 list[dict] 而非格式化字符串, 消除下游正则解析反模式 - - 不做高亮处理(高亮是展示层关注点, 不应混入数据层) - - Args: - patterns: [(compiled_regex, type_label), ...] 已编译正则及其类型标签 - folder_path: 搜索起始路径 - file_pattern: 文件通配符(如 "*.py") - max_workers: 并发读取文件的线程数 - ignore: 忽略路径正则列表 - - Returns: - [{"file": str, "line_num": int, "content": str, "pattern_type": str}, ...] - 按文件路径 → 行号排序, 同一行多个模式匹配则每个各一条记录 - """ - try: - path = self.path_validator.validate(folder_path) - - # 预编译 ignore 正则: 合并默认排除 + 用户传入的 ignore - ignore_res: list[re.Pattern] = self.exclusion_manager.merge_ignore_regexes(ignore) - - # 收集文件(一次遍历) - files_to_search: list[Path] = [] - if path.is_file(): - files_to_search = [path] - else: - for file_path in path.rglob(file_pattern): - if file_path.is_file(): - rel = str(file_path.relative_to(self.root_path)) - if any(ir.search(rel) for ir in ignore_res): - continue - files_to_search.append(file_path) - - if not files_to_search: - return [] - - # 并发搜索所有文件, 每个文件内一次读取、一次测试所有模式 - all_matches: list[dict] = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = { - executor.submit(self._search_multi_in_file, file_path, patterns): file_path - for file_path in files_to_search - } - for future in as_completed(futures): - try: - file_results = future.result() - if file_results: - all_matches.extend(file_results) - except Exception: - pass - - # 按文件路径 → 行号排序 - all_matches.sort(key=lambda m: (m["file"], m["line_num"])) - return all_matches - - except PathNotFoundError, WorkspaceBoundaryError, PermissionError: - return [] - except Exception: - return [] - - def _search_multi_in_file(self, file_path: Path, patterns: list[tuple[re.Pattern, str]]) -> list[dict]: - """在单个文件中一次读取、逐行测试所有模式. - - Args: - file_path: 文件绝对路径 - patterns: [(compiled_regex, type_label), ...] - - Returns: - 匹配项列表, 同一行多个模式匹配时每个各返回一条 - """ - results: list[dict] = [] - relative_path = str(file_path.relative_to(self.root_path)) - - try: - with open(file_path, encoding="utf-8") as f: - for line_num, line in enumerate(f, 1): - stripped = line.rstrip("\n\r") - for regex, pattern_type in patterns: - if regex.search(stripped): - results.append( - { - "file": relative_path, - "line_num": line_num, - "content": stripped, - "pattern_type": pattern_type, - } - ) - except UnicodeDecodeError, PermissionError: - pass - except Exception: - pass - - return results +import re +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import date +from pathlib import Path + +from src.models.tool_error_response import ToolErrorResponse +from src.workspace.exclusion_manager import ExclusionManager +from src.workspace.path_validator import PathNotFoundError, PathValidator, WorkspaceBoundaryError + + +def _highlight_matches(line: str, regex: re.Pattern) -> str: + """ + 高亮显示行中的匹配部分(内部方法) + + Args: + line: 原始行内容 + regex: 编译后的正则表达式 + + Returns: + 带有 **匹配** 标记的行内容 + """ + + def replacer(match): + return f"**{match.group(0)}**" + + try: + highlighted = regex.sub(replacer, line) + return highlighted + except Exception: + return line + + +class Workspace: + _instance: Workspace = None + + def __new__(cls, path: str): + if not cls._instance: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, path: str): + if self._initialized: + return + self.root_path = Path(path).resolve() + self.path_validator: PathValidator = PathValidator(self.root_path) + self.exclusion_manager: ExclusionManager = ExclusionManager(self.root_path) + self.is_git_repo: bool = (self.root_path / ".git").is_dir() + self.platform: str = sys.platform + self.date: str = date.today().strftime("%y-%m-%d") + self._db = None + self._current_session_id: int | None = None + self._initialized = True + + @property + def db(self): + if self._db is None: + from src.core.database_manager import DatabaseManager + + self._db = DatabaseManager(str(self.root_path)) + return self._db + + @property + def session_id(self) -> int | None: + """当前会话 ID 的公开 getter —— 工具类通过此接口访问, 避免直接访问私有属性.""" + return self._current_session_id + + @session_id.setter + def session_id(self, value: int | None) -> None: + self._current_session_id = value + + def search_content( + self, + pattern: str, + folder_path: str = ".", + exclude_dirs: list[str] | None = None, + file_pattern: str = "*", + max_workers: int = 4, + case_sensitive: bool = False, + ) -> str: + """在工作区内递归搜索文件内容(正则匹配),支持排除目录、并发读取和匹配高亮.返回格式化搜索结果或错误""" + try: + path = self.path_validator.validate(folder_path) + + # 初始化排除目录集合: 合并默认排除 + 用户传入排除 + exclude_set = set(exclude_dirs) | self.exclusion_manager.excluded_dir_names if exclude_dirs is not None else self.exclusion_manager.excluded_dir_names + + # 编译正则表达式 + flags = 0 if case_sensitive else re.IGNORECASE + try: + regex = re.compile(pattern, flags) + except re.error as e: + return f"错误:正则表达式无效 - {e}" + + # 收集所有要搜索的文件(支持单文件或目录) + files_to_search = [] + if path.is_file(): + files_to_search = [path] + else: + for file_path in path.rglob(file_pattern): + if file_path.is_file(): + # 检查是否在排除目录中 + should_exclude = False + for parent in file_path.parents: + if parent.name in exclude_set: + should_exclude = True + break + if not should_exclude: + files_to_search.append(file_path) + + if not files_to_search: + return f"在 {folder_path} 中没有找到匹配 {file_pattern} 的文件" + + # 异步搜索文件 + results = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(self._search_in_file, file_path, regex): file_path for file_path in files_to_search} + + for future in as_completed(futures): + file_path = futures[future] + try: + file_results = future.result() + if file_results: + results.extend(file_results) + except Exception as e: + results.append(f"[错误] {file_path.relative_to(self.root_path)}: {e!s}") + + # 格式化输出 + if not results: + return f"未找到匹配 '{pattern}' 的内容" + + output_lines = [ + f"搜索模式: {pattern}", + f"搜索路径: {folder_path}", + f"排除目录: {', '.join(sorted(exclude_set)) if exclude_set else '无'}", + f"文件模式: {file_pattern}", + f"匹配文件数: {len(set(r[0] for r in results if not r[0].startswith('[错误]')))}", + f"匹配行数: {len(results)}", + "-" * 80, + "", + ] + + current_file = None + for file_rel, line_num, line_content in results: + if file_rel != current_file: + current_file = file_rel + output_lines.append(f"\n[文件] {current_file}") + output_lines.append("-" * 40) + + # 高亮显示匹配的部分 + highlighted = _highlight_matches(line_content, regex) + output_lines.append(f" {line_num:4d} | {highlighted}") + + return "\n".join(output_lines) + + except PathNotFoundError as err1: + return ToolErrorResponse(self.search_content.__name__, err1).to_str() + except WorkspaceBoundaryError as err2: + return ToolErrorResponse(self.search_content.__name__, err2).to_str() + except PermissionError as err3: + return ToolErrorResponse(self.search_content.__name__, err3).to_str() + except Exception as err: + return ToolErrorResponse(self.search_content.__name__, err).to_str() + + def _search_in_file(self, file_path: Path, regex: re.Pattern) -> list[tuple]: + """ + 在单个文件中搜索匹配内容(内部方法) + + Returns: + List of (relative_path, line_number, line_content) + """ + results = [] + relative_path = str(file_path.relative_to(self.root_path)) + + try: + with open(file_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + if regex.search(line): + results.append((relative_path, line_num, line.rstrip("\n\r"))) + except UnicodeDecodeError, PermissionError: + # 跳过无法读取的二进制文件或无权限的文件 + pass + except Exception: + pass + + return results + + def search_content_multi_pattern( + self, + patterns: list[tuple[re.Pattern, str]], + folder_path: str = ".", + file_pattern: str = "*", + max_workers: int = 4, + ignore: list[str] | None = None, + ) -> list[dict]: + """单次文件遍历匹配多个正则模式, 直接返回结构化数据. + + 与 search_content 的区别: + - 接受多个已编译的正则 + 类型标签, 一次遍历全部匹配 + - 返回结构化 list[dict] 而非格式化字符串, 消除下游正则解析反模式 + - 不做高亮处理(高亮是展示层关注点, 不应混入数据层) + + Args: + patterns: [(compiled_regex, type_label), ...] 已编译正则及其类型标签 + folder_path: 搜索起始路径 + file_pattern: 文件通配符(如 "*.py") + max_workers: 并发读取文件的线程数 + ignore: 忽略路径正则列表 + + Returns: + [{"file": str, "line_num": int, "content": str, "pattern_type": str}, ...] + 按文件路径 → 行号排序, 同一行多个模式匹配则每个各一条记录 + """ + try: + path = self.path_validator.validate(folder_path) + + # 预编译 ignore 正则: 合并默认排除 + 用户传入的 ignore + ignore_res: list[re.Pattern] = self.exclusion_manager.merge_ignore_regexes(ignore) + + # 收集文件(一次遍历) + files_to_search: list[Path] = [] + if path.is_file(): + files_to_search = [path] + else: + for file_path in path.rglob(file_pattern): + if file_path.is_file(): + rel = str(file_path.relative_to(self.root_path)) + if any(ir.search(rel) for ir in ignore_res): + continue + files_to_search.append(file_path) + + if not files_to_search: + return [] + + # 并发搜索所有文件, 每个文件内一次读取、一次测试所有模式 + all_matches: list[dict] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(self._search_multi_in_file, file_path, patterns): file_path for file_path in files_to_search} + for future in as_completed(futures): + try: + file_results = future.result() + if file_results: + all_matches.extend(file_results) + except Exception: + pass + + # 按文件路径 → 行号排序 + all_matches.sort(key=lambda m: (m["file"], m["line_num"])) + return all_matches + + except PathNotFoundError, WorkspaceBoundaryError, PermissionError: + return [] + except Exception: + return [] + + def _search_multi_in_file(self, file_path: Path, patterns: list[tuple[re.Pattern, str]]) -> list[dict]: + """在单个文件中一次读取、逐行测试所有模式. + + Args: + file_path: 文件绝对路径 + patterns: [(compiled_regex, type_label), ...] + + Returns: + 匹配项列表, 同一行多个模式匹配时每个各返回一条 + """ + results: list[dict] = [] + relative_path = str(file_path.relative_to(self.root_path)) + + try: + with open(file_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + stripped = line.rstrip("\n\r") + for regex, pattern_type in patterns: + if regex.search(stripped): + results.append( + { + "file": relative_path, + "line_num": line_num, + "content": stripped, + "pattern_type": pattern_type, + } + ) + except UnicodeDecodeError, PermissionError: + pass + except Exception: + pass + + return results diff --git a/tests/core/test_audit_committer.py b/tests/core/test_audit_committer.py index d8b9838..57d0350 100644 --- a/tests/core/test_audit_committer.py +++ b/tests/core/test_audit_committer.py @@ -1,173 +1,169 @@ -import time -from pathlib import Path - -import pytest - -from src.core.audit_committer import AuditCommitter -from src.core.database_manager import DatabaseManager -from src.workspace.workspace import Workspace - - -@pytest.fixture(autouse=True) -def reset_singletons(): - Workspace._instance = None - DatabaseManager.reset_instances() - yield - Workspace._instance = None - DatabaseManager.reset_instances() - - -@pytest.fixture -def workspace(tmp_path: Path) -> Workspace: - ws = Workspace(str(tmp_path)) - ws._current_session_id = ws.db.create_session(name="test_session") - return ws - - -@pytest.fixture -def committer(workspace: Workspace) -> AuditCommitter: - return AuditCommitter(workspace) - - -def _create_pending_snapshot(workspace: Workspace, file_path: str, pending_content: str) -> int: - """Helper to create a PENDING_AUDIT snapshot.""" - return workspace.db.record_file_snapshot( - file_path, - "old_hash", - "new_hash", - "diff_content", - audit_status="PENDING_AUDIT", - pending_content=pending_content, - session_id=workspace._current_session_id, - ) - - -class TestCommitNewFile: - def test_approve_creates_new_file(self, committer: AuditCommitter, workspace: Workspace): - snapshot_id = _create_pending_snapshot(workspace, "new_file.txt", "hello world") - result = committer.commit(snapshot_id, approved=True) - - assert "已批准" in result - target = workspace.root_path / "new_file.txt" - assert target.is_file() - assert target.read_text(encoding="utf-8") == "hello world" - - def test_approve_updates_audit_status(self, committer: AuditCommitter, workspace: Workspace): - snapshot_id = _create_pending_snapshot(workspace, "new_file.txt", "content") - committer.commit(snapshot_id, approved=True) - - snap = workspace.db.get_snapshot_by_id(snapshot_id) - assert snap is not None - assert snap[7] == "APPROVED" - - def test_approve_creates_file_in_subdir(self, committer: AuditCommitter, workspace: Workspace): - snapshot_id = _create_pending_snapshot(workspace, "sub/dir/new_file.txt", "nested") - committer.commit(snapshot_id, approved=True) - - target = workspace.root_path / "sub" / "dir" / "new_file.txt" - assert target.is_file() - assert target.read_text(encoding="utf-8") == "nested" - - -class TestCommitExistingFile: - def test_approve_existing_file_with_read(self, committer: AuditCommitter, workspace: Workspace): - target = workspace.root_path / "test.txt" - target.write_text("original", encoding="utf-8") - - workspace.db.record_file_read( - workspace._current_session_id, "test.txt", target.stat().st_mtime, target.stat().st_size, "hash" - ) - - snapshot_id = _create_pending_snapshot(workspace, "test.txt", "updated content") - result = committer.commit(snapshot_id, approved=True) - - assert "已批准" in result - assert target.read_text(encoding="utf-8") == "updated content" - - def test_approve_existing_file_no_read_record(self, committer: AuditCommitter, workspace: Workspace): - target = workspace.root_path / "test.txt" - target.write_text("original", encoding="utf-8") - - snapshot_id = _create_pending_snapshot(workspace, "test.txt", "updated content") - result = committer.commit(snapshot_id, approved=True) - - assert "已批准" in result - assert target.read_text(encoding="utf-8") == "updated content" - - def test_approve_mtime_mismatch_fails(self, committer: AuditCommitter, workspace: Workspace): - target = workspace.root_path / "test.txt" - target.write_text("original", encoding="utf-8") - - workspace.db.record_file_read( - workspace._current_session_id, "test.txt", target.stat().st_mtime, target.stat().st_size, "hash" - ) - - # Modify file externally - time.sleep(0.1) - target.write_text("modified externally", encoding="utf-8") - - snapshot_id = _create_pending_snapshot(workspace, "test.txt", "should fail") - result = committer.commit(snapshot_id, approved=True) - - assert "ERROR" in result or "已被外部修改" in result - assert target.read_text(encoding="utf-8") == "modified externally" - - -class TestCommitReject: - def test_reject_does_not_write(self, committer: AuditCommitter, workspace: Workspace): - snapshot_id = _create_pending_snapshot(workspace, "should_not_exist.txt", "content") - result = committer.commit(snapshot_id, approved=False) - - assert "已拒绝" in result - assert not (workspace.root_path / "should_not_exist.txt").is_file() - - def test_reject_updates_status(self, committer: AuditCommitter, workspace: Workspace): - snapshot_id = _create_pending_snapshot(workspace, "test.txt", "content") - committer.commit(snapshot_id, approved=False) - - snap = workspace.db.get_snapshot_by_id(snapshot_id) - assert snap is not None - assert snap[7] == "REJECTED" - - -class TestCommitEdgeCases: - def test_commit_nonexistent_snapshot(self, committer: AuditCommitter): - result = committer.commit(9999, approved=True) - assert "不存在" in result - - def test_commit_already_approved(self, committer: AuditCommitter, workspace: Workspace): - snapshot_id = _create_pending_snapshot(workspace, "test.txt", "content") - committer.commit(snapshot_id, approved=True) - - result = committer.commit(snapshot_id, approved=True) - assert "已处理" in result - - def test_commit_already_rejected(self, committer: AuditCommitter, workspace: Workspace): - snapshot_id = _create_pending_snapshot(workspace, "test.txt", "content") - committer.commit(snapshot_id, approved=False) - - result = committer.commit(snapshot_id, approved=False) - assert "已处理" in result - - -class TestCommitBinaryProtection: - """测试审计提交器的二进制文件安全网.""" - - def test_commit_binary_ext_blocked(self, committer: AuditCommitter, workspace: Workspace): - """批准写入 .png 文件应被安全网拦截.""" - snapshot_id = _create_pending_snapshot(workspace, "image.png", "fake png") - result = committer.commit(snapshot_id, approved=True) - - assert "二进制文件" in result - assert not (workspace.root_path / "image.png").is_file() - # 状态应被标记为 REJECTED - snap = workspace.db.get_snapshot_by_id(snapshot_id) - assert snap[7] == "REJECTED" - - def test_commit_text_file_not_blocked(self, committer: AuditCommitter, workspace: Workspace): - """批准写入文本文件应正常通过.""" - snapshot_id = _create_pending_snapshot(workspace, "notes.txt", "hello") - result = committer.commit(snapshot_id, approved=True) - - assert "已批准" in result - assert (workspace.root_path / "notes.txt").read_text(encoding="utf-8") == "hello" +import time +from pathlib import Path + +import pytest + +from src.core.audit_committer import AuditCommitter +from src.core.database_manager import DatabaseManager +from src.workspace.workspace import Workspace + + +@pytest.fixture(autouse=True) +def reset_singletons(): + Workspace._instance = None + DatabaseManager.reset_instances() + yield + Workspace._instance = None + DatabaseManager.reset_instances() + + +@pytest.fixture +def workspace(tmp_path: Path) -> Workspace: + ws = Workspace(str(tmp_path)) + ws._current_session_id = ws.db.create_session(name="test_session") + return ws + + +@pytest.fixture +def committer(workspace: Workspace) -> AuditCommitter: + return AuditCommitter(workspace) + + +def _create_pending_snapshot(workspace: Workspace, file_path: str, pending_content: str) -> int: + """Helper to create a PENDING_AUDIT snapshot.""" + return workspace.db.record_file_snapshot( + file_path, + "old_hash", + "new_hash", + "diff_content", + audit_status="PENDING_AUDIT", + pending_content=pending_content, + session_id=workspace._current_session_id, + ) + + +class TestCommitNewFile: + def test_approve_creates_new_file(self, committer: AuditCommitter, workspace: Workspace): + snapshot_id = _create_pending_snapshot(workspace, "new_file.txt", "hello world") + result = committer.commit(snapshot_id, approved=True) + + assert "已批准" in result + target = workspace.root_path / "new_file.txt" + assert target.is_file() + assert target.read_text(encoding="utf-8") == "hello world" + + def test_approve_updates_audit_status(self, committer: AuditCommitter, workspace: Workspace): + snapshot_id = _create_pending_snapshot(workspace, "new_file.txt", "content") + committer.commit(snapshot_id, approved=True) + + snap = workspace.db.get_snapshot_by_id(snapshot_id) + assert snap is not None + assert snap[7] == "APPROVED" + + def test_approve_creates_file_in_subdir(self, committer: AuditCommitter, workspace: Workspace): + snapshot_id = _create_pending_snapshot(workspace, "sub/dir/new_file.txt", "nested") + committer.commit(snapshot_id, approved=True) + + target = workspace.root_path / "sub" / "dir" / "new_file.txt" + assert target.is_file() + assert target.read_text(encoding="utf-8") == "nested" + + +class TestCommitExistingFile: + def test_approve_existing_file_with_read(self, committer: AuditCommitter, workspace: Workspace): + target = workspace.root_path / "test.txt" + target.write_text("original", encoding="utf-8") + + workspace.db.record_file_read(workspace._current_session_id, "test.txt", target.stat().st_mtime, target.stat().st_size, "hash") + + snapshot_id = _create_pending_snapshot(workspace, "test.txt", "updated content") + result = committer.commit(snapshot_id, approved=True) + + assert "已批准" in result + assert target.read_text(encoding="utf-8") == "updated content" + + def test_approve_existing_file_no_read_record(self, committer: AuditCommitter, workspace: Workspace): + target = workspace.root_path / "test.txt" + target.write_text("original", encoding="utf-8") + + snapshot_id = _create_pending_snapshot(workspace, "test.txt", "updated content") + result = committer.commit(snapshot_id, approved=True) + + assert "已批准" in result + assert target.read_text(encoding="utf-8") == "updated content" + + def test_approve_mtime_mismatch_fails(self, committer: AuditCommitter, workspace: Workspace): + target = workspace.root_path / "test.txt" + target.write_text("original", encoding="utf-8") + + workspace.db.record_file_read(workspace._current_session_id, "test.txt", target.stat().st_mtime, target.stat().st_size, "hash") + + # Modify file externally + time.sleep(0.1) + target.write_text("modified externally", encoding="utf-8") + + snapshot_id = _create_pending_snapshot(workspace, "test.txt", "should fail") + result = committer.commit(snapshot_id, approved=True) + + assert "ERROR" in result or "已被外部修改" in result + assert target.read_text(encoding="utf-8") == "modified externally" + + +class TestCommitReject: + def test_reject_does_not_write(self, committer: AuditCommitter, workspace: Workspace): + snapshot_id = _create_pending_snapshot(workspace, "should_not_exist.txt", "content") + result = committer.commit(snapshot_id, approved=False) + + assert "已拒绝" in result + assert not (workspace.root_path / "should_not_exist.txt").is_file() + + def test_reject_updates_status(self, committer: AuditCommitter, workspace: Workspace): + snapshot_id = _create_pending_snapshot(workspace, "test.txt", "content") + committer.commit(snapshot_id, approved=False) + + snap = workspace.db.get_snapshot_by_id(snapshot_id) + assert snap is not None + assert snap[7] == "REJECTED" + + +class TestCommitEdgeCases: + def test_commit_nonexistent_snapshot(self, committer: AuditCommitter): + result = committer.commit(9999, approved=True) + assert "不存在" in result + + def test_commit_already_approved(self, committer: AuditCommitter, workspace: Workspace): + snapshot_id = _create_pending_snapshot(workspace, "test.txt", "content") + committer.commit(snapshot_id, approved=True) + + result = committer.commit(snapshot_id, approved=True) + assert "已处理" in result + + def test_commit_already_rejected(self, committer: AuditCommitter, workspace: Workspace): + snapshot_id = _create_pending_snapshot(workspace, "test.txt", "content") + committer.commit(snapshot_id, approved=False) + + result = committer.commit(snapshot_id, approved=False) + assert "已处理" in result + + +class TestCommitBinaryProtection: + """测试审计提交器的二进制文件安全网.""" + + def test_commit_binary_ext_blocked(self, committer: AuditCommitter, workspace: Workspace): + """批准写入 .png 文件应被安全网拦截.""" + snapshot_id = _create_pending_snapshot(workspace, "image.png", "fake png") + result = committer.commit(snapshot_id, approved=True) + + assert "二进制文件" in result + assert not (workspace.root_path / "image.png").is_file() + # 状态应被标记为 REJECTED + snap = workspace.db.get_snapshot_by_id(snapshot_id) + assert snap[7] == "REJECTED" + + def test_commit_text_file_not_blocked(self, committer: AuditCommitter, workspace: Workspace): + """批准写入文本文件应正常通过.""" + snapshot_id = _create_pending_snapshot(workspace, "notes.txt", "hello") + result = committer.commit(snapshot_id, approved=True) + + assert "已批准" in result + assert (workspace.root_path / "notes.txt").read_text(encoding="utf-8") == "hello" diff --git a/tests/core/test_database_manager.py b/tests/core/test_database_manager.py index 9099879..ef31173 100644 --- a/tests/core/test_database_manager.py +++ b/tests/core/test_database_manager.py @@ -1,508 +1,504 @@ -import time -from pathlib import Path - -import pytest - -from src.core.database_manager import DatabaseManager - - -@pytest.fixture(autouse=True) -def isolate_database_manager(): - DatabaseManager.reset_instances() - yield - DatabaseManager.reset_instances() - - -@pytest.fixture -def workspace_root(tmp_path: Path) -> str: - return str(tmp_path / "workspace") - - -@pytest.fixture -def db(workspace_root: str) -> DatabaseManager: - root = Path(workspace_root) - root.mkdir(parents=True, exist_ok=True) - return DatabaseManager(workspace_root) - - -class TestSingleton: - def test_same_path_returns_same_instance(self, workspace_root: str): - root = Path(workspace_root) - root.mkdir(parents=True, exist_ok=True) - db1 = DatabaseManager(workspace_root) - db2 = DatabaseManager(workspace_root) - assert db1 is db2 - - def test_different_paths_return_different_instances(self, tmp_path: Path): - root1 = tmp_path / "ws1" - root2 = tmp_path / "ws2" - root1.mkdir() - root2.mkdir() - db1 = DatabaseManager(str(root1)) - db2 = DatabaseManager(str(root2)) - assert db1 is not db2 - - -class TestDirectoryAndDbCreation: - def test_directory_created_on_init(self, db: DatabaseManager, workspace_root: str): - assert (Path(workspace_root) / ".ManualAid").is_dir() - - def test_db_file_exists_after_init(self, db: DatabaseManager): - assert db.db_path.is_file() - - def test_wal_mode_enabled(self, db: DatabaseManager): - row = db.fetchone("PRAGMA journal_mode") - assert row is not None - assert row[0] == "wal" - - -class TestTableCreation: - def test_sessions_table_exists(self, db: DatabaseManager): - rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='sessions'") - assert len(rows) == 1 - - def test_tool_calls_table_exists(self, db: DatabaseManager): - rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='tool_calls'") - assert len(rows) == 1 - - def test_file_read_records_table_exists(self, db: DatabaseManager): - rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='file_read_records'") - assert len(rows) == 1 - - def test_file_snapshots_table_exists(self, db: DatabaseManager): - rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='file_snapshots'") - assert len(rows) == 1 - - -class TestSessionLifecycle: - def test_create_session(self, db: DatabaseManager): - session_id = db.create_session(name="test_session") - assert isinstance(session_id, int) - assert session_id > 0 - - def test_close_session_updates_duration(self, db: DatabaseManager): - session_id = db.create_session(name="test_session") - time.sleep(0.05) - db.close_session(session_id) - - row = db.fetchone("SELECT duration FROM sessions WHERE id = ?", (session_id,)) - assert row is not None - assert row[0] > 0 - - def test_close_nonexistent_session(self, db: DatabaseManager): - db.close_session(9999) - - -class TestToolCallLogging: - def test_log_tool_call(self, db: DatabaseManager): - session_id = db.create_session() - call_id = db.log_tool_call(session_id, "read", "abc123", duration_ms=50.0, status="success") - assert isinstance(call_id, int) - assert call_id > 0 - - def test_log_tool_call_stored_correctly(self, db: DatabaseManager): - session_id = db.create_session() - db.log_tool_call(session_id, "write", "def456", duration_ms=120.0, status="error", audit_status="none") - - row = db.fetchone( - "SELECT func_name, kwargs, duration_ms, status FROM tool_calls WHERE session_id = ?", - (session_id,), - ) - assert row is not None - assert row[0] == "write" - assert row[1] == "def456" - assert row[2] == 120.0 - assert row[3] == "error" - - def test_update_tool_call_status(self, db: DatabaseManager): - session_id = db.create_session() - call_id = db.log_tool_call(session_id, "edit", "xyz") - - db.update_tool_call_status(call_id, "error", "PENDING_AUDIT") - - row = db.fetchone("SELECT status, audit_status FROM tool_calls WHERE id = ?", (call_id,)) - assert row is not None - assert row[0] == "error" - assert row[1] == "PENDING_AUDIT" - - -class TestFileReadRecords: - def test_record_file_read(self, db: DatabaseManager): - session_id = db.create_session() - db.record_file_read(session_id, "src/main.py", 1234567890.5, 1024, "abc123hash") - - row = db.get_file_read_record(session_id, "src/main.py") - assert row is not None - assert row[3] == 1234567890.5 - assert row[4] == 1024 - assert row[5] == "abc123hash" - - def test_record_file_read_upsert(self, db: DatabaseManager): - session_id = db.create_session() - db.record_file_read(session_id, "src/main.py", 1000.0, 100, "hash1") - db.record_file_read(session_id, "src/main.py", 2000.0, 200, "hash2") - - row = db.get_file_read_record(session_id, "src/main.py") - assert row is not None - assert row[3] == 2000.0 - assert row[4] == 200 - assert row[5] == "hash2" - assert row[7] == 2 - - def test_get_nonexistent_read_record(self, db: DatabaseManager): - session_id = db.create_session() - row = db.get_file_read_record(session_id, "nonexistent.py") - assert row is None - - -class TestFileSnapshots: - def test_record_file_snapshot(self, db: DatabaseManager): - snapshot_id = db.record_file_snapshot( - "src/main.py", "old_hash", "new_hash", "--- a/src/main.py\n+++ b/src/main.py\n" - ) - assert isinstance(snapshot_id, int) - assert snapshot_id > 0 - - def test_snapshot_stored_correctly(self, db: DatabaseManager): - session_id = db.create_session() - db.record_file_snapshot( - "src/main.py", None, "new_hash", "diff content", audit_status="PENDING_AUDIT", session_id=session_id - ) - - rows = db.fetchall("SELECT file_path, old_hash, new_hash, diff_content, audit_status FROM file_snapshots") - assert len(rows) == 1 - assert rows[0][0] == "src/main.py" - assert rows[0][1] is None - assert rows[0][2] == "new_hash" - assert rows[0][3] == "diff content" - assert rows[0][4] == "PENDING_AUDIT" - - def test_update_snapshot_audit(self, db: DatabaseManager): - snapshot_id = db.record_file_snapshot("src/main.py", "old", "new", "diff") - - db.update_snapshot_audit(snapshot_id, "APPROVED") - - row = db.fetchone("SELECT audit_status FROM file_snapshots WHERE id = ?", (snapshot_id,)) - assert row is not None - assert row[0] == "APPROVED" - - def test_get_pending_audits(self, db: DatabaseManager): - db.record_file_snapshot("a.py", None, "h1", "diff1") - db.record_file_snapshot("b.py", "h0", "h2", "diff2") - db.record_file_snapshot("c.py", "h0", "h3", "diff3", audit_status="APPROVED") - - pending = db.get_pending_audits() - assert len(pending) == 2 - - def test_snapshot_with_session_id(self, db: DatabaseManager): - session_id = db.create_session() - db.record_file_snapshot("src/main.py", "old", "new", "diff", session_id=session_id) - - row = db.fetchone("SELECT session_id FROM file_snapshots WHERE file_path = 'src/main.py'") - assert row is not None - assert row[0] == session_id - - -class TestPendingContentMigration: - def test_pending_content_column_exists(self, db: DatabaseManager): - cols = db.fetchall("PRAGMA table_info(file_snapshots)") - col_names = [c[1] for c in cols] - assert "pending_content" in col_names - - def test_pending_content_default_empty(self, db: DatabaseManager): - snapshot_id = db.record_file_snapshot("src/main.py", "old", "new", "diff") - row = db.fetchone("SELECT pending_content FROM file_snapshots WHERE id = ?", (snapshot_id,)) - assert row is not None - assert row[0] == "" - - def test_pending_content_stored(self, db: DatabaseManager): - session_id = db.create_session() - db.execute( - "UPDATE file_snapshots SET pending_content = ? WHERE id = ?", - ("new file content", 1), - ) - # Use record_file_snapshot with explicit pending_content via direct SQL - # since the method currently doesn't have pending_content param - snapshot_id = db.record_file_snapshot( - "src/main.py", - "old_hash", - "new_hash", - "diff_content", - audit_status="PENDING_AUDIT", - session_id=session_id, - ) - db.execute( - "UPDATE file_snapshots SET pending_content = ? WHERE id = ?", - ("pending content here", snapshot_id), - ) - row = db.fetchone("SELECT pending_content FROM file_snapshots WHERE id = ?", (snapshot_id,)) - assert row is not None - assert row[0] == "pending content here" - - -class TestGetSnapshotById: - def test_get_existing_snapshot(self, db: DatabaseManager): - session_id = db.create_session() - db.record_file_snapshot( - "src/main.py", - "old_hash", - "new_hash", - "diff_content", - audit_status="PENDING_AUDIT", - session_id=session_id, - ) - # Set pending_content directly - db.execute( - "UPDATE file_snapshots SET pending_content = ? WHERE id = ?", - ("pending content", 1), - ) - - row = db.get_snapshot_by_id(1) - assert row is not None - assert row[1] == "src/main.py" - assert row[8] == "pending content" - - def test_get_nonexistent_snapshot(self, db: DatabaseManager): - row = db.get_snapshot_by_id(9999) - assert row is None - - -class TestGetSnapshotsByAuditStatus: - def test_filter_by_status(self, db: DatabaseManager): - db.record_file_snapshot("a.py", None, "h1", "diff1") - db.record_file_snapshot("b.py", "h0", "h2", "diff2", audit_status="APPROVED") - db.record_file_snapshot("c.py", "h0", "h3", "diff3", audit_status="REJECTED") - - pending = db.get_snapshots_by_audit_status("PENDING_AUDIT") - approved = db.get_snapshots_by_audit_status("APPROVED") - rejected = db.get_snapshots_by_audit_status("REJECTED") - - assert len(pending) == 1 - assert pending[0][1] == "a.py" - assert len(approved) == 1 - assert approved[0][1] == "b.py" - assert len(rejected) == 1 - assert rejected[0][1] == "c.py" - - def test_no_matching_status(self, db: DatabaseManager): - rows = db.get_snapshots_by_audit_status("NONEXISTENT_STATUS") - assert len(rows) == 0 - - -class TestThreadSafety: - def test_concurrent_writes(self, db: DatabaseManager): - import threading - - errors = [] - - def write_session(name): - try: - db.create_session(name=name) - except Exception as e: - errors.append(e) - finally: - db.close() # 确保每个线程关闭自己的连接,避免 ResourceWarning - - threads = [threading.Thread(target=write_session, args=(f"thread_{i}",)) for i in range(5)] - for t in threads: - t.start() - for t in threads: - t.join() - - assert len(errors) == 0 - rows = db.fetchall("SELECT COUNT(*) FROM sessions") - assert rows[0][0] == 5 - - -class TestClose: - def test_close_and_reopen(self, workspace_root: str): - root = Path(workspace_root) - root.mkdir(parents=True, exist_ok=True) - db1 = DatabaseManager(workspace_root) - db1.create_session("first") - db1.close() - - db2 = DatabaseManager(workspace_root) - rows = db2.fetchall("SELECT COUNT(*) FROM sessions") - assert rows[0][0] == 1 - - -class TestSessionSummary: - def test_summary_returns_empty_for_nonexistent_session(self, db: DatabaseManager): - summary = db.get_session_summary(9999) - assert summary == {} - - def test_summary_with_no_tool_calls(self, db: DatabaseManager): - session_id = db.create_session(name="empty_session") - summary = db.get_session_summary(session_id) - assert summary["name"] == "empty_session" - assert summary["total_calls"] == 0 - assert summary["success_count"] == 0 - assert summary["fail_count"] == 0 - assert summary["success_rate"] == 0.0 - - def test_summary_with_tool_calls(self, db: DatabaseManager): - session_id = db.create_session(name="test") - db.log_tool_call(session_id, "read", "hash1", status="success") - db.log_tool_call(session_id, "write", "hash2", status="success") - db.log_tool_call(session_id, "edit", "hash3", status="error") - - summary = db.get_session_summary(session_id) - assert summary["total_calls"] == 3 - assert summary["success_count"] == 2 - assert summary["fail_count"] == 1 - assert summary["success_rate"] == pytest.approx(66.6667, rel=0.01) - - def test_summary_duration(self, db: DatabaseManager): - session_id = db.create_session(name="dur_test") - time.sleep(0.05) - db.close_session(session_id) - summary = db.get_session_summary(session_id) - assert summary["duration"] > 0 - - -class TestAllSessions: - def test_get_all_sessions_empty(self, db: DatabaseManager): - sessions = db.get_all_sessions() - assert len(sessions) == 0 - - def test_get_all_sessions_ordering(self, db: DatabaseManager): - id1 = db.create_session(name="first") - id2 = db.create_session(name="second") - id3 = db.create_session(name="third") - - sessions = db.get_all_sessions() - assert len(sessions) == 3 - # Most recent first - assert sessions[0][0] == id3 - assert sessions[1][0] == id2 - assert sessions[2][0] == id1 - - def test_get_all_sessions_returns_columns(self, db: DatabaseManager): - db.create_session(name="check_columns") - sessions = db.get_all_sessions() - row = sessions[0] - assert len(row) == 4 # id, name, created_at, duration - - -class TestRenameSession: - def test_rename_session(self, db: DatabaseManager): - session_id = db.create_session(name="original") - db.rename_session(session_id, "renamed") - row = db.fetchone("SELECT name FROM sessions WHERE id = ?", (session_id,)) - assert row is not None - assert row[0] == "renamed" - - def test_rename_nonexistent_session(self, db: DatabaseManager): - db.rename_session(9999, "ghost") # Should not raise - - -class TestDeleteSession: - def test_delete_session_removes_session(self, db: DatabaseManager): - session_id = db.create_session(name="to_delete") - db.delete_session(session_id) - row = db.fetchone("SELECT id FROM sessions WHERE id = ?", (session_id,)) - assert row is None - - def test_delete_session_removes_tool_calls(self, db: DatabaseManager): - session_id = db.create_session() - db.log_tool_call(session_id, "read", "hash1") - db.log_tool_call(session_id, "write", "hash2") - - # Verify tool calls exist - rows = db.fetchall("SELECT COUNT(*) FROM tool_calls WHERE session_id = ?", (session_id,)) - assert rows[0][0] == 2 - - db.delete_session(session_id) - - rows = db.fetchall("SELECT COUNT(*) FROM tool_calls WHERE session_id = ?", (session_id,)) - assert rows[0][0] == 0 - - def test_delete_session_removes_snapshots(self, db: DatabaseManager): - session_id = db.create_session() - db.record_file_snapshot("a.py", "old", "new", "diff", session_id=session_id) - - rows = db.fetchall("SELECT COUNT(*) FROM file_snapshots WHERE session_id = ?", (session_id,)) - assert rows[0][0] == 1 - - db.delete_session(session_id) - - rows = db.fetchall("SELECT COUNT(*) FROM file_snapshots WHERE session_id = ?", (session_id,)) - assert rows[0][0] == 0 - - def test_delete_nonexistent_session(self, db: DatabaseManager): - db.delete_session(9999) # Should not raise - - def test_delete_session_keeps_other_sessions(self, db: DatabaseManager): - sid1 = db.create_session() - sid2 = db.create_session() - db.log_tool_call(sid1, "read", "h1") - db.log_tool_call(sid2, "write", "h2") - - db.delete_session(sid1) - - remaining = db.fetchall("SELECT id FROM sessions") - assert len(remaining) == 1 - assert remaining[0][0] == sid2 - - -class TestToolUsageRanking: - def test_ranking_empty(self, db: DatabaseManager): - ranking = db.get_tool_usage_ranking() - assert len(ranking) == 0 - - def test_ranking_global(self, db: DatabaseManager): - sid1 = db.create_session() - sid2 = db.create_session() - - # Session 1: 3 reads, 2 writes - for _ in range(3): - db.log_tool_call(sid1, "read", "h", duration_ms=10.0) - for _ in range(2): - db.log_tool_call(sid1, "write", "h", duration_ms=20.0) - - # Session 2: 1 read, 1 glob - db.log_tool_call(sid2, "read", "h", duration_ms=5.0) - db.log_tool_call(sid2, "glob", "h", duration_ms=15.0) - - ranking = db.get_tool_usage_ranking(limit=10) - assert len(ranking) == 3 - # read=4, write=2, glob=1 - assert ranking[0][0] == "read" - assert ranking[0][1] == 4 - assert ranking[1][0] == "write" - assert ranking[1][1] == 2 - assert ranking[2][0] == "glob" - assert ranking[2][1] == 1 - - def test_ranking_per_session(self, db: DatabaseManager): - sid = db.create_session() - db.log_tool_call(sid, "read", "h", duration_ms=5.0) - db.log_tool_call(sid, "read", "h", duration_ms=10.0) - db.log_tool_call(sid, "glob", "h", duration_ms=15.0) - - ranking = db.get_tool_usage_ranking(session_id=sid) - assert len(ranking) == 2 - assert ranking[0][0] == "read" - assert ranking[0][1] == 2 - assert ranking[1][0] == "glob" - assert ranking[1][1] == 1 - - def test_ranking_avg_duration(self, db: DatabaseManager): - sid = db.create_session() - db.log_tool_call(sid, "read", "h", duration_ms=10.0) - db.log_tool_call(sid, "read", "h", duration_ms=30.0) - - ranking = db.get_tool_usage_ranking(session_id=sid) - assert len(ranking) == 1 - assert ranking[0][0] == "read" - assert ranking[0][1] == 2 - assert ranking[0][2] == pytest.approx(20.0) - assert ranking[0][3] == pytest.approx(40.0) - - def test_ranking_limit(self, db: DatabaseManager): - sid = db.create_session() - for tool in ["a", "b", "c", "d", "e"]: - db.log_tool_call(sid, tool, "h") - ranking = db.get_tool_usage_ranking(session_id=sid, limit=3) - assert len(ranking) == 3 +import time +from pathlib import Path + +import pytest + +from src.core.database_manager import DatabaseManager + + +@pytest.fixture(autouse=True) +def isolate_database_manager(): + DatabaseManager.reset_instances() + yield + DatabaseManager.reset_instances() + + +@pytest.fixture +def workspace_root(tmp_path: Path) -> str: + return str(tmp_path / "workspace") + + +@pytest.fixture +def db(workspace_root: str) -> DatabaseManager: + root = Path(workspace_root) + root.mkdir(parents=True, exist_ok=True) + return DatabaseManager(workspace_root) + + +class TestSingleton: + def test_same_path_returns_same_instance(self, workspace_root: str): + root = Path(workspace_root) + root.mkdir(parents=True, exist_ok=True) + db1 = DatabaseManager(workspace_root) + db2 = DatabaseManager(workspace_root) + assert db1 is db2 + + def test_different_paths_return_different_instances(self, tmp_path: Path): + root1 = tmp_path / "ws1" + root2 = tmp_path / "ws2" + root1.mkdir() + root2.mkdir() + db1 = DatabaseManager(str(root1)) + db2 = DatabaseManager(str(root2)) + assert db1 is not db2 + + +class TestDirectoryAndDbCreation: + def test_directory_created_on_init(self, db: DatabaseManager, workspace_root: str): + assert (Path(workspace_root) / ".ManualAid").is_dir() + + def test_db_file_exists_after_init(self, db: DatabaseManager): + assert db.db_path.is_file() + + def test_wal_mode_enabled(self, db: DatabaseManager): + row = db.fetchone("PRAGMA journal_mode") + assert row is not None + assert row[0] == "wal" + + +class TestTableCreation: + def test_sessions_table_exists(self, db: DatabaseManager): + rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='sessions'") + assert len(rows) == 1 + + def test_tool_calls_table_exists(self, db: DatabaseManager): + rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='tool_calls'") + assert len(rows) == 1 + + def test_file_read_records_table_exists(self, db: DatabaseManager): + rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='file_read_records'") + assert len(rows) == 1 + + def test_file_snapshots_table_exists(self, db: DatabaseManager): + rows = db.fetchall("SELECT name FROM sqlite_master WHERE type='table' AND name='file_snapshots'") + assert len(rows) == 1 + + +class TestSessionLifecycle: + def test_create_session(self, db: DatabaseManager): + session_id = db.create_session(name="test_session") + assert isinstance(session_id, int) + assert session_id > 0 + + def test_close_session_updates_duration(self, db: DatabaseManager): + session_id = db.create_session(name="test_session") + time.sleep(0.05) + db.close_session(session_id) + + row = db.fetchone("SELECT duration FROM sessions WHERE id = ?", (session_id,)) + assert row is not None + assert row[0] > 0 + + def test_close_nonexistent_session(self, db: DatabaseManager): + db.close_session(9999) + + +class TestToolCallLogging: + def test_log_tool_call(self, db: DatabaseManager): + session_id = db.create_session() + call_id = db.log_tool_call(session_id, "read", "abc123", duration_ms=50.0, status="success") + assert isinstance(call_id, int) + assert call_id > 0 + + def test_log_tool_call_stored_correctly(self, db: DatabaseManager): + session_id = db.create_session() + db.log_tool_call(session_id, "write", "def456", duration_ms=120.0, status="error", audit_status="none") + + row = db.fetchone( + "SELECT func_name, kwargs, duration_ms, status FROM tool_calls WHERE session_id = ?", + (session_id,), + ) + assert row is not None + assert row[0] == "write" + assert row[1] == "def456" + assert row[2] == 120.0 + assert row[3] == "error" + + def test_update_tool_call_status(self, db: DatabaseManager): + session_id = db.create_session() + call_id = db.log_tool_call(session_id, "edit", "xyz") + + db.update_tool_call_status(call_id, "error", "PENDING_AUDIT") + + row = db.fetchone("SELECT status, audit_status FROM tool_calls WHERE id = ?", (call_id,)) + assert row is not None + assert row[0] == "error" + assert row[1] == "PENDING_AUDIT" + + +class TestFileReadRecords: + def test_record_file_read(self, db: DatabaseManager): + session_id = db.create_session() + db.record_file_read(session_id, "src/main.py", 1234567890.5, 1024, "abc123hash") + + row = db.get_file_read_record(session_id, "src/main.py") + assert row is not None + assert row[3] == 1234567890.5 + assert row[4] == 1024 + assert row[5] == "abc123hash" + + def test_record_file_read_upsert(self, db: DatabaseManager): + session_id = db.create_session() + db.record_file_read(session_id, "src/main.py", 1000.0, 100, "hash1") + db.record_file_read(session_id, "src/main.py", 2000.0, 200, "hash2") + + row = db.get_file_read_record(session_id, "src/main.py") + assert row is not None + assert row[3] == 2000.0 + assert row[4] == 200 + assert row[5] == "hash2" + assert row[7] == 2 + + def test_get_nonexistent_read_record(self, db: DatabaseManager): + session_id = db.create_session() + row = db.get_file_read_record(session_id, "nonexistent.py") + assert row is None + + +class TestFileSnapshots: + def test_record_file_snapshot(self, db: DatabaseManager): + snapshot_id = db.record_file_snapshot("src/main.py", "old_hash", "new_hash", "--- a/src/main.py\n+++ b/src/main.py\n") + assert isinstance(snapshot_id, int) + assert snapshot_id > 0 + + def test_snapshot_stored_correctly(self, db: DatabaseManager): + session_id = db.create_session() + db.record_file_snapshot("src/main.py", None, "new_hash", "diff content", audit_status="PENDING_AUDIT", session_id=session_id) + + rows = db.fetchall("SELECT file_path, old_hash, new_hash, diff_content, audit_status FROM file_snapshots") + assert len(rows) == 1 + assert rows[0][0] == "src/main.py" + assert rows[0][1] is None + assert rows[0][2] == "new_hash" + assert rows[0][3] == "diff content" + assert rows[0][4] == "PENDING_AUDIT" + + def test_update_snapshot_audit(self, db: DatabaseManager): + snapshot_id = db.record_file_snapshot("src/main.py", "old", "new", "diff") + + db.update_snapshot_audit(snapshot_id, "APPROVED") + + row = db.fetchone("SELECT audit_status FROM file_snapshots WHERE id = ?", (snapshot_id,)) + assert row is not None + assert row[0] == "APPROVED" + + def test_get_pending_audits(self, db: DatabaseManager): + db.record_file_snapshot("a.py", None, "h1", "diff1") + db.record_file_snapshot("b.py", "h0", "h2", "diff2") + db.record_file_snapshot("c.py", "h0", "h3", "diff3", audit_status="APPROVED") + + pending = db.get_pending_audits() + assert len(pending) == 2 + + def test_snapshot_with_session_id(self, db: DatabaseManager): + session_id = db.create_session() + db.record_file_snapshot("src/main.py", "old", "new", "diff", session_id=session_id) + + row = db.fetchone("SELECT session_id FROM file_snapshots WHERE file_path = 'src/main.py'") + assert row is not None + assert row[0] == session_id + + +class TestPendingContentMigration: + def test_pending_content_column_exists(self, db: DatabaseManager): + cols = db.fetchall("PRAGMA table_info(file_snapshots)") + col_names = [c[1] for c in cols] + assert "pending_content" in col_names + + def test_pending_content_default_empty(self, db: DatabaseManager): + snapshot_id = db.record_file_snapshot("src/main.py", "old", "new", "diff") + row = db.fetchone("SELECT pending_content FROM file_snapshots WHERE id = ?", (snapshot_id,)) + assert row is not None + assert row[0] == "" + + def test_pending_content_stored(self, db: DatabaseManager): + session_id = db.create_session() + db.execute( + "UPDATE file_snapshots SET pending_content = ? WHERE id = ?", + ("new file content", 1), + ) + # Use record_file_snapshot with explicit pending_content via direct SQL + # since the method currently doesn't have pending_content param + snapshot_id = db.record_file_snapshot( + "src/main.py", + "old_hash", + "new_hash", + "diff_content", + audit_status="PENDING_AUDIT", + session_id=session_id, + ) + db.execute( + "UPDATE file_snapshots SET pending_content = ? WHERE id = ?", + ("pending content here", snapshot_id), + ) + row = db.fetchone("SELECT pending_content FROM file_snapshots WHERE id = ?", (snapshot_id,)) + assert row is not None + assert row[0] == "pending content here" + + +class TestGetSnapshotById: + def test_get_existing_snapshot(self, db: DatabaseManager): + session_id = db.create_session() + db.record_file_snapshot( + "src/main.py", + "old_hash", + "new_hash", + "diff_content", + audit_status="PENDING_AUDIT", + session_id=session_id, + ) + # Set pending_content directly + db.execute( + "UPDATE file_snapshots SET pending_content = ? WHERE id = ?", + ("pending content", 1), + ) + + row = db.get_snapshot_by_id(1) + assert row is not None + assert row[1] == "src/main.py" + assert row[8] == "pending content" + + def test_get_nonexistent_snapshot(self, db: DatabaseManager): + row = db.get_snapshot_by_id(9999) + assert row is None + + +class TestGetSnapshotsByAuditStatus: + def test_filter_by_status(self, db: DatabaseManager): + db.record_file_snapshot("a.py", None, "h1", "diff1") + db.record_file_snapshot("b.py", "h0", "h2", "diff2", audit_status="APPROVED") + db.record_file_snapshot("c.py", "h0", "h3", "diff3", audit_status="REJECTED") + + pending = db.get_snapshots_by_audit_status("PENDING_AUDIT") + approved = db.get_snapshots_by_audit_status("APPROVED") + rejected = db.get_snapshots_by_audit_status("REJECTED") + + assert len(pending) == 1 + assert pending[0][1] == "a.py" + assert len(approved) == 1 + assert approved[0][1] == "b.py" + assert len(rejected) == 1 + assert rejected[0][1] == "c.py" + + def test_no_matching_status(self, db: DatabaseManager): + rows = db.get_snapshots_by_audit_status("NONEXISTENT_STATUS") + assert len(rows) == 0 + + +class TestThreadSafety: + def test_concurrent_writes(self, db: DatabaseManager): + import threading + + errors = [] + + def write_session(name): + try: + db.create_session(name=name) + except Exception as e: + errors.append(e) + finally: + db.close() # 确保每个线程关闭自己的连接,避免 ResourceWarning + + threads = [threading.Thread(target=write_session, args=(f"thread_{i}",)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + rows = db.fetchall("SELECT COUNT(*) FROM sessions") + assert rows[0][0] == 5 + + +class TestClose: + def test_close_and_reopen(self, workspace_root: str): + root = Path(workspace_root) + root.mkdir(parents=True, exist_ok=True) + db1 = DatabaseManager(workspace_root) + db1.create_session("first") + db1.close() + + db2 = DatabaseManager(workspace_root) + rows = db2.fetchall("SELECT COUNT(*) FROM sessions") + assert rows[0][0] == 1 + + +class TestSessionSummary: + def test_summary_returns_empty_for_nonexistent_session(self, db: DatabaseManager): + summary = db.get_session_summary(9999) + assert summary == {} + + def test_summary_with_no_tool_calls(self, db: DatabaseManager): + session_id = db.create_session(name="empty_session") + summary = db.get_session_summary(session_id) + assert summary["name"] == "empty_session" + assert summary["total_calls"] == 0 + assert summary["success_count"] == 0 + assert summary["fail_count"] == 0 + assert summary["success_rate"] == 0.0 + + def test_summary_with_tool_calls(self, db: DatabaseManager): + session_id = db.create_session(name="test") + db.log_tool_call(session_id, "read", "hash1", status="success") + db.log_tool_call(session_id, "write", "hash2", status="success") + db.log_tool_call(session_id, "edit", "hash3", status="error") + + summary = db.get_session_summary(session_id) + assert summary["total_calls"] == 3 + assert summary["success_count"] == 2 + assert summary["fail_count"] == 1 + assert summary["success_rate"] == pytest.approx(66.6667, rel=0.01) + + def test_summary_duration(self, db: DatabaseManager): + session_id = db.create_session(name="dur_test") + time.sleep(0.05) + db.close_session(session_id) + summary = db.get_session_summary(session_id) + assert summary["duration"] > 0 + + +class TestAllSessions: + def test_get_all_sessions_empty(self, db: DatabaseManager): + sessions = db.get_all_sessions() + assert len(sessions) == 0 + + def test_get_all_sessions_ordering(self, db: DatabaseManager): + id1 = db.create_session(name="first") + id2 = db.create_session(name="second") + id3 = db.create_session(name="third") + + sessions = db.get_all_sessions() + assert len(sessions) == 3 + # Most recent first + assert sessions[0][0] == id3 + assert sessions[1][0] == id2 + assert sessions[2][0] == id1 + + def test_get_all_sessions_returns_columns(self, db: DatabaseManager): + db.create_session(name="check_columns") + sessions = db.get_all_sessions() + row = sessions[0] + assert len(row) == 4 # id, name, created_at, duration + + +class TestRenameSession: + def test_rename_session(self, db: DatabaseManager): + session_id = db.create_session(name="original") + db.rename_session(session_id, "renamed") + row = db.fetchone("SELECT name FROM sessions WHERE id = ?", (session_id,)) + assert row is not None + assert row[0] == "renamed" + + def test_rename_nonexistent_session(self, db: DatabaseManager): + db.rename_session(9999, "ghost") # Should not raise + + +class TestDeleteSession: + def test_delete_session_removes_session(self, db: DatabaseManager): + session_id = db.create_session(name="to_delete") + db.delete_session(session_id) + row = db.fetchone("SELECT id FROM sessions WHERE id = ?", (session_id,)) + assert row is None + + def test_delete_session_removes_tool_calls(self, db: DatabaseManager): + session_id = db.create_session() + db.log_tool_call(session_id, "read", "hash1") + db.log_tool_call(session_id, "write", "hash2") + + # Verify tool calls exist + rows = db.fetchall("SELECT COUNT(*) FROM tool_calls WHERE session_id = ?", (session_id,)) + assert rows[0][0] == 2 + + db.delete_session(session_id) + + rows = db.fetchall("SELECT COUNT(*) FROM tool_calls WHERE session_id = ?", (session_id,)) + assert rows[0][0] == 0 + + def test_delete_session_removes_snapshots(self, db: DatabaseManager): + session_id = db.create_session() + db.record_file_snapshot("a.py", "old", "new", "diff", session_id=session_id) + + rows = db.fetchall("SELECT COUNT(*) FROM file_snapshots WHERE session_id = ?", (session_id,)) + assert rows[0][0] == 1 + + db.delete_session(session_id) + + rows = db.fetchall("SELECT COUNT(*) FROM file_snapshots WHERE session_id = ?", (session_id,)) + assert rows[0][0] == 0 + + def test_delete_nonexistent_session(self, db: DatabaseManager): + db.delete_session(9999) # Should not raise + + def test_delete_session_keeps_other_sessions(self, db: DatabaseManager): + sid1 = db.create_session() + sid2 = db.create_session() + db.log_tool_call(sid1, "read", "h1") + db.log_tool_call(sid2, "write", "h2") + + db.delete_session(sid1) + + remaining = db.fetchall("SELECT id FROM sessions") + assert len(remaining) == 1 + assert remaining[0][0] == sid2 + + +class TestToolUsageRanking: + def test_ranking_empty(self, db: DatabaseManager): + ranking = db.get_tool_usage_ranking() + assert len(ranking) == 0 + + def test_ranking_global(self, db: DatabaseManager): + sid1 = db.create_session() + sid2 = db.create_session() + + # Session 1: 3 reads, 2 writes + for _ in range(3): + db.log_tool_call(sid1, "read", "h", duration_ms=10.0) + for _ in range(2): + db.log_tool_call(sid1, "write", "h", duration_ms=20.0) + + # Session 2: 1 read, 1 glob + db.log_tool_call(sid2, "read", "h", duration_ms=5.0) + db.log_tool_call(sid2, "glob", "h", duration_ms=15.0) + + ranking = db.get_tool_usage_ranking(limit=10) + assert len(ranking) == 3 + # read=4, write=2, glob=1 + assert ranking[0][0] == "read" + assert ranking[0][1] == 4 + assert ranking[1][0] == "write" + assert ranking[1][1] == 2 + assert ranking[2][0] == "glob" + assert ranking[2][1] == 1 + + def test_ranking_per_session(self, db: DatabaseManager): + sid = db.create_session() + db.log_tool_call(sid, "read", "h", duration_ms=5.0) + db.log_tool_call(sid, "read", "h", duration_ms=10.0) + db.log_tool_call(sid, "glob", "h", duration_ms=15.0) + + ranking = db.get_tool_usage_ranking(session_id=sid) + assert len(ranking) == 2 + assert ranking[0][0] == "read" + assert ranking[0][1] == 2 + assert ranking[1][0] == "glob" + assert ranking[1][1] == 1 + + def test_ranking_avg_duration(self, db: DatabaseManager): + sid = db.create_session() + db.log_tool_call(sid, "read", "h", duration_ms=10.0) + db.log_tool_call(sid, "read", "h", duration_ms=30.0) + + ranking = db.get_tool_usage_ranking(session_id=sid) + assert len(ranking) == 1 + assert ranking[0][0] == "read" + assert ranking[0][1] == 2 + assert ranking[0][2] == pytest.approx(20.0) + assert ranking[0][3] == pytest.approx(40.0) + + def test_ranking_limit(self, db: DatabaseManager): + sid = db.create_session() + for tool in ["a", "b", "c", "d", "e"]: + db.log_tool_call(sid, tool, "h") + ranking = db.get_tool_usage_ranking(session_id=sid, limit=3) + assert len(ranking) == 3 diff --git a/tests/core/test_tool_call_summaries.py b/tests/core/test_tool_call_summaries.py index da2d9ae..0c8eb80 100644 --- a/tests/core/test_tool_call_summaries.py +++ b/tests/core/test_tool_call_summaries.py @@ -1,150 +1,146 @@ -"""Test tool_call_summaries table functionality.""" - -import os -import sqlite3 -import time - -import pytest - -from src.core.database_manager import DatabaseManager - - -@pytest.fixture -def temp_db(tmp_path): - """Create a temporary database for testing.""" - temp_workspace = tmp_path / "workspace" - os.makedirs(temp_workspace, exist_ok=True) - db = DatabaseManager(str(temp_workspace)) - yield db - db.close() - DatabaseManager.reset_instances() - - -class TestToolCallSummariesTable: - def test_table_exists(self, temp_db): - cursor = temp_db._get_connection() - tables = cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='tool_call_summaries'" - ).fetchall() - assert len(tables) == 1 - assert tables[0][0] == "tool_call_summaries" - - def test_primary_key_constraints(self, temp_db): - session_id = temp_db.create_session() - kwargs_json = '{"file_path": "test.txt"}' - - # Insert first record - temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result1") - - # Insert with same primary key (should update) - time.sleep(0.01) - temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result2") - - # Should only have one record - summaries = temp_db.get_tool_call_summaries(session_id) - assert len(summaries) == 1 - assert summaries[0][3] == "result2" # result updated - assert summaries[0][4] > time.time() - 1 # timestamp updated - - def test_different_kwargs_create_new_records(self, temp_db): - session_id = temp_db.create_session() - kwargs1 = '{"file_path": "test.txt"}' - kwargs2 = '{"file_path": "other.txt"}' - - temp_db.record_tool_call_summary(session_id, "read", kwargs1, "result1") - temp_db.record_tool_call_summary(session_id, "read", kwargs2, "result2") - - summaries = temp_db.get_tool_call_summaries(session_id) - assert len(summaries) == 2 - - def test_different_func_names_create_new_records(self, temp_db): - session_id = temp_db.create_session() - kwargs_json = '{"query": "test"}' - - temp_db.record_tool_call_summary(session_id, "search", kwargs_json, "result1") - temp_db.record_tool_call_summary(session_id, "stat", kwargs_json, "result2") - - summaries = temp_db.get_tool_call_summaries(session_id) - assert len(summaries) == 2 - - def test_get_tool_call_summaries_ordered_by_timestamp(self, temp_db): - session_id = temp_db.create_session() - kwargs_json = '{"file_path": "test.txt"}' - - time.sleep(0.01) - temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result1") - time.sleep(0.01) - temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result2") - time.sleep(0.01) - temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result3") - - summaries = temp_db.get_tool_call_summaries(session_id) - assert len(summaries) == 1 - - def test_get_tool_call_summaries_from_different_sessions(self, temp_db): - session_id1 = temp_db.create_session() - session_id2 = temp_db.create_session() - kwargs_json = '{"file_path": "test.txt"}' - - temp_db.record_tool_call_summary(session_id1, "read", kwargs_json, "result1") - temp_db.record_tool_call_summary(session_id2, "read", kwargs_json, "result2") - - summaries1 = temp_db.get_tool_call_summaries(session_id1) - summaries2 = temp_db.get_tool_call_summaries(session_id2) - - assert len(summaries1) == 1 - assert len(summaries2) == 1 - assert summaries1[0][3] != summaries2[0][3] - - def test_session_foreign_key_constraint(self, temp_db): - kwargs_json = '{"file_path": "test.txt"}' - # Trying to insert with non-existent session should raise FK constraint error - # because PRAGMA foreign_keys=ON is enabled - with pytest.raises(sqlite3.IntegrityError): - temp_db.record_tool_call_summary(999, "read", kwargs_json, "result1") - - # No record should be inserted - summaries = temp_db.get_tool_call_summaries(999) - assert len(summaries) == 0 - - def test_index_exists(self, temp_db): - cursor = temp_db._get_connection() - indexes = cursor.execute( - "SELECT name FROM sqlite_master WHERE type='index' AND name='idx_tool_call_summaries_session'" - ).fetchall() - assert len(indexes) == 1 - - -class TestToolCallSummariesIntegration: - def test_record_tool_call_summary_stores_correct_data(self, temp_db): - session_id = temp_db.create_session() - func_name = "read" - kwargs_json = '{"file_path": "test.txt"}' - result = "file content here" - - temp_db.record_tool_call_summary(session_id, func_name, kwargs_json, result) - - summaries = temp_db.get_tool_call_summaries(session_id) - assert len(summaries) == 1 - assert summaries[0][0] == session_id - assert summaries[0][1] == func_name - assert summaries[0][2] == kwargs_json - assert summaries[0][3] == result - assert summaries[0][4] > 0 # timestamp - - def test_multiple_sessions_isolated(self, temp_db): - session_id1 = temp_db.create_session() - session_id2 = temp_db.create_session() - - kwargs_json = '{"file_path": "test.txt"}' - result1 = "result for session 1" - result2 = "result for session 2" - - temp_db.record_tool_call_summary(session_id1, "read", kwargs_json, result1) - temp_db.record_tool_call_summary(session_id2, "read", kwargs_json, result2) - - summaries1 = temp_db.get_tool_call_summaries(session_id1) - summaries2 = temp_db.get_tool_call_summaries(session_id2) - - assert summaries1[0][3] == result1 - assert summaries2[0][3] == result2 +"""Test tool_call_summaries table functionality.""" + +import os +import sqlite3 +import time + +import pytest + +from src.core.database_manager import DatabaseManager + + +@pytest.fixture +def temp_db(tmp_path): + """Create a temporary database for testing.""" + temp_workspace = tmp_path / "workspace" + os.makedirs(temp_workspace, exist_ok=True) + db = DatabaseManager(str(temp_workspace)) + yield db + db.close() + DatabaseManager.reset_instances() + + +class TestToolCallSummariesTable: + def test_table_exists(self, temp_db): + cursor = temp_db._get_connection() + tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tool_call_summaries'").fetchall() + assert len(tables) == 1 + assert tables[0][0] == "tool_call_summaries" + + def test_primary_key_constraints(self, temp_db): + session_id = temp_db.create_session() + kwargs_json = '{"file_path": "test.txt"}' + + # Insert first record + temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result1") + + # Insert with same primary key (should update) + time.sleep(0.01) + temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result2") + + # Should only have one record + summaries = temp_db.get_tool_call_summaries(session_id) + assert len(summaries) == 1 + assert summaries[0][3] == "result2" # result updated + assert summaries[0][4] > time.time() - 1 # timestamp updated + + def test_different_kwargs_create_new_records(self, temp_db): + session_id = temp_db.create_session() + kwargs1 = '{"file_path": "test.txt"}' + kwargs2 = '{"file_path": "other.txt"}' + + temp_db.record_tool_call_summary(session_id, "read", kwargs1, "result1") + temp_db.record_tool_call_summary(session_id, "read", kwargs2, "result2") + + summaries = temp_db.get_tool_call_summaries(session_id) + assert len(summaries) == 2 + + def test_different_func_names_create_new_records(self, temp_db): + session_id = temp_db.create_session() + kwargs_json = '{"query": "test"}' + + temp_db.record_tool_call_summary(session_id, "search", kwargs_json, "result1") + temp_db.record_tool_call_summary(session_id, "stat", kwargs_json, "result2") + + summaries = temp_db.get_tool_call_summaries(session_id) + assert len(summaries) == 2 + + def test_get_tool_call_summaries_ordered_by_timestamp(self, temp_db): + session_id = temp_db.create_session() + kwargs_json = '{"file_path": "test.txt"}' + + time.sleep(0.01) + temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result1") + time.sleep(0.01) + temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result2") + time.sleep(0.01) + temp_db.record_tool_call_summary(session_id, "read", kwargs_json, "result3") + + summaries = temp_db.get_tool_call_summaries(session_id) + assert len(summaries) == 1 + + def test_get_tool_call_summaries_from_different_sessions(self, temp_db): + session_id1 = temp_db.create_session() + session_id2 = temp_db.create_session() + kwargs_json = '{"file_path": "test.txt"}' + + temp_db.record_tool_call_summary(session_id1, "read", kwargs_json, "result1") + temp_db.record_tool_call_summary(session_id2, "read", kwargs_json, "result2") + + summaries1 = temp_db.get_tool_call_summaries(session_id1) + summaries2 = temp_db.get_tool_call_summaries(session_id2) + + assert len(summaries1) == 1 + assert len(summaries2) == 1 + assert summaries1[0][3] != summaries2[0][3] + + def test_session_foreign_key_constraint(self, temp_db): + kwargs_json = '{"file_path": "test.txt"}' + # Trying to insert with non-existent session should raise FK constraint error + # because PRAGMA foreign_keys=ON is enabled + with pytest.raises(sqlite3.IntegrityError): + temp_db.record_tool_call_summary(999, "read", kwargs_json, "result1") + + # No record should be inserted + summaries = temp_db.get_tool_call_summaries(999) + assert len(summaries) == 0 + + def test_index_exists(self, temp_db): + cursor = temp_db._get_connection() + indexes = cursor.execute("SELECT name FROM sqlite_master WHERE type='index' AND name='idx_tool_call_summaries_session'").fetchall() + assert len(indexes) == 1 + + +class TestToolCallSummariesIntegration: + def test_record_tool_call_summary_stores_correct_data(self, temp_db): + session_id = temp_db.create_session() + func_name = "read" + kwargs_json = '{"file_path": "test.txt"}' + result = "file content here" + + temp_db.record_tool_call_summary(session_id, func_name, kwargs_json, result) + + summaries = temp_db.get_tool_call_summaries(session_id) + assert len(summaries) == 1 + assert summaries[0][0] == session_id + assert summaries[0][1] == func_name + assert summaries[0][2] == kwargs_json + assert summaries[0][3] == result + assert summaries[0][4] > 0 # timestamp + + def test_multiple_sessions_isolated(self, temp_db): + session_id1 = temp_db.create_session() + session_id2 = temp_db.create_session() + + kwargs_json = '{"file_path": "test.txt"}' + result1 = "result for session 1" + result2 = "result for session 2" + + temp_db.record_tool_call_summary(session_id1, "read", kwargs_json, result1) + temp_db.record_tool_call_summary(session_id2, "read", kwargs_json, result2) + + summaries1 = temp_db.get_tool_call_summaries(session_id1) + summaries2 = temp_db.get_tool_call_summaries(session_id2) + + assert summaries1[0][3] == result1 + assert summaries2[0][3] == result2 diff --git a/tests/workspace/tools/test_edit_tool.py b/tests/workspace/tools/test_edit_tool.py index 9f17d51..4e87fa4 100644 --- a/tests/workspace/tools/test_edit_tool.py +++ b/tests/workspace/tools/test_edit_tool.py @@ -1,209 +1,207 @@ -"""Edit 工具测试 — 安全的字符串替换编辑工具.""" - -import time -from pathlib import Path - -import pytest - -from src.core.database_manager import DatabaseManager -from src.workspace.workspace import Workspace - - -@pytest.fixture(autouse=True) -def reset_singletons(): - Workspace._instance = None - DatabaseManager.reset_instances() - yield - Workspace._instance = None - DatabaseManager.reset_instances() - - -@pytest.fixture -def workspace(tmp_path: Path) -> Workspace: - ws = Workspace(str(tmp_path)) - ws._current_session_id = ws.db.create_session(name="test_session") - return ws - - -@pytest.fixture -def edit_tool(workspace: Workspace): - from src.workspace.tools.edit_tool import EditTool - - return EditTool(workspace) - - -@pytest.fixture -def read_tool(workspace: Workspace): - from src.workspace.tools.read_tool import ReadTool - - return ReadTool(workspace) - - -def _create_file(workspace: Workspace, path: str, content: str) -> Path: - """Create a file in the workspace.""" - target = workspace.root_path / path - target.parent.mkdir(parents=True, exist_ok=True) - target.write_text(content, encoding="utf-8") - return target - - -class TestEditBasic: - def test_simple_replacement(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "hello world") - result = edit_tool.edit("test.txt", "world", "there") - - assert "Edit Preview" in result.data - assert "Snapshot ID:" in result.data - - def test_does_not_write_to_disk(self, edit_tool, workspace): - file = _create_file(workspace, "test.txt", "hello world") - edit_tool.edit("test.txt", "world", "there") - - assert file.read_text(encoding="utf-8") == "hello world" - - def test_creates_pending_snapshot(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "hello world") - edit_tool.edit("test.txt", "world", "there") - - rows = workspace.db.fetchall("SELECT audit_status, pending_content FROM file_snapshots") - assert len(rows) == 1 - assert rows[0][0] == "PENDING_AUDIT" - assert rows[0][1] == "hello there" - - def test_diff_in_preview(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "line1\nline2\nline3") - result = edit_tool.edit("test.txt", "line2", "modified") - - assert "-line2" in result.data - assert "+modified" in result.data - - def test_multiple_replacements(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "a a a a a") - result = edit_tool.edit("test.txt", "a", "b", max_replacements=3) - - assert "Replacements: 3" in result.data - # Verify pending content has exactly 3 replacements - snap = workspace.db.fetchone("SELECT pending_content FROM file_snapshots") - assert snap is not None - assert snap[0] == "b b b a a" - - def test_no_match_found(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "hello world") - result = edit_tool.edit("test.txt", "nonexistent", "replacement") - - assert result.success is False - assert "No changes made" in result.error - assert "old_string not found" in result.error - - -class TestEditMaxReplacements: - def test_exceeds_max_replacements(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "a a a a a a a a a a a a") # 12 a's - result = edit_tool.edit("test.txt", "a", "b", max_replacements=5) - - assert "Replacements: 5" in result.data - - def test_max_replacements_default_10(self, edit_tool, workspace): - _create_file(workspace, "test.txt", " ".join(["a"] * 20)) - result = edit_tool.edit("test.txt", "a", "b") - - assert "Replacements: 10" in result.data - - def test_max_replacements_capped_at_100(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "a " * 150) - result = edit_tool.edit("test.txt", "a", "b", max_replacements=200) - - assert "Replacements: 100" in result.data - - -class TestEditContextValidation: - def test_context_before_matches(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "prefix target suffix") - result = edit_tool.edit("test.txt", "target", "replaced", context_before="prefix ") - - assert "Edit Preview" in result.data - - def test_context_before_mismatch(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "prefix target suffix") - result = edit_tool.edit("test.txt", "target", "replaced", context_before="wrong ") - - assert result.success is False - assert "context_before" in result.error - assert "mismatch" in result.error.lower() - - def test_context_after_matches(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "prefix target suffix") - result = edit_tool.edit("test.txt", "target", "replaced", context_after=" suffix") - - assert "Edit Preview" in result.data - - def test_context_after_mismatch(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "prefix target suffix") - result = edit_tool.edit("test.txt", "target", "replaced", context_after=" wrong") - - assert result.success is False - assert "context_after" in result.error - assert "mismatch" in result.error.lower() - - def test_both_contexts_match(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "before target after") - result = edit_tool.edit("test.txt", "target", "replaced", context_before="before ", context_after=" after") - - assert "Edit Preview" in result.data - - def test_context_with_multiple_matches(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "before X after\nignore\nbefore X after") - result = edit_tool.edit( - "test.txt", "X", "Y", max_replacements=2, context_before="before ", context_after=" after" - ) - - assert "Edit Preview" in result.data - assert "Replacements: 2" in result.data - - -class TestEditMtimeValidation: - def test_edit_modified_externally_fails(self, edit_tool, read_tool, workspace): - file = _create_file(workspace, "test.txt", "original content") - read_tool.read("test.txt") - - # Modify externally - time.sleep(0.1) - file.write_text("modified externally", encoding="utf-8") - - result = edit_tool.edit("test.txt", "original", "replaced") - assert result.success is False - assert "FILE_MODIFIED_EXTERNALLY" in result.error - - def test_edit_no_prior_read_succeeds(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "original content") - result = edit_tool.edit("test.txt", "original", "updated") - - assert "Edit Preview" in result.data - - -class TestEditEdgeCases: - def test_empty_old_string(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "content") - result = edit_tool.edit("test.txt", "", "replacement") - - assert result.success is False - assert "不能为空" in result.error or "empty" in result.error.lower() - - def test_nonexistent_file(self, edit_tool, workspace): - result = edit_tool.edit("nonexistent.txt", "old", "new") - assert result.success is False - assert "不存在" in result.error or "not found" in result.error.lower() or "exists" in result.error.lower() - - def test_file_outside_workspace(self, edit_tool, workspace): - result = edit_tool.edit("../outside.txt", "old", "new") - assert "越界" in result.error or "boundary" in result.error.lower() or "outside" in result.error.lower() - - def test_edit_snapshot_has_old_hash(self, edit_tool, workspace): - _create_file(workspace, "test.txt", "original") - edit_tool.edit("test.txt", "original", "updated") - - snap = workspace.db.fetchone("SELECT old_hash, new_hash FROM file_snapshots") - assert snap is not None - assert snap[0] is not None - assert snap[1] is not None - assert snap[0] != snap[1] +"""Edit 工具测试 — 安全的字符串替换编辑工具.""" + +import time +from pathlib import Path + +import pytest + +from src.core.database_manager import DatabaseManager +from src.workspace.workspace import Workspace + + +@pytest.fixture(autouse=True) +def reset_singletons(): + Workspace._instance = None + DatabaseManager.reset_instances() + yield + Workspace._instance = None + DatabaseManager.reset_instances() + + +@pytest.fixture +def workspace(tmp_path: Path) -> Workspace: + ws = Workspace(str(tmp_path)) + ws._current_session_id = ws.db.create_session(name="test_session") + return ws + + +@pytest.fixture +def edit_tool(workspace: Workspace): + from src.workspace.tools.edit_tool import EditTool + + return EditTool(workspace) + + +@pytest.fixture +def read_tool(workspace: Workspace): + from src.workspace.tools.read_tool import ReadTool + + return ReadTool(workspace) + + +def _create_file(workspace: Workspace, path: str, content: str) -> Path: + """Create a file in the workspace.""" + target = workspace.root_path / path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + return target + + +class TestEditBasic: + def test_simple_replacement(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "hello world") + result = edit_tool.edit("test.txt", "world", "there") + + assert "Edit Preview" in result.data + assert "Snapshot ID:" in result.data + + def test_does_not_write_to_disk(self, edit_tool, workspace): + file = _create_file(workspace, "test.txt", "hello world") + edit_tool.edit("test.txt", "world", "there") + + assert file.read_text(encoding="utf-8") == "hello world" + + def test_creates_pending_snapshot(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "hello world") + edit_tool.edit("test.txt", "world", "there") + + rows = workspace.db.fetchall("SELECT audit_status, pending_content FROM file_snapshots") + assert len(rows) == 1 + assert rows[0][0] == "PENDING_AUDIT" + assert rows[0][1] == "hello there" + + def test_diff_in_preview(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "line1\nline2\nline3") + result = edit_tool.edit("test.txt", "line2", "modified") + + assert "-line2" in result.data + assert "+modified" in result.data + + def test_multiple_replacements(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "a a a a a") + result = edit_tool.edit("test.txt", "a", "b", max_replacements=3) + + assert "Replacements: 3" in result.data + # Verify pending content has exactly 3 replacements + snap = workspace.db.fetchone("SELECT pending_content FROM file_snapshots") + assert snap is not None + assert snap[0] == "b b b a a" + + def test_no_match_found(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "hello world") + result = edit_tool.edit("test.txt", "nonexistent", "replacement") + + assert result.success is False + assert "No changes made" in result.error + assert "old_string not found" in result.error + + +class TestEditMaxReplacements: + def test_exceeds_max_replacements(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "a a a a a a a a a a a a") # 12 a's + result = edit_tool.edit("test.txt", "a", "b", max_replacements=5) + + assert "Replacements: 5" in result.data + + def test_max_replacements_default_10(self, edit_tool, workspace): + _create_file(workspace, "test.txt", " ".join(["a"] * 20)) + result = edit_tool.edit("test.txt", "a", "b") + + assert "Replacements: 10" in result.data + + def test_max_replacements_capped_at_100(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "a " * 150) + result = edit_tool.edit("test.txt", "a", "b", max_replacements=200) + + assert "Replacements: 100" in result.data + + +class TestEditContextValidation: + def test_context_before_matches(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "prefix target suffix") + result = edit_tool.edit("test.txt", "target", "replaced", context_before="prefix ") + + assert "Edit Preview" in result.data + + def test_context_before_mismatch(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "prefix target suffix") + result = edit_tool.edit("test.txt", "target", "replaced", context_before="wrong ") + + assert result.success is False + assert "context_before" in result.error + assert "mismatch" in result.error.lower() + + def test_context_after_matches(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "prefix target suffix") + result = edit_tool.edit("test.txt", "target", "replaced", context_after=" suffix") + + assert "Edit Preview" in result.data + + def test_context_after_mismatch(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "prefix target suffix") + result = edit_tool.edit("test.txt", "target", "replaced", context_after=" wrong") + + assert result.success is False + assert "context_after" in result.error + assert "mismatch" in result.error.lower() + + def test_both_contexts_match(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "before target after") + result = edit_tool.edit("test.txt", "target", "replaced", context_before="before ", context_after=" after") + + assert "Edit Preview" in result.data + + def test_context_with_multiple_matches(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "before X after\nignore\nbefore X after") + result = edit_tool.edit("test.txt", "X", "Y", max_replacements=2, context_before="before ", context_after=" after") + + assert "Edit Preview" in result.data + assert "Replacements: 2" in result.data + + +class TestEditMtimeValidation: + def test_edit_modified_externally_fails(self, edit_tool, read_tool, workspace): + file = _create_file(workspace, "test.txt", "original content") + read_tool.read("test.txt") + + # Modify externally + time.sleep(0.1) + file.write_text("modified externally", encoding="utf-8") + + result = edit_tool.edit("test.txt", "original", "replaced") + assert result.success is False + assert "FILE_MODIFIED_EXTERNALLY" in result.error + + def test_edit_no_prior_read_succeeds(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "original content") + result = edit_tool.edit("test.txt", "original", "updated") + + assert "Edit Preview" in result.data + + +class TestEditEdgeCases: + def test_empty_old_string(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "content") + result = edit_tool.edit("test.txt", "", "replacement") + + assert result.success is False + assert "不能为空" in result.error or "empty" in result.error.lower() + + def test_nonexistent_file(self, edit_tool, workspace): + result = edit_tool.edit("nonexistent.txt", "old", "new") + assert result.success is False + assert "不存在" in result.error or "not found" in result.error.lower() or "exists" in result.error.lower() + + def test_file_outside_workspace(self, edit_tool, workspace): + result = edit_tool.edit("../outside.txt", "old", "new") + assert "越界" in result.error or "boundary" in result.error.lower() or "outside" in result.error.lower() + + def test_edit_snapshot_has_old_hash(self, edit_tool, workspace): + _create_file(workspace, "test.txt", "original") + edit_tool.edit("test.txt", "original", "updated") + + snap = workspace.db.fetchone("SELECT old_hash, new_hash FROM file_snapshots") + assert snap is not None + assert snap[0] is not None + assert snap[1] is not None + assert snap[0] != snap[1] diff --git a/tests/workspace/tools/test_git_tool.py b/tests/workspace/tools/test_git_tool.py index 19450dd..6527729 100644 --- a/tests/workspace/tools/test_git_tool.py +++ b/tests/workspace/tools/test_git_tool.py @@ -1,196 +1,192 @@ -"""Git 工具测试 — 白名单机制与安全封装.""" - -import subprocess -from pathlib import Path - -import pytest - -from src.core.database_manager import DatabaseManager -from src.workspace.workspace import Workspace - - -# noinspection PyTypeChecker -@pytest.fixture(autouse=True) -def reset_singletons(): - Workspace._instance = None - DatabaseManager.reset_instances() - yield - Workspace._instance = None - DatabaseManager.reset_instances() - - -@pytest.fixture -def git_repo(tmp_path: Path) -> Path: - """创建一个带有初始提交的 git 仓库.""" - repo = tmp_path / "repo" - repo.mkdir(parents=True) - subprocess.run(["git", "init"], cwd=repo, capture_output=True) - subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo, capture_output=True) - subprocess.run(["git", "config", "user.name", "Test"], cwd=repo, capture_output=True) - # Initial commit so commands work against a real repo - (repo / "README.md").write_text("# Test\n", encoding="utf-8") - subprocess.run(["git", "add", "README.md"], cwd=repo, capture_output=True) - subprocess.run(["git", "commit", "-m", "initial"], cwd=repo, capture_output=True) - return repo - - -@pytest.fixture -def workspace(git_repo: Path) -> Workspace: - ws = Workspace(str(git_repo)) - return ws - - -@pytest.fixture -def git_tool(workspace: Workspace): - from src.workspace.tools.git_tool import GitTool - - return GitTool(workspace) - - -class TestGitAllowedCommands: - def test_status(self, git_tool): - result = git_tool.git("status") - assert "nothing to commit" in result.response.lower() or "working tree clean" in result.lower() - - def test_diff(self, git_tool): - result = git_tool.git("diff") - assert "(no output)" in result.response or result.response == "" or result.response == "(no output)" - - def test_log(self, git_tool): - result = git_tool.git("log --oneline -1") - assert "initial" in result.response.lower() - - def test_show(self, git_tool): - result = git_tool.git("show --stat") - assert result.success is True - assert "README" in result.data or "initial" in result.data.lower() - - def test_add_and_commit(self, git_tool, git_repo: Path): - (git_repo / "new_file.txt").write_text("content", encoding="utf-8") - add_result = git_tool.git("add new_file.txt") - assert add_result.success is True - - commit_result = git_tool.git('commit -m "test commit"') - assert commit_result.success is True - assert "commi" in commit_result.data.lower() or "file changed" in commit_result.data.lower() - - def test_branch(self, git_tool): - result = git_tool.git("branch") - assert result.success is True - assert "*" in result.data or "main" in result.data or "master" in result.data - - def test_restore_specific_file(self, git_tool, git_repo: Path): - (git_repo / "README.md").write_text("modified\n", encoding="utf-8") - result = git_tool.git("restore README.md") - assert "failed" not in result.data.lower() - - -class TestGitBlockedCommands: - def test_push_blocked(self, git_tool): - result = git_tool.git("push") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_remote_blocked(self, git_tool): - result = git_tool.git("remote add origin https://example.com/repo.git") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_reset_hard_blocked(self, git_tool): - result = git_tool.git("reset --hard HEAD") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_branch_d_blocked(self, git_tool): - result = git_tool.git("branch -D test") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_merge_blocked(self, git_tool): - result = git_tool.git("merge test-branch") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_rebase_blocked(self, git_tool): - result = git_tool.git("rebase main") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_clean_blocked(self, git_tool): - result = git_tool.git("clean -fd") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_fetch_blocked(self, git_tool): - result = git_tool.git("fetch origin") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - def test_pull_blocked(self, git_tool): - result = git_tool.git("pull origin main") - assert result.success is False - assert "blocked" in result.error.lower() or "ERROR" in result.error - - -class TestGitRestoreSafety: - def test_bare_restore_rejected(self, git_tool): - result = git_tool.git("restore") - assert result.success is False - assert "需要指定文件路径" in result.error or "ERROR" in result.error or "restore" in result.error.lower() - - def test_restore_dot_rejected(self, git_tool): - result = git_tool.git("restore .") - assert result.success is False - assert "通配符" in result.error or "ERROR" in result.error or "restore" in result.error.lower() - - -class TestGitUnknownCommand: - def test_unknown_command_rejected(self, git_tool): - result = git_tool.git("unknown-command") - assert result.success is False - assert "not in the allowed whitelist" in result.error.lower() or "ERROR" in result.error - - -class TestGitIsSafeCommand: - def test_safe_commands(self): - from src.workspace.tools.git_tool import GitTool - - assert GitTool.is_safe_command("status") - assert GitTool.is_safe_command("diff") - assert GitTool.is_safe_command("log --oneline -5") - assert GitTool.is_safe_command("show HEAD") - - def test_modifying_commands_not_safe(self): - from src.workspace.tools.git_tool import GitTool - - assert not GitTool.is_safe_command("add file.txt") - assert not GitTool.is_safe_command('commit -m "msg"') - assert not GitTool.is_safe_command("restore file.txt") - - def test_blocked_commands_not_safe(self): - from src.workspace.tools.git_tool import GitTool - - assert not GitTool.is_safe_command("push") - assert not GitTool.is_safe_command("merge main") - - def test_empty_string(self): - from src.workspace.tools.git_tool import GitTool - - assert not GitTool.is_safe_command("") - assert not GitTool.is_safe_command(" ") - - -class TestGitInjection: - def test_command_injection_via_semicolon(self, git_tool): - result = git_tool.git("status; echo pwned") - assert result.success is False - assert ( - "blocked" in result.error.lower() - or "ERROR" in result.error - or "not in the allowed whitelist" in result.error.lower() - ) - - def test_invalid_shell_syntax(self, git_tool): - result = git_tool.git("status $(whoami)") - assert result.success is True - assert "working tree clean" in result.data.lower() or "nothing to commit" in result.data.lower() +"""Git 工具测试 — 白名单机制与安全封装.""" + +import subprocess +from pathlib import Path + +import pytest + +from src.core.database_manager import DatabaseManager +from src.workspace.workspace import Workspace + + +# noinspection PyTypeChecker +@pytest.fixture(autouse=True) +def reset_singletons(): + Workspace._instance = None + DatabaseManager.reset_instances() + yield + Workspace._instance = None + DatabaseManager.reset_instances() + + +@pytest.fixture +def git_repo(tmp_path: Path) -> Path: + """创建一个带有初始提交的 git 仓库.""" + repo = tmp_path / "repo" + repo.mkdir(parents=True) + subprocess.run(["git", "init"], cwd=repo, capture_output=True) + subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo, capture_output=True) + subprocess.run(["git", "config", "user.name", "Test"], cwd=repo, capture_output=True) + # Initial commit so commands work against a real repo + (repo / "README.md").write_text("# Test\n", encoding="utf-8") + subprocess.run(["git", "add", "README.md"], cwd=repo, capture_output=True) + subprocess.run(["git", "commit", "-m", "initial"], cwd=repo, capture_output=True) + return repo + + +@pytest.fixture +def workspace(git_repo: Path) -> Workspace: + ws = Workspace(str(git_repo)) + return ws + + +@pytest.fixture +def git_tool(workspace: Workspace): + from src.workspace.tools.git_tool import GitTool + + return GitTool(workspace) + + +class TestGitAllowedCommands: + def test_status(self, git_tool): + result = git_tool.git("status") + assert "nothing to commit" in result.response.lower() or "working tree clean" in result.lower() + + def test_diff(self, git_tool): + result = git_tool.git("diff") + assert "(no output)" in result.response or result.response == "" or result.response == "(no output)" + + def test_log(self, git_tool): + result = git_tool.git("log --oneline -1") + assert "initial" in result.response.lower() + + def test_show(self, git_tool): + result = git_tool.git("show --stat") + assert result.success is True + assert "README" in result.data or "initial" in result.data.lower() + + def test_add_and_commit(self, git_tool, git_repo: Path): + (git_repo / "new_file.txt").write_text("content", encoding="utf-8") + add_result = git_tool.git("add new_file.txt") + assert add_result.success is True + + commit_result = git_tool.git('commit -m "test commit"') + assert commit_result.success is True + assert "commi" in commit_result.data.lower() or "file changed" in commit_result.data.lower() + + def test_branch(self, git_tool): + result = git_tool.git("branch") + assert result.success is True + assert "*" in result.data or "main" in result.data or "master" in result.data + + def test_restore_specific_file(self, git_tool, git_repo: Path): + (git_repo / "README.md").write_text("modified\n", encoding="utf-8") + result = git_tool.git("restore README.md") + assert "failed" not in result.data.lower() + + +class TestGitBlockedCommands: + def test_push_blocked(self, git_tool): + result = git_tool.git("push") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_remote_blocked(self, git_tool): + result = git_tool.git("remote add origin https://example.com/repo.git") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_reset_hard_blocked(self, git_tool): + result = git_tool.git("reset --hard HEAD") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_branch_d_blocked(self, git_tool): + result = git_tool.git("branch -D test") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_merge_blocked(self, git_tool): + result = git_tool.git("merge test-branch") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_rebase_blocked(self, git_tool): + result = git_tool.git("rebase main") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_clean_blocked(self, git_tool): + result = git_tool.git("clean -fd") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_fetch_blocked(self, git_tool): + result = git_tool.git("fetch origin") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + def test_pull_blocked(self, git_tool): + result = git_tool.git("pull origin main") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error + + +class TestGitRestoreSafety: + def test_bare_restore_rejected(self, git_tool): + result = git_tool.git("restore") + assert result.success is False + assert "需要指定文件路径" in result.error or "ERROR" in result.error or "restore" in result.error.lower() + + def test_restore_dot_rejected(self, git_tool): + result = git_tool.git("restore .") + assert result.success is False + assert "通配符" in result.error or "ERROR" in result.error or "restore" in result.error.lower() + + +class TestGitUnknownCommand: + def test_unknown_command_rejected(self, git_tool): + result = git_tool.git("unknown-command") + assert result.success is False + assert "not in the allowed whitelist" in result.error.lower() or "ERROR" in result.error + + +class TestGitIsSafeCommand: + def test_safe_commands(self): + from src.workspace.tools.git_tool import GitTool + + assert GitTool.is_safe_command("status") + assert GitTool.is_safe_command("diff") + assert GitTool.is_safe_command("log --oneline -5") + assert GitTool.is_safe_command("show HEAD") + + def test_modifying_commands_not_safe(self): + from src.workspace.tools.git_tool import GitTool + + assert not GitTool.is_safe_command("add file.txt") + assert not GitTool.is_safe_command('commit -m "msg"') + assert not GitTool.is_safe_command("restore file.txt") + + def test_blocked_commands_not_safe(self): + from src.workspace.tools.git_tool import GitTool + + assert not GitTool.is_safe_command("push") + assert not GitTool.is_safe_command("merge main") + + def test_empty_string(self): + from src.workspace.tools.git_tool import GitTool + + assert not GitTool.is_safe_command("") + assert not GitTool.is_safe_command(" ") + + +class TestGitInjection: + def test_command_injection_via_semicolon(self, git_tool): + result = git_tool.git("status; echo pwned") + assert result.success is False + assert "blocked" in result.error.lower() or "ERROR" in result.error or "not in the allowed whitelist" in result.error.lower() + + def test_invalid_shell_syntax(self, git_tool): + result = git_tool.git("status $(whoami)") + assert result.success is True + assert "working tree clean" in result.data.lower() or "nothing to commit" in result.data.lower() diff --git a/tests/workspace/tools/test_write_tool.py b/tests/workspace/tools/test_write_tool.py index 49f0d44..6ab1a82 100644 --- a/tests/workspace/tools/test_write_tool.py +++ b/tests/workspace/tools/test_write_tool.py @@ -1,154 +1,150 @@ -import time -from pathlib import Path - -import pytest - -from src.core.database_manager import DatabaseManager -from src.workspace.workspace import Workspace - - -@pytest.fixture(autouse=True) -def reset_singletons(): - Workspace._instance = None - DatabaseManager.reset_instances() - yield - Workspace._instance = None - DatabaseManager.reset_instances() - - -@pytest.fixture -def workspace(tmp_path: Path) -> Workspace: - ws = Workspace(str(tmp_path)) - ws._current_session_id = ws.db.create_session(name="test_session") - return ws - - -@pytest.fixture -def write_tool(workspace: Workspace): - from src.workspace.tools.write_tool import WriteTool - - return WriteTool(workspace) - - -@pytest.fixture -def read_tool(workspace: Workspace): - from src.workspace.tools.read_tool import ReadTool - - return ReadTool(workspace) - - -class TestWritePreview: - def test_write_returns_preview(self, write_tool, tmp_path: Path): - result = write_tool.write("new_file.txt", "hello") - assert "Write Preview" in result.data - assert "Snapshot ID:" in result.data - - def test_write_does_not_write_to_disk(self, write_tool, tmp_path: Path): - write_tool.write("new_file.txt", "hello") - assert not (tmp_path / "new_file.txt").is_file() - - def test_write_creates_snapshot(self, write_tool, tmp_path: Path): - write_tool.write("new_file.txt", "hello") - - rows = write_tool.workspace.db.fetchall( - "SELECT old_hash, new_hash, audit_status, pending_content FROM file_snapshots" - ) - assert len(rows) == 1 - assert rows[0][0] is None # old_hash for new file - assert rows[0][2] == "PENDING_AUDIT" - assert rows[0][3] == "hello" - - def test_write_preview_contains_diff(self, write_tool, tmp_path: Path): - write_tool.write("new_file.txt", "line1\nline2\nline3") - rows = write_tool.workspace.db.fetchall("SELECT diff_content FROM file_snapshots") - assert len(rows) == 1 - assert "+line1" in rows[0][0] - - -class TestWriteAfterRead: - def test_write_after_read_contains_old_hash(self, write_tool, read_tool, tmp_path: Path): - file = tmp_path / "test.txt" - file.write_text("original", encoding="utf-8") - - read_tool.read("test.txt") - write_tool.write("test.txt", "updated") - - rows = write_tool.workspace.db.fetchall("SELECT old_hash, new_hash FROM file_snapshots") - assert len(rows) == 1 - assert rows[0][0] is not None - assert rows[0][1] != rows[0][0] - - def test_write_after_read_shows_diff(self, write_tool, read_tool, tmp_path: Path): - file = tmp_path / "test.txt" - file.write_text("line1\nline2\nline3", encoding="utf-8") - - read_tool.read("test.txt") - result = write_tool.write("test.txt", "line1\nmodified\nline3") - - assert "Write Preview" in result.data - rows = write_tool.workspace.db.fetchall("SELECT diff_content FROM file_snapshots") - assert len(rows) == 1 - assert "-line2" in rows[0][0] - assert "+modified" in rows[0][0] - - -class TestWriteModifiedExternally: - def test_write_modified_externally_fails(self, write_tool, read_tool, tmp_path: Path): - file = tmp_path / "test.txt" - file.write_text("original", encoding="utf-8") - - read_tool.read("test.txt") - - file.write_text("modified externally", encoding="utf-8") - time.sleep(0.1) - - new_mtime = file.stat().st_mtime - read_record = read_tool.workspace.db.get_file_read_record(read_tool.workspace._current_session_id, "test.txt") - stored_mtime = read_record[3] if read_record else None - - if stored_mtime and abs(new_mtime - stored_mtime) < 0.001: - file.write_text("modified externally again", encoding="utf-8") - time.sleep(0.1) - - result = write_tool.write("test.txt", "should fail") - assert result.success is False - assert "FILE_MODIFIED_EXTERNALLY" in result.error - - def test_write_no_prior_read_succeeds(self, write_tool, tmp_path: Path): - file = tmp_path / "test.txt" - file.write_text("existing content", encoding="utf-8") - - result = write_tool.write("test.txt", "new content") - assert "Write Preview" in result.data - - -class TestWriteSnapshotPendingContent: - def test_pending_content_stored(self, write_tool, tmp_path: Path): - write_tool.write("new_file.txt", "pending content here") - - rows = write_tool.workspace.db.fetchall("SELECT pending_content FROM file_snapshots") - assert len(rows) == 1 - assert rows[0][0] == "pending content here" - - def test_pending_content_for_existing_file(self, write_tool, read_tool, tmp_path: Path): - file = tmp_path / "test.txt" - file.write_text("original", encoding="utf-8") - - read_tool.read("test.txt") - write_tool.write("test.txt", "updated content") - - snap = write_tool.workspace.db.fetchone( - "SELECT pending_content, audit_status FROM file_snapshots WHERE file_path = 'test.txt'" - ) - assert snap is not None - assert snap[0] == "updated content" - assert snap[1] == "PENDING_AUDIT" - - def test_disk_file_unchanged_after_write(self, write_tool, read_tool, tmp_path: Path): - file = tmp_path / "test.txt" - file.write_text("original", encoding="utf-8") - - read_tool.read("test.txt") - write_tool.write("test.txt", "should not change disk") - - assert file.read_text(encoding="utf-8") == "original" +import time +from pathlib import Path + +import pytest + +from src.core.database_manager import DatabaseManager +from src.workspace.workspace import Workspace + + +@pytest.fixture(autouse=True) +def reset_singletons(): + Workspace._instance = None + DatabaseManager.reset_instances() + yield + Workspace._instance = None + DatabaseManager.reset_instances() + + +@pytest.fixture +def workspace(tmp_path: Path) -> Workspace: + ws = Workspace(str(tmp_path)) + ws._current_session_id = ws.db.create_session(name="test_session") + return ws + + +@pytest.fixture +def write_tool(workspace: Workspace): + from src.workspace.tools.write_tool import WriteTool + + return WriteTool(workspace) + + +@pytest.fixture +def read_tool(workspace: Workspace): + from src.workspace.tools.read_tool import ReadTool + + return ReadTool(workspace) + + +class TestWritePreview: + def test_write_returns_preview(self, write_tool, tmp_path: Path): + result = write_tool.write("new_file.txt", "hello") + assert "Write Preview" in result.data + assert "Snapshot ID:" in result.data + + def test_write_does_not_write_to_disk(self, write_tool, tmp_path: Path): + write_tool.write("new_file.txt", "hello") + assert not (tmp_path / "new_file.txt").is_file() + + def test_write_creates_snapshot(self, write_tool, tmp_path: Path): + write_tool.write("new_file.txt", "hello") + + rows = write_tool.workspace.db.fetchall("SELECT old_hash, new_hash, audit_status, pending_content FROM file_snapshots") + assert len(rows) == 1 + assert rows[0][0] is None # old_hash for new file + assert rows[0][2] == "PENDING_AUDIT" + assert rows[0][3] == "hello" + + def test_write_preview_contains_diff(self, write_tool, tmp_path: Path): + write_tool.write("new_file.txt", "line1\nline2\nline3") + rows = write_tool.workspace.db.fetchall("SELECT diff_content FROM file_snapshots") + assert len(rows) == 1 + assert "+line1" in rows[0][0] + + +class TestWriteAfterRead: + def test_write_after_read_contains_old_hash(self, write_tool, read_tool, tmp_path: Path): + file = tmp_path / "test.txt" + file.write_text("original", encoding="utf-8") + + read_tool.read("test.txt") + write_tool.write("test.txt", "updated") + + rows = write_tool.workspace.db.fetchall("SELECT old_hash, new_hash FROM file_snapshots") + assert len(rows) == 1 + assert rows[0][0] is not None + assert rows[0][1] != rows[0][0] + + def test_write_after_read_shows_diff(self, write_tool, read_tool, tmp_path: Path): + file = tmp_path / "test.txt" + file.write_text("line1\nline2\nline3", encoding="utf-8") + + read_tool.read("test.txt") + result = write_tool.write("test.txt", "line1\nmodified\nline3") + + assert "Write Preview" in result.data + rows = write_tool.workspace.db.fetchall("SELECT diff_content FROM file_snapshots") + assert len(rows) == 1 + assert "-line2" in rows[0][0] + assert "+modified" in rows[0][0] + + +class TestWriteModifiedExternally: + def test_write_modified_externally_fails(self, write_tool, read_tool, tmp_path: Path): + file = tmp_path / "test.txt" + file.write_text("original", encoding="utf-8") + + read_tool.read("test.txt") + + file.write_text("modified externally", encoding="utf-8") + time.sleep(0.1) + + new_mtime = file.stat().st_mtime + read_record = read_tool.workspace.db.get_file_read_record(read_tool.workspace._current_session_id, "test.txt") + stored_mtime = read_record[3] if read_record else None + + if stored_mtime and abs(new_mtime - stored_mtime) < 0.001: + file.write_text("modified externally again", encoding="utf-8") + time.sleep(0.1) + + result = write_tool.write("test.txt", "should fail") + assert result.success is False + assert "FILE_MODIFIED_EXTERNALLY" in result.error + + def test_write_no_prior_read_succeeds(self, write_tool, tmp_path: Path): + file = tmp_path / "test.txt" + file.write_text("existing content", encoding="utf-8") + + result = write_tool.write("test.txt", "new content") + assert "Write Preview" in result.data + + +class TestWriteSnapshotPendingContent: + def test_pending_content_stored(self, write_tool, tmp_path: Path): + write_tool.write("new_file.txt", "pending content here") + + rows = write_tool.workspace.db.fetchall("SELECT pending_content FROM file_snapshots") + assert len(rows) == 1 + assert rows[0][0] == "pending content here" + + def test_pending_content_for_existing_file(self, write_tool, read_tool, tmp_path: Path): + file = tmp_path / "test.txt" + file.write_text("original", encoding="utf-8") + + read_tool.read("test.txt") + write_tool.write("test.txt", "updated content") + + snap = write_tool.workspace.db.fetchone("SELECT pending_content, audit_status FROM file_snapshots WHERE file_path = 'test.txt'") + assert snap is not None + assert snap[0] == "updated content" + assert snap[1] == "PENDING_AUDIT" + + def test_disk_file_unchanged_after_write(self, write_tool, read_tool, tmp_path: Path): + file = tmp_path / "test.txt" + file.write_text("original", encoding="utf-8") + + read_tool.read("test.txt") + write_tool.write("test.txt", "should not change disk") + + assert file.read_text(encoding="utf-8") == "original"