Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions config/check_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3
"""Check if Test-Agent update is available. Called by Claude Code Stop hook.

Reads .version from project root, fetches remote VERSION via HTTP.
Prints notification only when newer version available.
Rate-limited: checks at most once per 24h via .version_last_check timestamp.
"""
import os
import time
import urllib.request
import urllib.error

VERSION_URL = "https://raw.githubusercontent.com/Wool-xing/Test-Agent/main/VERSION"
CHECK_INTERVAL = 86400 # 24 hours


def main():
project_root = os.getcwd()
version_file = os.path.join(project_root, ".version")

if not os.path.isfile(version_file):
return # Not a Test-Agent project, skip silently

# Rate limit: check at most once per CHECK_INTERVAL
last_check_file = os.path.join(project_root, ".version_last_check")
now = time.time()
if os.path.isfile(last_check_file):
try:
with open(last_check_file, encoding="utf-8") as f:
last = float(f.read().strip())
if now - last < CHECK_INTERVAL:
return
except (ValueError, OSError):
pass

with open(version_file, encoding="utf-8") as f:
local = f.read().strip()

try:
req = urllib.request.Request(VERSION_URL)
req.add_header("User-Agent", "Test-Agent-version-check/1.0")
resp = urllib.request.urlopen(req, timeout=10)
remote = resp.read().decode().strip()
except (urllib.error.URLError, OSError, ValueError):
return # Network error, skip silently

# Write last check timestamp
try:
with open(last_check_file, "w", encoding="utf-8") as f:
f.write(str(now))
except OSError:
pass

if local != remote:
print(
f"\n📦 Test-Agent {remote} 可用(当前 {local})。"
f"运行 python install.py --update 更新。\n"
)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions config/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"hooks": {
"Stop": [
{
"command": "python check_version.py",
"description": "Check for Test-Agent updates (24h cooldown)"
}
]
}
}
189 changes: 184 additions & 5 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
Git / Node.js — 脚本自动检测并安装(winget / brew / apt)

用法:
python install.py /path/to/your-test-project # 指定目录
python install.py # 默认 ./Test-Agent
python install.py /path/to/your-test-project # 完整安装到指定目录
python install.py # 完整安装,默认 ./Test-Agent
python install.py --update # 轻量更新当前目录
python install.py /path/to/project --update # 轻量更新指定目录

安全提示:不要 pipe-to-python。先下载再审查后执行:
curl -fsSL -o install.py https://raw.githubusercontent.com/Wool-xing/Test-Agent/main/install.py
Expand All @@ -32,9 +34,36 @@
import tempfile
import glob
import platform
import argparse


PROJECT_ROOT = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(), "Test-Agent")
def _parse_args():
"""解析命令行参数。"""
parser = argparse.ArgumentParser(
description="Test-Agent 工作流一键部署脚本",
epilog="安全提示:不要 pipe-to-python。先下载再审查后执行。",
)
parser.add_argument(
"path", nargs="?", default=None,
help="项目目录路径(默认: ./Test-Agent;--update 模式下默认当前目录)",
)
parser.add_argument(
"--update", action="store_true",
help="轻量更新:仅同步新文件 + 依赖,保留用户数据和 .venv",
)
_args = parser.parse_args()
return _args


_ARGS = _parse_args()
UPDATE_MODE = _ARGS.update

if _ARGS.path:
PROJECT_ROOT = _ARGS.path
elif UPDATE_MODE:
PROJECT_ROOT = os.getcwd()
else:
PROJECT_ROOT = os.path.join(os.getcwd(), "Test-Agent")
REPO_URL = os.environ.get("TEST_AGENT_REPO_URL", "https://github.com/Wool-xing/Test-Agent.git")
REPO_BRANCH = os.environ.get("TEST_AGENT_REPO_BRANCH", "main")

Expand All @@ -49,8 +78,9 @@


def banner():
mode = "轻量更新" if UPDATE_MODE else "一键部署"
print("=" * 50)
print(" Test-Agent 工作流一键部署 V1.42.0")
print(f" Test-Agent 工作流{mode}")
print(f" 仓库: {REPO_URL} ({REPO_BRANCH})")
print(f" 项目目录: {PROJECT_ROOT}")
print("=" * 50)
Expand Down Expand Up @@ -318,18 +348,28 @@ def copy_config(template_dir, project_root):
"""拷贝配置文件。"""
print("→ 拷贝配置文件...")
config_dir = os.path.join(template_dir, "config")
files = ["conftest.py", "pytest.ini", ".mcp.json", "requirements.txt"]
files = ["conftest.py", "pytest.ini", ".mcp.json", "requirements.txt", "check_version.py"]
for f in files:
src = os.path.join(config_dir, f)
if os.path.isfile(src):
shutil.copy2(src, project_root)

# .env — 仅在不存在时创建
env_dst = os.path.join(project_root, ".env")
if not os.path.isfile(env_dst):
env_src = os.path.join(config_dir, ".env.example")
if os.path.isfile(env_src):
shutil.copy2(env_src, env_dst)

# .claude/settings.json — 部署版本检查 hook,仅在不存在时创建
claude_dir = os.path.join(project_root, ".claude")
settings_dst = os.path.join(claude_dir, "settings.json")
if not os.path.isfile(settings_dst):
settings_src = os.path.join(config_dir, "settings.json")
if os.path.isfile(settings_src):
os.makedirs(claude_dir, exist_ok=True)
shutil.copy2(settings_src, settings_dst)


def copy_utils(template_dir, project_root):
"""拷贝 utils 目录下所有 .py 文件。"""
Expand Down Expand Up @@ -476,9 +516,143 @@ def _rmtree_onerror(func, path, _exc_info):
func(path)


def _read_template_version(template_dir):
"""读取模板 VERSION 文件。"""
vf = os.path.join(template_dir, "VERSION")
if os.path.isfile(vf):
with open(vf, encoding="utf-8") as f:
return f.read().strip()
return None


def _write_local_version(project_root, version):
"""写入 .version 文件供后续更新检测。"""
vf = os.path.join(project_root, ".version")
with open(vf, "w", encoding="utf-8") as f:
f.write(version + "\n")


def _update_deps(project_root):
"""使用已有 venv 安装/更新 Python 依赖(不重建 venv)。"""
if IS_WINDOWS:
python_exe = os.path.join(project_root, ".venv", "Scripts", "python.exe")
pip_cmd = os.path.join(project_root, ".venv", "Scripts", "pip")
else:
python_exe = os.path.join(project_root, ".venv", "bin", "python")
pip_cmd = os.path.join(project_root, ".venv", "bin", "pip")

if not os.path.isfile(python_exe):
print("⚠️ 未找到虚拟环境,跳过依赖更新")
return

subprocess.run([python_exe, "-m", "pip", "install", "--upgrade", "pip", "-q"], check=True)

# CN 镜像检测
pip_env = os.environ.copy()
if os.environ.get("TEST_AGENT_NO_CN_MIRROR", "0") != "1":
if any([
os.environ.get("LANG", "").startswith(("zh", "CN", "GB")),
timezone_is_cn(),
]):
pip_env["PIP_INDEX_URL"] = "https://pypi.tuna.tsinghua.edu.cn/simple"
pip_env["PIP_TRUSTED_HOST"] = "pypi.tuna.tsinghua.edu.cn"

req_file = os.path.join(project_root, "requirements.txt")
print("→ 更新 Python 依赖...")
if IS_WINDOWS:
with open(req_file, encoding="utf-8") as f:
lines = f.readlines()
filtered = [l for l in lines if not l.startswith(("scikit-image", "scikit-learn", "opencv-python", "opencv-contrib-python"))]
fd, tmp = tempfile.mkstemp(suffix=".txt", prefix="tagent-update-req-")
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.writelines(filtered)
subprocess.run([pip_cmd, "install", "-r", tmp], env=pip_env, check=True)
os.unlink(tmp)
else:
subprocess.run([pip_cmd, "install", "-r", req_file], env=pip_env, check=True)


def do_update():
"""轻量更新:克隆最新模板 → 比较版本 → 拷贝文件 → 更新依赖 → 保留用户数据。"""
version_file = os.path.join(PROJECT_ROOT, ".version")
if not os.path.isfile(version_file):
print(f"❌ 未找到 .version 文件")
print(f" 当前目录: {os.getcwd()}")
print(f" 查找路径: {version_file}")
print(f" 请先执行完整安装:python install.py <目录>")
print(f" 或切换到项目目录后执行:cd <项目目录> && python install.py --update")
sys.exit(1)

with open(version_file, encoding="utf-8") as f:
local_version = f.read().strip()

print(f"→ 当前版本: {local_version}")

template_dir_parent = tempfile.mkdtemp()
template_dir = os.path.join(template_dir_parent, "Test-Agent工作流搭建")

try:
local_src = os.environ.get("TEST_AGENT_LOCAL_SRC")
if local_src:
print(f"→ [dev mode] 复制本地源代码: {local_src} → {template_dir}")
shutil.copytree(local_src, template_dir)
else:
print("→ 检查更新...")
subprocess.run(
["git", "clone", "--depth", "1", "--branch", REPO_BRANCH, REPO_URL, template_dir],
check=True,
)

remote_version = _read_template_version(template_dir)
if remote_version is None:
print("❌ 无法读取远程版本信息")
sys.exit(1)

if local_version == remote_version:
print(f"✓ 已是最新版本 ({local_version})")
return

print(f"→ 新版本可用: {local_version} → {remote_version}")
print("→ 开始轻量更新(保留用户数据和 .venv)...")

# 备份用户数据
backed = backup_user_data(PROJECT_ROOT)

# 拷贝新文件(跳过 create_dirs / setup_venv / claude code 安装)
copy_agents(template_dir, PROJECT_ROOT)
copy_skills(template_dir, PROJECT_ROOT)
copy_config(template_dir, PROJECT_ROOT)
copy_utils(template_dir, PROJECT_ROOT)
copy_ci(template_dir, PROJECT_ROOT)
copy_top_level_docs(template_dir, PROJECT_ROOT)

# 恢复用户数据
restore_user_data(PROJECT_ROOT, backed)

# 更新依赖
_update_deps(PROJECT_ROOT)

# 写回新版本号
_write_local_version(PROJECT_ROOT, remote_version)

print("=" * 50)
print(f" ✅ 已更新到 {remote_version}")
print("=" * 50)

finally:
if os.path.isdir(template_dir_parent):
shutil.rmtree(template_dir_parent, onerror=_rmtree_onerror)
# cleanup backup tmp if any leftover (restore_user_data usually handles this)
# handled in finally block of main, but do_update has its own finally


def main():
banner()

if UPDATE_MODE:
do_update()
return

# 1. 检查 + 自动安装前置工具
ensure_prerequisites()
python_bin = find_python()
Expand Down Expand Up @@ -525,6 +699,11 @@ def main():
# 8. 恢复用户数据
restore_user_data(PROJECT_ROOT, backed)

# 9. 写入 .version 供后续更新检测
version = _read_template_version(template_dir)
if version:
_write_local_version(PROJECT_ROOT, version)

finish(PROJECT_ROOT)

finally:
Expand Down
Loading