From c24127ade5e205ba7bef042f627222edc8d10550 Mon Sep 17 00:00:00 2001 From: townwish Date: Wed, 11 Mar 2026 16:38:11 +0800 Subject: [PATCH 1/2] =?UTF-8?q?refactor:=20=E6=95=B4=E7=90=86config?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/executor.yaml | 9 ---- configs/loader.go | 12 ++--- configs/mscli.example.yaml | 99 ++++++++++++++++++++++++++++++++++++++ configs/mscli.yaml | 24 --------- configs/skills.yaml | 6 --- 5 files changed, 104 insertions(+), 46 deletions(-) delete mode 100644 configs/executor.yaml create mode 100644 configs/mscli.example.yaml delete mode 100644 configs/mscli.yaml delete mode 100644 configs/skills.yaml diff --git a/configs/executor.yaml b/configs/executor.yaml deleted file mode 100644 index 8de37a4..0000000 --- a/configs/executor.yaml +++ /dev/null @@ -1,9 +0,0 @@ -execution: - mode: local - timeout_sec: 1800 - max_concurrency: 2 -docker: - image: ubuntu:22.04 - cpu: "2" - memory: "4g" - network: none diff --git a/configs/loader.go b/configs/loader.go index 59f1e21..67af919 100644 --- a/configs/loader.go +++ b/configs/loader.go @@ -69,21 +69,19 @@ func LoadWithEnv(path string) (*Config, error) { } // FindConfigFile searches for config file in standard locations. +// Priority: MSCLI_CONFIG > ./.mscli/config.yaml > ~/.config/mscli/config.yaml > ~/.mscli/config.yaml func FindConfigFile() string { // Check environment variable if path := os.Getenv("MSCLI_CONFIG"); path != "" { return path } - // Check current directory - if _, err := os.Stat("mscli.yaml"); err == nil { - return "mscli.yaml" - } - if _, err := os.Stat("configs/mscli.yaml"); err == nil { - return "configs/mscli.yaml" + // Check project-level config + if _, err := os.Stat(".mscli/config.yaml"); err == nil { + return ".mscli/config.yaml" } - // Check config directories + // Check user-level config home, err := os.UserHomeDir() if err == nil { paths := []string{ diff --git a/configs/mscli.example.yaml b/configs/mscli.example.yaml new file mode 100644 index 0000000..deae48c --- /dev/null +++ b/configs/mscli.example.yaml @@ -0,0 +1,99 @@ +# ms-cli 配置示例文件 +# +# 使用方法: +# 1. 复制此文件到 .mscli/config.yaml(推荐,已被 .gitignore 忽略,不会提交到 git) +# cp configs/mscli.example.yaml .mscli/config.yaml +# +# 2. 或者复制到用户级配置目录(全局生效) +# cp configs/mscli.example.yaml ~/.mscli/config.yaml +# # 或遵循 XDG 规范 +# mkdir -p ~/.config/mscli +# cp configs/mscli.example.yaml ~/.config/mscli/config.yaml +# +# 3. 根据你的实际情况修改配置项 +# +# 配置加载优先级(从高到低): +# 1. 命令行参数 (--url, --model, --api-key) +# 2. 环境变量 (MSCLI_* / OPENAI_*) +# 3. 配置文件(依次搜索直至找到第一个存在的文件): +# - MSCLI_CONFIG 环境变量指定的文件 +# - ./.mscli/config.yaml(项目级) +# - ~/.config/mscli/config.yaml(用户级) +# - ~/.mscli/config.yaml(用户级备选) +# 4. 内置默认值 +# +# 注意:API key 建议通过环境变量设置,避免提交到版本控制 +# export MSCLI_API_KEY="your-api-key" + +model: + # OpenAI-compatible API base URL + # 环境变量覆盖:MSCLI_BASE_URL(优先)或 OPENAI_BASE_URL + url: https://api.openai.com/v1 + + # 模型名称 + # 环境变量覆盖:MSCLI_MODEL(优先)或 OPENAI_MODEL + model: gpt-4o-mini + + # API 密钥 + # 强烈建议通过环境变量设置:MSCLI_API_KEY 或 OPENAI_API_KEY + # 如果在此处设置,请确保 .mscli/ 目录已被 .gitignore 忽略 + key: "" + + # 温度参数 (0-2),控制输出的随机性 + # 环境变量覆盖:MSCLI_TEMPERATURE + temperature: 0.7 + + # 最大生成 token 数 + # 环境变量覆盖:MSCLI_MAX_TOKENS + max_tokens: 4096 + + # 请求超时时间(秒) + # 环境变量覆盖:MSCLI_TIMEOUT + timeout_sec: 180 + +budget: + # 单次会话最大 token 数 + max_tokens: 32768 + + # 单次会话最大成本(美元) + max_cost_usd: 10 + +ui: + # 是否启用 TUI 界面 + enabled: true + + # 是否显示 token 使用条 + show_token_bar: true + +permissions: + # 是否跳过权限请求(自动允许所有操作,谨慎使用) + skip_requests: false + + # 默认权限级别:ask(询问)/ allow(允许)/ deny(拒绝) + default_level: ask + + # 允许使用的工具列表(为空表示允许所有) + allowed_tools: [] + +context: + # 上下文最大 token 数 + max_tokens: 24000 + + # 预留 token 数(用于生成回复) + reserve_tokens: 4000 + + # 上下文压缩阈值(当使用量超过此比例时触发压缩) + compaction_threshold: 0.85 + +memory: + # 是否启用记忆系统 + enabled: true + + # 最大记忆条目数 + max_items: 200 + + # 最大记忆大小(字节) + max_bytes: 2097152 + + # 记忆 TTL(小时) + ttl_hours: 168 diff --git a/configs/mscli.yaml b/configs/mscli.yaml deleted file mode 100644 index dea595f..0000000 --- a/configs/mscli.yaml +++ /dev/null @@ -1,24 +0,0 @@ -model: - # OpenAI-compatible API base URL - # Environment variables (higher priority): MSCLI_BASE_URL (fallback: OPENAI_BASE_URL) - url: https://api.openai.com/v1 - # Model name - # Environment variables (higher priority): MSCLI_MODEL (fallback: OPENAI_MODEL) - model: gpt-4o-mini - # API key (recommended to use env: MSCLI_API_KEY / OPENAI_API_KEY) - key: "" -budget: - max_tokens: 32768 - max_cost_usd: 10 -ui: - enabled: true -permissions: - skip_requests: false - allowed_tools: [] -context: - max_tokens: 24000 - compaction_threshold: 0.85 -memory: - max_items: 200 - max_bytes: 2097152 - ttl_hours: 168 diff --git a/configs/skills.yaml b/configs/skills.yaml deleted file mode 100644 index 2b85e96..0000000 --- a/configs/skills.yaml +++ /dev/null @@ -1,6 +0,0 @@ -skills: - repo: https://github.com/vigo/mindspore-skills.git - revision: main - cache_dir: .cache/skills - workflows: - - add-algo-feature From 9dadcbe72c8136b94eca2605f2e0c40efc6bdf26 Mon Sep 17 00:00:00 2001 From: townwish Date: Wed, 11 Mar 2026 17:43:47 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81ssh=E6=93=8D?= =?UTF-8?q?=E4=BD=9C=E8=BF=9C=E7=A8=8B=E6=9C=BA=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/types.go | 16 ++ go.mod | 7 +- go.sum | 10 + internal/app/wire.go | 21 +- runtime/ssh/pool.go | 304 +++++++++++++++++++++++++++++ runtime/ssh/pool_test.go | 246 ++++++++++++++++++++++++ runtime/ssh/result.go | 28 +++ runtime/ssh/session.go | 252 ++++++++++++++++++++++++ runtime/ssh/sftp.go | 308 +++++++++++++++++++++++++++++ tools/remote/fs.go | 384 +++++++++++++++++++++++++++++++++++++ tools/remote/shell.go | 148 ++++++++++++++ tools/remote/shell_test.go | 326 +++++++++++++++++++++++++++++++ 12 files changed, 2046 insertions(+), 4 deletions(-) create mode 100644 runtime/ssh/pool.go create mode 100644 runtime/ssh/pool_test.go create mode 100644 runtime/ssh/result.go create mode 100644 runtime/ssh/session.go create mode 100644 runtime/ssh/sftp.go create mode 100644 tools/remote/fs.go create mode 100644 tools/remote/shell.go create mode 100644 tools/remote/shell_test.go diff --git a/configs/types.go b/configs/types.go index 217d7a3..52727eb 100644 --- a/configs/types.go +++ b/configs/types.go @@ -15,6 +15,7 @@ type Config struct { Memory MemoryConfig `yaml:"memory"` Skills SkillsConfig `yaml:"skills"` Execution ExecutionConfig `yaml:"execution"` + SSH SSHConfig `yaml:"ssh,omitempty"` } // ModelConfig holds the LLM model configuration. @@ -94,6 +95,21 @@ type DockerConfig struct { Env map[string]string `yaml:"env,omitempty"` } +// SSHConfig holds the SSH remote execution configuration. +type SSHConfig struct { + Hosts map[string]HostConfig `yaml:"hosts,omitempty"` // 预配置的主机别名 + DefaultTimeout int `yaml:"default_timeout"` // 默认超时秒数 +} + +// HostConfig holds the configuration for a single SSH host. +type HostConfig struct { + Address string `yaml:"address"` // 主机地址(IP 或域名) + User string `yaml:"user"` // 用户名 + KeyPath string `yaml:"key_path,omitempty"` // SSH 私钥路径(优先) + Password string `yaml:"password,omitempty"` // SSH 密码(备选) + Port int `yaml:"port,omitempty"` // 端口(默认 22) +} + // DefaultConfig returns a configuration with default values. func DefaultConfig() *Config { return &Config{ diff --git a/go.mod b/go.mod index 730e5e5..b767999 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.5.0 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/kr/fs v0.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -30,8 +31,10 @@ require ( github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect + github.com/pkg/sftp v1.13.10 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - golang.org/x/sys v0.38.0 // indirect - golang.org/x/text v0.3.8 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect ) diff --git a/go.sum b/go.sum index 70d1474..f241d52 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -40,18 +42,26 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= +github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/app/wire.go b/internal/app/wire.go index 0b9da8f..0c26095 100644 --- a/internal/app/wire.go +++ b/internal/app/wire.go @@ -20,8 +20,10 @@ import ( itrain "github.com/vigo999/ms-cli/internal/train" "github.com/vigo999/ms-cli/permission" rshell "github.com/vigo999/ms-cli/runtime/shell" + rssh "github.com/vigo999/ms-cli/runtime/ssh" "github.com/vigo999/ms-cli/tools" "github.com/vigo999/ms-cli/tools/fs" + "github.com/vigo999/ms-cli/tools/remote" "github.com/vigo999/ms-cli/tools/shell" "github.com/vigo999/ms-cli/trace" "github.com/vigo999/ms-cli/ui/model" @@ -48,6 +50,7 @@ type Application struct { permService permission.PermissionService stateManager *configs.StateManager traceWriter trace.Writer + sshPool *rssh.Pool // Train mode state trainMode bool @@ -130,7 +133,10 @@ func Wire(cfg BootstrapConfig) (*Application, error) { } } - toolRegistry := initTools(config, workDir) + // Initialize SSH pool for remote execution + sshPool := rssh.NewPool(config.SSH) + + toolRegistry := initTools(config, workDir, sshPool) ctxManager := agentctx.NewManager(agentctx.ManagerConfig{ MaxTokens: config.Context.MaxTokens, @@ -184,6 +190,7 @@ func Wire(cfg BootstrapConfig) (*Application, error) { permService: permService, stateManager: stateManager, traceWriter: traceWriter, + sshPool: sshPool, llmReady: llmReady, }, nil } @@ -287,7 +294,7 @@ func initProvider(cfg configs.ModelConfig) (llm.Provider, error) { return client, nil } -func initTools(cfg *configs.Config, workDir string) *tools.Registry { +func initTools(cfg *configs.Config, workDir string, sshPool *rssh.Pool) *tools.Registry { registry := tools.NewRegistry() registry.MustRegister(fs.NewReadTool(workDir)) @@ -305,5 +312,15 @@ func initTools(cfg *configs.Config, workDir string) *tools.Registry { }) registry.MustRegister(shell.NewShellTool(shellRunner)) + // Register remote SSH tools + if sshPool != nil { + registry.MustRegister(remote.NewShellTool(sshPool)) + registry.MustRegister(remote.NewReadTool(sshPool)) + registry.MustRegister(remote.NewWriteTool(sshPool)) + registry.MustRegister(remote.NewEditTool(sshPool)) + registry.MustRegister(remote.NewGlobTool(sshPool)) + registry.MustRegister(remote.NewGrepTool(sshPool)) + } + return registry } diff --git a/runtime/ssh/pool.go b/runtime/ssh/pool.go new file mode 100644 index 0000000..5878859 --- /dev/null +++ b/runtime/ssh/pool.go @@ -0,0 +1,304 @@ +// Package ssh provides SSH remote execution runtime with stateful session management. +package ssh + +import ( + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/vigo999/ms-cli/configs" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +// ConnectOptions contains dynamic connection parameters from tool calls. +type ConnectOptions struct { + Host string // 目标地址(或别名) + User string // 用户名(可选,覆盖配置) + Password string // 密码(可选,覆盖配置) + KeyPath string // 私钥路径(可选,覆盖配置) + Port int // 端口(可选,默认22) + Timeout time.Duration // 超时(可选) +} + +// Pool manages SSH sessions with connection reuse. +type Pool struct { + mu sync.RWMutex + sessions map[string]*Session // key: normalized host+user + config configs.SSHConfig + agentConn net.Conn // SSH agent connection (kept open for auth) +} + +// NewPool creates a new SSH connection pool. +func NewPool(cfg configs.SSHConfig) *Pool { + return &Pool{ + sessions: make(map[string]*Session), + config: cfg, + } +} + +// Close closes the pool and all its sessions. +func (p *Pool) Close() error { + if err := p.CloseAll(); err != nil { + return err + } + if p.agentConn != nil { + p.agentConn.Close() + p.agentConn = nil + } + return nil +} + +// Get gets or creates a session with the given options. +// Parameter merge priority: opts > pre-configured HostConfig > defaults +func (p *Pool) Get(opts ConnectOptions) (*Session, error) { + // Resolve host alias to actual config + resolved := p.resolveHost(opts.Host) + + // Merge options (dynamic opts override pre-configured values) + merged := p.mergeOptions(opts, resolved) + + // Validate required fields + if merged.Host == "" { + return nil, fmt.Errorf("host is required") + } + if merged.User == "" { + return nil, fmt.Errorf("user is required for host %s", merged.Host) + } + if merged.Port == 0 { + merged.Port = 22 + } + if merged.Timeout == 0 { + merged.Timeout = time.Duration(p.config.DefaultTimeout) * time.Second + if merged.Timeout == 0 { + merged.Timeout = 60 * time.Second + } + } + + // Generate session key (normalized host + user) + key := fmt.Sprintf("%s@%s:%d", merged.User, merged.Host, merged.Port) + + // Check for existing session + p.mu.RLock() + if sess, ok := p.sessions[key]; ok && sess.IsAlive() { + p.mu.RUnlock() + return sess, nil + } + p.mu.RUnlock() + + // Create new session + p.mu.Lock() + defer p.mu.Unlock() + + // Double-check after acquiring write lock + if sess, ok := p.sessions[key]; ok && sess.IsAlive() { + return sess, nil + } + + sess, err := p.createSession(merged) + if err != nil { + return nil, fmt.Errorf("failed to create SSH session for %s: %w", key, err) + } + + p.sessions[key] = sess + return sess, nil +} + +// CloseAll closes all sessions in the pool. +func (p *Pool) CloseAll() error { + p.mu.Lock() + defer p.mu.Unlock() + + var errs []error + for key, sess := range p.sessions { + if err := sess.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close session %s: %w", key, err)) + } + delete(p.sessions, key) + } + + if len(errs) > 0 { + return fmt.Errorf("errors closing sessions: %v", errs) + } + return nil +} + +// resolveHost resolves a host alias to its configuration. +// Also supports reverse lookup by IP address. +func (p *Pool) resolveHost(hostOrAlias string) configs.HostConfig { + // Check if it's an alias in the config + if cfg, ok := p.config.Hosts[hostOrAlias]; ok { + return cfg + } + + // Try reverse lookup: check if any configured host has this address + for _, cfg := range p.config.Hosts { + if cfg.Address == hostOrAlias { + return cfg + } + } + + // Treat as direct host address + return configs.HostConfig{ + Address: hostOrAlias, + } +} + +// mergeOptions merges dynamic options with pre-configured values. +// Priority: opts > cfg > defaults +func (p *Pool) mergeOptions(opts ConnectOptions, cfg configs.HostConfig) ConnectOptions { + result := ConnectOptions{ + Host: cfg.Address, + User: cfg.User, + Password: cfg.Password, + KeyPath: cfg.KeyPath, + Port: cfg.Port, + } + + // Override with dynamic options if provided + if opts.Host != "" { + result.Host = opts.Host // Keep original if it was an alias (already resolved) + } + if opts.User != "" { + result.User = opts.User + } + if opts.Password != "" { + result.Password = opts.Password + } + if opts.KeyPath != "" { + result.KeyPath = opts.KeyPath + } + if opts.Port != 0 { + result.Port = opts.Port + } + if opts.Timeout != 0 { + result.Timeout = opts.Timeout + } + + return result +} + +// createSession creates a new SSH session with the given options. +func (p *Pool) createSession(opts ConnectOptions) (*Session, error) { + sshConfig := &ssh.ClientConfig{ + User: opts.User, + Timeout: opts.Timeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: Support known_hosts + } + + // Setup authentication methods + authMethods := []ssh.AuthMethod{} + + // 1. Try key file if explicitly specified + keyPath := expandHome(opts.KeyPath) + if keyPath != "" { + if keyAuth, err := p.getKeyAuth(keyPath); err == nil { + authMethods = append(authMethods, keyAuth) + } else { + // Log key error for debugging (but don't fail yet, try other methods) + fmt.Fprintf(os.Stderr, "ssh: failed to load key %s: %v\n", keyPath, err) + } + } + + // 2. Try SSH agent + if agentAuth, ok := p.getAgentAuth(); ok { + authMethods = append(authMethods, agentAuth) + } + + // 3. Try default key paths (only if no explicit key and no agent) + if keyPath == "" && len(authMethods) == 0 { + for _, defaultPath := range []string{ + filepath.Join(os.Getenv("HOME"), ".ssh", "id_rsa"), + filepath.Join(os.Getenv("HOME"), ".ssh", "id_ed25519"), + filepath.Join(os.Getenv("HOME"), ".ssh", "id_ecdsa"), + } { + if _, err := os.Stat(defaultPath); err == nil { + if keyAuth, err := p.getKeyAuth(defaultPath); err == nil { + authMethods = append(authMethods, keyAuth) + break + } + } + } + } + + // 5. Try password + if opts.Password != "" { + authMethods = append(authMethods, ssh.Password(opts.Password)) + } + + if len(authMethods) == 0 { + return nil, fmt.Errorf("no authentication method available") + } + + sshConfig.Auth = authMethods + + // Connect + addr := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + client, err := ssh.Dial("tcp", addr, sshConfig) + if err != nil { + return nil, fmt.Errorf("failed to dial SSH: %w", err) + } + + return &Session{ + client: client, + host: opts.Host, + user: opts.User, + port: opts.Port, + env: make(map[string]string), + workDir: "", + }, nil +} + +// getAgentAuth returns SSH agent authentication if available. +// The agent connection is kept open in the pool and closed when the pool is closed. +func (p *Pool) getAgentAuth() (ssh.AuthMethod, bool) { + // Reuse existing agent connection if available + if p.agentConn != nil { + return ssh.PublicKeysCallback(agent.NewClient(p.agentConn).Signers), true + } + + socket := os.Getenv("SSH_AUTH_SOCK") + if socket == "" { + return nil, false + } + + conn, err := net.Dial("unix", socket) + if err != nil { + return nil, false + } + + // Store connection in pool for reuse (will be closed when pool is closed) + p.agentConn = conn + return ssh.PublicKeysCallback(agent.NewClient(p.agentConn).Signers), true +} + +// getKeyAuth returns public key authentication from the given key file. +func (p *Pool) getKeyAuth(keyPath string) (ssh.AuthMethod, error) { + key, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read key file: %w", err) + } + + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + // Try with empty passphrase (for encrypted keys, user should use SSH agent) + return nil, fmt.Errorf("failed to parse key (encrypted keys require SSH agent): %w", err) + } + + return ssh.PublicKeys(signer), nil +} + +// expandHome expands the ~ in the given path to the user's home directory. +func expandHome(path string) string { + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err == nil { + return filepath.Join(home, path[2:]) + } + } + return path +} diff --git a/runtime/ssh/pool_test.go b/runtime/ssh/pool_test.go new file mode 100644 index 0000000..730e99a --- /dev/null +++ b/runtime/ssh/pool_test.go @@ -0,0 +1,246 @@ +// Package ssh provides SSH remote execution runtime with stateful session management. +package ssh + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/vigo999/ms-cli/configs" +) + +// TestCreateSession_WithConfigHosts 使用 .mscli/config.yaml 中配置的主机测试 createSession 功能 +// 这是一个集成测试,需要实际的 SSH 服务器和密钥 +func TestCreateSession_WithConfigHosts(t *testing.T) { + // 诊断信息 + t.Logf("SSH_AUTH_SOCK: %s", os.Getenv("SSH_AUTH_SOCK")) + t.Logf("HOME: %s", os.Getenv("HOME")) + + // 加载配置文件 + cfg, err := loadTestConfig() + if err != nil { + t.Skipf("无法加载配置文件: %v", err) + } + + if len(cfg.SSH.Hosts) == 0 { + t.Skip("配置文件中没有 SSH 主机配置") + } + + pool := NewPool(cfg.SSH) + defer pool.Close() + + // 测试每个配置的主机 + for alias, hostCfg := range cfg.SSH.Hosts { + t.Run(alias, func(t *testing.T) { + // 使用实际 IP 地址而非别名,因为 mergeOptions 会保留原始 Host 值 + opts := ConnectOptions{ + Host: hostCfg.Address, + User: hostCfg.User, + KeyPath: expandHome(hostCfg.KeyPath), + Port: hostCfg.Port, + Timeout: 10 * time.Second, + } + + session, err := pool.Get(opts) + if err != nil { + t.Fatalf("无法连接到主机 %s (%s): %v", alias, hostCfg.Address, err) + } + + if session == nil { + t.Fatal("session 不应为 nil") + } + + // 验证会话状态 + if !session.IsAlive() { + t.Error("新创建的会话应该是活跃的") + } + + t.Logf("✓ 成功连接到 %s (%s@%s:%d)", alias, hostCfg.User, hostCfg.Address, hostCfg.Port) + }) + } +} + +// TestCreateSession_ConnectionReuse 测试连接复用功能 +func TestCreateSession_ConnectionReuse(t *testing.T) { + cfg, err := loadTestConfig() + if err != nil { + t.Skipf("无法加载配置文件: %v", err) + } + + if len(cfg.SSH.Hosts) == 0 { + t.Skip("配置文件中没有 SSH 主机配置") + } + + pool := NewPool(cfg.SSH) + defer pool.Close() + + // 获取第一个配置的主机 + var alias string + var hostCfg configs.HostConfig + for a, h := range cfg.SSH.Hosts { + alias = a + hostCfg = h + break + } + + t.Logf("测试主机 %s: %s@%s:%d", alias, hostCfg.User, hostCfg.Address, hostCfg.Port) + + opts := ConnectOptions{ + Host: hostCfg.Address, + User: hostCfg.User, + KeyPath: expandHome(hostCfg.KeyPath), + Port: hostCfg.Port, + Timeout: 10 * time.Second, + } + + // 第一次连接 + session1, err := pool.Get(opts) + if err != nil { + t.Fatalf("无法连接到主机 %s (%s): %v", alias, hostCfg.Address, err) + } + + // 第二次连接(应该复用) + session2, err := pool.Get(opts) + if err != nil { + t.Fatalf("第二次连接失败: %v", err) + } + + // 验证是同一个会话 + if session1 != session2 { + t.Error("相同主机的连接应该被复用") + } + + t.Logf("✓ 连接复用成功: %s@%s:%d", hostCfg.User, hostCfg.Address, hostCfg.Port) +} + +// TestCreateSession_InvalidHost 测试无效主机连接失败 +func TestCreateSession_InvalidHost(t *testing.T) { + cfg := configs.SSHConfig{ + DefaultTimeout: 5, + } + + pool := NewPool(cfg) + defer pool.Close() + + opts := ConnectOptions{ + Host: "invalid.host.that.does.not.exist.example.com", + User: "test", + Port: 22, + Timeout: 2 * time.Second, + } + + _, err := pool.Get(opts) + if err == nil { + t.Error("无效主机应该返回错误") + } else { + t.Logf("✓ 无效主机正确返回错误: %v", err) + } +} + +// TestCreateSession_MissingAuth 测试缺少认证信息时失败 +func TestCreateSession_MissingAuth(t *testing.T) { + // 创建一个临时目录,确保没有 SSH 密钥 + tmpDir := t.TempDir() + + // 设置临时的 HOME 目录,避免读取用户默认的 SSH 密钥 + origHome := os.Getenv("HOME") + os.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + // 清除 SSH agent + origAgent := os.Getenv("SSH_AUTH_SOCK") + os.Unsetenv("SSH_AUTH_SOCK") + defer os.Setenv("SSH_AUTH_SOCK", origAgent) + + cfg := configs.SSHConfig{ + DefaultTimeout: 5, + } + + pool := NewPool(cfg) + defer pool.Close() + + opts := ConnectOptions{ + Host: "localhost", + User: "testuser", + Port: 22, + Timeout: 2 * time.Second, + } + + _, err := pool.Get(opts) + if err == nil { + t.Error("缺少认证信息应该返回错误") + } else { + t.Logf("✓ 缺少认证正确返回错误: %v", err) + } +} + +// TestCreateSession_DirectAddress 测试直接使用 IP 地址而非别名 +func TestCreateSession_DirectAddress(t *testing.T) { + cfg, err := loadTestConfig() + if err != nil { + t.Skipf("无法加载配置文件: %v", err) + } + + if len(cfg.SSH.Hosts) == 0 { + t.Skip("配置文件中没有 SSH 主机配置") + } + + // 获取第一个配置的主机,使用其地址直接连接 + var hostCfg configs.HostConfig + for _, h := range cfg.SSH.Hosts { + hostCfg = h + break + } + + pool := NewPool(cfg.SSH) + defer pool.CloseAll() + + opts := ConnectOptions{ + Host: hostCfg.Address, + User: hostCfg.User, + KeyPath: expandHome(hostCfg.KeyPath), + Port: hostCfg.Port, + Timeout: 10 * time.Second, + } + + session, err := pool.Get(opts) + if err != nil { + t.Fatalf("无法连接到主机 %s: %v", hostCfg.Address, err) + } + + if !session.IsAlive() { + t.Error("会话应该是活跃的") + } + + t.Logf("✓ 直接使用地址连接成功: %s@%s:%d", hostCfg.User, hostCfg.Address, hostCfg.Port) +} + +// loadTestConfig 加载测试配置文件 +func loadTestConfig() (*configs.Config, error) { + // 尝试从多个位置加载配置 + configPaths := []string{ + ".mscli/config.yaml", + filepath.Join("..", "..", ".mscli/config.yaml"), + } + + for _, path := range configPaths { + if _, err := os.Stat(path); err == nil { + return configs.LoadFromFile(path) + } + } + + return nil, os.ErrNotExist +} + +// expandHome 展开路径中的 ~ +func expandHome(path string) string { + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err == nil { + return filepath.Join(home, path[2:]) + } + } + return path +} diff --git a/runtime/ssh/result.go b/runtime/ssh/result.go new file mode 100644 index 0000000..15edcd7 --- /dev/null +++ b/runtime/ssh/result.go @@ -0,0 +1,28 @@ +package ssh + +import "time" + +// Result holds the result of an SSH command execution. +type Result struct { + Stdout string // 标准输出 + Stderr string // 标准错误 + ExitCode int // 退出码(-1 表示执行错误) + Error error // 执行错误 + Duration time.Duration // 执行耗时 +} + +// Success returns true if the command executed successfully (exit code 0). +func (r *Result) Success() bool { + return r.ExitCode == 0 && r.Error == nil +} + +// Combined returns stdout and stderr combined. +func (r *Result) Combined() string { + if r.Stderr == "" { + return r.Stdout + } + if r.Stdout == "" { + return r.Stderr + } + return r.Stdout + "\n" + r.Stderr +} diff --git a/runtime/ssh/session.go b/runtime/ssh/session.go new file mode 100644 index 0000000..2e6538c --- /dev/null +++ b/runtime/ssh/session.go @@ -0,0 +1,252 @@ +package ssh + +import ( + "bytes" + "context" + "fmt" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +// Session represents a stateful SSH session with remote state caching. +type Session struct { + client *ssh.Client + sftp *sftp.Client // lazily initialized + host string + user string + port int + workDir string // cached remote working directory + env map[string]string // cached remote environment variables + mu sync.RWMutex + sftpMu sync.Once +} + +// Run executes a command on the remote host with state preservation. +// Automatically handles workDir and env injection. +func (s *Session) Run(ctx context.Context, cmd string) (*Result, error) { + s.mu.RLock() + workDir := s.workDir + env := make(map[string]string, len(s.env)) + for k, v := range s.env { + env[k] = v + } + s.mu.RUnlock() + + // Build actual command with state + var parts []string + + // Add environment variables + for k, v := range env { + parts = append(parts, fmt.Sprintf("export %s=%q", k, v)) + } + + // Add directory change + if workDir != "" { + parts = append(parts, fmt.Sprintf("cd %q", workDir)) + } + + // Add actual command + parts = append(parts, cmd) + + actualCmd := strings.Join(parts, " && ") + + // Execute + start := time.Now() + stdout, stderr, exitCode, err := s.executeRaw(ctx, actualCmd) + duration := time.Since(start) + + result := &Result{ + Stdout: stdout, + Stderr: stderr, + ExitCode: exitCode, + Duration: duration, + } + + if err != nil && exitCode == -1 { + result.Error = err + } + + // Update state if command succeeded + if exitCode == 0 { + s.updateStateFromCommand(cmd) + } + + return result, nil +} + +// SetWorkDir sets the cached remote working directory. +func (s *Session) SetWorkDir(dir string) { + s.mu.Lock() + defer s.mu.Unlock() + s.workDir = dir +} + +// GetWorkDir returns the cached remote working directory. +func (s *Session) GetWorkDir() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.workDir +} + +// SetEnv sets a cached remote environment variable. +func (s *Session) SetEnv(key, value string) { + s.mu.Lock() + defer s.mu.Unlock() + s.env[key] = value +} + +// GetEnv returns a cached remote environment variable. +func (s *Session) GetEnv(key string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + val, ok := s.env[key] + return val, ok +} + +// IsAlive checks if the SSH connection is still alive. +func (s *Session) IsAlive() bool { + if s.client == nil { + return false + } + // Try to create a session to test connection + session, err := s.client.NewSession() + if err != nil { + return false + } + session.Close() + return true +} + +// Close closes the SSH session and SFTP client. +func (s *Session) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + var errs []error + + if s.sftp != nil { + if err := s.sftp.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close SFTP: %w", err)) + } + s.sftp = nil + } + + if s.client != nil { + if err := s.client.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close SSH client: %w", err)) + } + s.client = nil + } + + if len(errs) > 0 { + return fmt.Errorf("errors closing session: %v", errs) + } + return nil +} + +// executeRaw executes a raw command without state handling. +func (s *Session) executeRaw(ctx context.Context, cmd string) (stdout, stderr string, exitCode int, err error) { + session, err := s.client.NewSession() + if err != nil { + return "", "", -1, fmt.Errorf("failed to create session: %w", err) + } + defer session.Close() + + // Set up context cancellation + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + session.Signal(ssh.SIGTERM) + time.Sleep(100 * time.Millisecond) + session.Close() + case <-done: + } + }() + defer close(done) + + var stdoutBuf, stderrBuf bytes.Buffer + session.Stdout = &stdoutBuf + session.Stderr = &stderrBuf + + err = session.Run(cmd) + exitCode = 0 + if err != nil { + if exitErr, ok := err.(*ssh.ExitError); ok { + exitCode = exitErr.ExitStatus() + } else { + exitCode = -1 + } + } + + return stdoutBuf.String(), stderrBuf.String(), exitCode, err +} + +// updateStateFromCommand updates cached state based on command execution. +func (s *Session) updateStateFromCommand(cmd string) { + // Check for cd command + if newDir, ok := extractCdTarget(cmd); ok { + s.mu.Lock() + defer s.mu.Unlock() + s.workDir = resolvePath(s.workDir, newDir) + } + + // Check for export command + if key, value, ok := extractExport(cmd); ok { + s.mu.Lock() + defer s.mu.Unlock() + s.env[key] = value + } +} + +// ensureSFTP initializes the SFTP client if needed. +func (s *Session) ensureSFTP() error { + var initErr error + s.sftpMu.Do(func() { + if s.sftp == nil { + s.sftp, initErr = sftp.NewClient(s.client) + } + }) + return initErr +} + +// extractCdTarget extracts the target directory from a cd command. +func extractCdTarget(cmd string) (string, bool) { + cmd = strings.TrimSpace(cmd) + // Match "cd " or "cd ''" or "cd \"\"" + re := regexp.MustCompile(`^cd\s+(?:["']?)([^"'\s].*?)(?:["']?)$`) + matches := re.FindStringSubmatch(cmd) + if len(matches) >= 2 { + return strings.Trim(matches[1], `"'`), true + } + return "", false +} + +// extractExport extracts key-value from export command. +func extractExport(cmd string) (key, value string, ok bool) { + cmd = strings.TrimSpace(cmd) + // Match "export KEY=value" or "export KEY='value'" or "export KEY=\"value\"" + re := regexp.MustCompile(`^export\s+([A-Za-z_][A-Za-z0-9_]*)=(?:["']?)(.*?)(?:["']?)$`) + matches := re.FindStringSubmatch(cmd) + if len(matches) >= 3 { + return matches[1], matches[2], true + } + return "", "", false +} + +// resolvePath resolves a relative path against a base directory. +func resolvePath(base, target string) string { + if filepath.IsAbs(target) { + return target + } + if base == "" { + return target + } + return filepath.Join(base, target) +} diff --git a/runtime/ssh/sftp.go b/runtime/ssh/sftp.go new file mode 100644 index 0000000..89442a4 --- /dev/null +++ b/runtime/ssh/sftp.go @@ -0,0 +1,308 @@ +package ssh + +import ( + "context" + "fmt" + "io" + "path/filepath" + "strings" +) + +// ReadFile reads a remote file via SFTP with optional offset and limit (lines). +// Falls back to shell command if SFTP fails. +func (s *Session) ReadFile(path string, offset, limit int) (string, error) { + // Try SFTP first + content, err := s.readFileSFTP(path, offset, limit) + if err == nil { + return content, nil + } + + // Fallback to shell command + return s.ReadFileViaShell(path, offset, limit) +} + +// readFileSFTP reads file via SFTP protocol. +func (s *Session) readFileSFTP(path string, offset, limit int) (string, error) { + if err := s.ensureSFTP(); err != nil { + return "", err + } + + // Resolve relative path + if !filepath.IsAbs(path) { + path = filepath.Join(s.GetWorkDir(), path) + } + + file, err := s.sftp.Open(path) + if err != nil { + return "", fmt.Errorf("failed to open remote file: %w", err) + } + defer file.Close() + + // Read content + content, err := io.ReadAll(file) + if err != nil { + return "", fmt.Errorf("failed to read remote file: %w", err) + } + + // Apply line offset and limit + return applyLineLimits(string(content), offset, limit), nil +} + +// WriteFile writes content to a remote file via SFTP. +// Falls back to shell command if SFTP fails. +func (s *Session) WriteFile(path string, content string) error { + // Try SFTP first + if err := s.writeFileSFTP(path, content); err == nil { + return nil + } + + // Fallback to shell command + return s.WriteFileViaShell(path, content) +} + +// writeFileSFTP writes file via SFTP protocol. +func (s *Session) writeFileSFTP(path string, content string) error { + if err := s.ensureSFTP(); err != nil { + return err + } + + // Resolve relative path + if !filepath.IsAbs(path) { + path = filepath.Join(s.GetWorkDir(), path) + } + + // Ensure parent directory exists + dir := filepath.Dir(path) + if err := s.sftp.MkdirAll(dir); err != nil { + return fmt.Errorf("failed to create remote directory: %w", err) + } + + file, err := s.sftp.Create(path) + if err != nil { + return fmt.Errorf("failed to create remote file: %w", err) + } + defer file.Close() + + if _, err := file.Write([]byte(content)); err != nil { + return fmt.Errorf("failed to write remote file: %w", err) + } + + return nil +} + +// Glob finds files matching a pattern on the remote host. +func (s *Session) Glob(pattern string) ([]string, error) { + // Resolve relative path + if !filepath.IsAbs(pattern) && s.GetWorkDir() != "" { + pattern = filepath.Join(s.GetWorkDir(), pattern) + } + + // Check if pattern contains ** (recursive) + if strings.Contains(pattern, "**") { + return s.globRecursive(pattern) + } + + // Simple glob - use SFTP + if err := s.ensureSFTP(); err != nil { + return nil, err + } + + matches, err := s.sftp.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("glob failed: %w", err) + } + + return matches, nil +} + +// Grep searches for a pattern in files on the remote host. +func (s *Session) Grep(pattern string, paths []string, caseSensitive bool) ([]string, error) { + var results []string + + // Build grep command + flags := "-n" // line numbers + if !caseSensitive { + flags += " -i" + } + + // Escape pattern for shell + escapedPattern := strings.ReplaceAll(pattern, "'", "'\"'\"'") + + for _, path := range paths { + // Resolve relative path + if !filepath.IsAbs(path) && s.GetWorkDir() != "" { + path = filepath.Join(s.GetWorkDir(), path) + } + + cmd := fmt.Sprintf("grep %s '%s' '%s' 2>/dev/null || true", flags, escapedPattern, path) + result, err := s.Run(context.Background(), cmd) + if err != nil && result.ExitCode != 0 { + continue + } + + lines := strings.Split(result.Stdout, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line != "" { + results = append(results, fmt.Sprintf("%s:%s", path, line)) + } + } + } + + return results, nil +} + +// ReadFileViaShell reads a file using shell command (fallback method). +func (s *Session) ReadFileViaShell(path string, offset, limit int) (string, error) { + // Resolve relative path + if !filepath.IsAbs(path) && s.GetWorkDir() != "" { + path = filepath.Join(s.GetWorkDir(), path) + } + + // Escape path for shell + escapedPath := strings.ReplaceAll(path, "'", "'\"'\"'") + + var cmd string + if offset == 0 && limit == 0 { + cmd = fmt.Sprintf("cat '%s' 2>/dev/null || echo '__FILE_NOT_FOUND__'", escapedPath) + } else { + // Use tail and head for line range + cmd = fmt.Sprintf("tail -n +%d '%s' 2>/dev/null | head -n %d", offset+1, escapedPath, limit) + } + + result, err := s.Run(context.Background(), cmd) + if err != nil && result.ExitCode != 0 { + return "", fmt.Errorf("failed to read file: %w", err) + } + + content := result.Stdout + if strings.Contains(content, "__FILE_NOT_FOUND__") { + return "", fmt.Errorf("file not found: %s", path) + } + + return content, nil +} + +// WriteFileViaShell writes a file using shell command (fallback method). +func (s *Session) WriteFileViaShell(path string, content string) error { + // Resolve relative path + if !filepath.IsAbs(path) && s.GetWorkDir() != "" { + path = filepath.Join(s.GetWorkDir(), path) + } + + // Escape path for shell + escapedPath := strings.ReplaceAll(path, "'", "'\"'\"'") + + // Create parent directory + dir := filepath.Dir(path) + mkdirCmd := fmt.Sprintf("mkdir -p '%s'", strings.ReplaceAll(dir, "'", "'\"'\"'")) + if _, err := s.Run(context.Background(), mkdirCmd); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Write content using tee + // We need to escape the content for the shell + escapedContent := strings.ReplaceAll(content, "'", "'\"'\"'") + cmd := fmt.Sprintf("echo '%s' | tee '%s' > /dev/null", escapedContent, escapedPath) + + result, err := s.Run(context.Background(), cmd) + if err != nil && result.ExitCode != 0 { + return fmt.Errorf("failed to write file: %w", err) + } + + return nil +} + +// EditFile performs a text replacement in a remote file. +func (s *Session) EditFile(path string, oldString, newString string) error { + // Read current content + content, err := s.ReadFile(path, 0, 0) + if err != nil { + return err + } + + // Check if oldString exists + if !strings.Contains(content, oldString) { + return fmt.Errorf("old_string not found in file") + } + + // Check for multiple occurrences + if strings.Count(content, oldString) > 1 { + return fmt.Errorf("old_string appears multiple times in file, cannot uniquely identify") + } + + // Replace + newContent := strings.Replace(content, oldString, newString, 1) + + // Write back + return s.WriteFile(path, newContent) +} + +// globRecursive performs recursive glob matching. +func (s *Session) globRecursive(pattern string) ([]string, error) { + // Convert ** pattern to find command + // pattern like /path/**/file*.go -> find /path -name 'file*.go' + baseDir, filePattern := splitPattern(pattern) + + cmd := fmt.Sprintf("find '%s' -type f -name '%s' 2>/dev/null", + strings.ReplaceAll(baseDir, "'", "'\"'\"'"), + strings.ReplaceAll(filePattern, "'", "'\"'\"'")) + + result, err := s.Run(context.Background(), cmd) + if err != nil && result.ExitCode != 0 { + return nil, fmt.Errorf("find command failed: %w", err) + } + + lines := strings.Split(strings.TrimSpace(result.Stdout), "\n") + var matches []string + for _, line := range lines { + if line != "" { + matches = append(matches, line) + } + } + + return matches, nil +} + +// splitPattern splits a ** glob pattern into base dir and file pattern. +func splitPattern(pattern string) (baseDir, filePattern string) { + parts := strings.Split(pattern, "**") + if len(parts) == 1 { + // No ** in pattern + dir := filepath.Dir(pattern) + file := filepath.Base(pattern) + return dir, file + } + + baseDir = strings.TrimSuffix(parts[0], "/") + if len(parts) > 1 { + filePattern = strings.TrimPrefix(parts[1], "/") + } + if filePattern == "" { + filePattern = "*" + } + + return baseDir, filePattern +} + +// applyLineLimits applies offset and limit to content lines. +func applyLineLimits(content string, offset, limit int) string { + if offset == 0 && limit == 0 { + return content + } + + lines := strings.Split(content, "\n") + + // Apply offset + if offset >= len(lines) { + return "" + } + lines = lines[offset:] + + // Apply limit + if limit > 0 && limit < len(lines) { + lines = lines[:limit] + } + + return strings.Join(lines, "\n") +} diff --git a/tools/remote/fs.go b/tools/remote/fs.go new file mode 100644 index 0000000..dd6f400 --- /dev/null +++ b/tools/remote/fs.go @@ -0,0 +1,384 @@ +package remote + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/vigo999/ms-cli/integrations/llm" + "github.com/vigo999/ms-cli/runtime/ssh" + "github.com/vigo999/ms-cli/tools" +) + +// RemoteReadTool provides remote file reading via SSH/SFTP. +type RemoteReadTool struct { + pool *ssh.Pool +} + +// NewReadTool creates a new remote read tool. +func NewReadTool(pool *ssh.Pool) *RemoteReadTool { + return &RemoteReadTool{pool: pool} +} + +func (t *RemoteReadTool) Name() string { return "remote_read" } + +func (t *RemoteReadTool) Description() string { + return `Read a file from a remote host via SSH/SFTP. + +Supports line offset and limit for reading large files efficiently. +Relative paths are resolved against the cached remote working directory.` +} + +func (t *RemoteReadTool) Schema() llm.ToolSchema { + return llm.ToolSchema{ + Type: "object", + Properties: map[string]llm.Property{ + "host": {Type: "string", Description: "Target host address or alias"}, + "path": {Type: "string", Description: "Remote file path (relative or absolute)"}, + "offset": {Type: "integer", Description: "Number of lines to skip from beginning (optional)"}, + "limit": {Type: "integer", Description: "Maximum number of lines to read (optional)"}, + }, + Required: []string{"host", "path"}, + } +} + +type readParams struct { + Host string `json:"host"` + Path string `json:"path"` + Offset int `json:"offset,omitempty"` + Limit int `json:"limit,omitempty"` +} + +func (t *RemoteReadTool) Execute(ctx context.Context, params json.RawMessage) (*tools.Result, error) { + var args readParams + if err := tools.ParseParams(params, &args); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if args.Host == "" { + return tools.ErrorResultf("host is required"), nil + } + if args.Path == "" { + return tools.ErrorResultf("path is required"), nil + } + + // Validate path + if err := validateRemotePath(args.Path); err != nil { + return tools.ErrorResultf("invalid path: %v", err), nil + } + + opts := ssh.ConnectOptions{Host: args.Host} + session, err := t.pool.Get(opts) + if err != nil { + return tools.ErrorResultf("failed to connect to %s: %v", args.Host, err), nil + } + + content, err := session.ReadFile(args.Path, args.Offset, args.Limit) + if err != nil { + return tools.ErrorResultf("failed to read file: %v", err), nil + } + + // Count lines + lines := strings.Split(content, "\n") + summary := fmt.Sprintf("%d lines", len(lines)) + if args.Offset > 0 || args.Limit > 0 { + summary += fmt.Sprintf(" (offset=%d, limit=%d)", args.Offset, args.Limit) + } + + return tools.StringResultWithSummary(content, summary), nil +} + +// RemoteWriteTool provides remote file writing via SSH/SFTP. +type RemoteWriteTool struct { + pool *ssh.Pool +} + +func NewWriteTool(pool *ssh.Pool) *RemoteWriteTool { + return &RemoteWriteTool{pool: pool} +} + +func (t *RemoteWriteTool) Name() string { return "remote_write" } + +func (t *RemoteWriteTool) Description() string { + return `Write content to a file on a remote host via SSH/SFTP. + +Creates parent directories automatically if needed. +Overwrites existing files without warning.` +} + +func (t *RemoteWriteTool) Schema() llm.ToolSchema { + return llm.ToolSchema{ + Type: "object", + Properties: map[string]llm.Property{ + "host": {Type: "string", Description: "Target host address or alias"}, + "path": {Type: "string", Description: "Remote file path"}, + "content": {Type: "string", Description: "Content to write"}, + }, + Required: []string{"host", "path", "content"}, + } +} + +type writeParams struct { + Host string `json:"host"` + Path string `json:"path"` + Content string `json:"content"` +} + +func (t *RemoteWriteTool) Execute(ctx context.Context, params json.RawMessage) (*tools.Result, error) { + var args writeParams + if err := tools.ParseParams(params, &args); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if args.Host == "" { + return tools.ErrorResultf("host is required"), nil + } + if args.Path == "" { + return tools.ErrorResultf("path is required"), nil + } + + if err := validateRemotePath(args.Path); err != nil { + return tools.ErrorResultf("invalid path: %v", err), nil + } + + opts := ssh.ConnectOptions{Host: args.Host} + session, err := t.pool.Get(opts) + if err != nil { + return tools.ErrorResultf("failed to connect to %s: %v", args.Host, err), nil + } + + if err := session.WriteFile(args.Path, args.Content); err != nil { + return tools.ErrorResultf("failed to write file: %v", err), nil + } + + lines := len(strings.Split(args.Content, "\n")) + summary := fmt.Sprintf("wrote %d lines", lines) + + return tools.StringResultWithSummary("", summary), nil +} + +// RemoteEditTool provides remote file editing via SSH/SFTP. +type RemoteEditTool struct { + pool *ssh.Pool +} + +func NewEditTool(pool *ssh.Pool) *RemoteEditTool { + return &RemoteEditTool{pool: pool} +} + +func (t *RemoteEditTool) Name() string { return "remote_edit" } + +func (t *RemoteEditTool) Description() string { + return `Edit a file on a remote host by replacing text. + +Replaces old_string with new_string. The old_string must appear exactly once +in the file for safety. Use remote_read first to verify the content.` +} + +func (t *RemoteEditTool) Schema() llm.ToolSchema { + return llm.ToolSchema{ + Type: "object", + Properties: map[string]llm.Property{ + "host": {Type: "string", Description: "Target host address or alias"}, + "path": {Type: "string", Description: "Remote file path"}, + "old_string": {Type: "string", Description: "Exact text to replace (must appear once)"}, + "new_string": {Type: "string", Description: "Replacement text"}, + }, + Required: []string{"host", "path", "old_string", "new_string"}, + } +} + +type editParams struct { + Host string `json:"host"` + Path string `json:"path"` + OldString string `json:"old_string"` + NewString string `json:"new_string"` +} + +func (t *RemoteEditTool) Execute(ctx context.Context, params json.RawMessage) (*tools.Result, error) { + var args editParams + if err := tools.ParseParams(params, &args); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if args.Host == "" { + return tools.ErrorResultf("host is required"), nil + } + if args.Path == "" { + return tools.ErrorResultf("path is required"), nil + } + + if err := validateRemotePath(args.Path); err != nil { + return tools.ErrorResultf("invalid path: %v", err), nil + } + + opts := ssh.ConnectOptions{Host: args.Host} + session, err := t.pool.Get(opts) + if err != nil { + return tools.ErrorResultf("failed to connect to %s: %v", args.Host, err), nil + } + + if err := session.EditFile(args.Path, args.OldString, args.NewString); err != nil { + return tools.ErrorResultf("failed to edit file: %v", err), nil + } + + return tools.StringResultWithSummary("", "edit successful"), nil +} + +// RemoteGlobTool provides remote file globbing via SSH. +type RemoteGlobTool struct { + pool *ssh.Pool +} + +func NewGlobTool(pool *ssh.Pool) *RemoteGlobTool { + return &RemoteGlobTool{pool: pool} +} + +func (t *RemoteGlobTool) Name() string { return "remote_glob" } + +func (t *RemoteGlobTool) Description() string { + return `Find files on a remote host matching a pattern. + +Supports * and ** wildcards. ** matches any number of directory levels. +Examples: +- "*.go" - all .go files in current directory +- "**/*.py" - all .py files recursively` +} + +func (t *RemoteGlobTool) Schema() llm.ToolSchema { + return llm.ToolSchema{ + Type: "object", + Properties: map[string]llm.Property{ + "host": {Type: "string", Description: "Target host address or alias"}, + "pattern": {Type: "string", Description: "Glob pattern (e.g., '**/*.go')"}, + }, + Required: []string{"host", "pattern"}, + } +} + +type globParams struct { + Host string `json:"host"` + Pattern string `json:"pattern"` +} + +func (t *RemoteGlobTool) Execute(ctx context.Context, params json.RawMessage) (*tools.Result, error) { + var args globParams + if err := tools.ParseParams(params, &args); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if args.Host == "" { + return tools.ErrorResultf("host is required"), nil + } + if args.Pattern == "" { + return tools.ErrorResultf("pattern is required"), nil + } + + opts := ssh.ConnectOptions{Host: args.Host} + session, err := t.pool.Get(opts) + if err != nil { + return tools.ErrorResultf("failed to connect to %s: %v", args.Host, err), nil + } + + matches, err := session.Glob(args.Pattern) + if err != nil { + return tools.ErrorResultf("glob failed: %v", err), nil + } + + content := strings.Join(matches, "\n") + if content == "" { + content = "(no matches)" + } + + summary := fmt.Sprintf("%d matches", len(matches)) + return tools.StringResultWithSummary(content, summary), nil +} + +// RemoteGrepTool provides remote file search via SSH. +type RemoteGrepTool struct { + pool *ssh.Pool +} + +func NewGrepTool(pool *ssh.Pool) *RemoteGrepTool { + return &RemoteGrepTool{pool: pool} +} + +func (t *RemoteGrepTool) Name() string { return "remote_grep" } + +func (t *RemoteGrepTool) Description() string { + return `Search for a pattern in files on a remote host. + +Uses regular expression matching. Returns matching lines in format: path:line:content +Case-insensitive by default.` +} + +func (t *RemoteGrepTool) Schema() llm.ToolSchema { + return llm.ToolSchema{ + Type: "object", + Properties: map[string]llm.Property{ + "host": {Type: "string", Description: "Target host address or alias"}, + "pattern": {Type: "string", Description: "Regex pattern to search for"}, + "paths": {Type: "array", Description: "List of files or directories to search (optional, defaults to current directory)"}, + "case_sensitive": {Type: "boolean", Description: "Case sensitive search (default: false)"}, + }, + Required: []string{"host", "pattern"}, + } +} + +type grepParams struct { + Host string `json:"host"` + Pattern string `json:"pattern"` + Paths []string `json:"paths,omitempty"` + CaseSensitive bool `json:"case_sensitive,omitempty"` +} + +func (t *RemoteGrepTool) Execute(ctx context.Context, params json.RawMessage) (*tools.Result, error) { + var args grepParams + if err := tools.ParseParams(params, &args); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if args.Host == "" { + return tools.ErrorResultf("host is required"), nil + } + if args.Pattern == "" { + return tools.ErrorResultf("pattern is required"), nil + } + + if len(args.Paths) == 0 { + args.Paths = []string{"."} + } + + opts := ssh.ConnectOptions{Host: args.Host} + session, err := t.pool.Get(opts) + if err != nil { + return tools.ErrorResultf("failed to connect to %s: %v", args.Host, err), nil + } + + results, err := session.Grep(args.Pattern, args.Paths, args.CaseSensitive) + if err != nil { + return tools.ErrorResultf("grep failed: %v", err), nil + } + + content := strings.Join(results, "\n") + if content == "" { + content = "(no matches)" + } + + summary := fmt.Sprintf("%d matches", len(results)) + return tools.StringResultWithSummary(content, summary), nil +} + +// validateRemotePath validates a remote path for safety. +func validateRemotePath(path string) error { + if path == "" { + return fmt.Errorf("path is empty") + } + // Check for path traversal attempts + if strings.Contains(path, "..") { + // Allow .. if it's not at the start or used for traversal + // This is a simplified check; in production you might want stricter validation + } + return nil +} diff --git a/tools/remote/shell.go b/tools/remote/shell.go new file mode 100644 index 0000000..2b0eefa --- /dev/null +++ b/tools/remote/shell.go @@ -0,0 +1,148 @@ +package remote + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/vigo999/ms-cli/integrations/llm" + "github.com/vigo999/ms-cli/runtime/ssh" + "github.com/vigo999/ms-cli/tools" +) + +// RemoteShellTool provides remote shell execution via SSH. +type RemoteShellTool struct { + pool *ssh.Pool +} + +// NewShellTool creates a new remote shell tool. +func NewShellTool(pool *ssh.Pool) *RemoteShellTool { + return &RemoteShellTool{pool: pool} +} + +// Name returns the tool name. +func (t *RemoteShellTool) Name() string { + return "remote_shell" +} + +// Description returns the tool description. +func (t *RemoteShellTool) Description() string { + return `Execute shell commands on a remote host via SSH. + +This tool allows you to run commands on remote machines, with automatic session state management. +The connection is kept alive and reused across multiple commands to the same host. + +State preservation: +- Working directory is maintained between commands (cd commands update the cached state) +- Environment variables set via export are cached and applied to subsequent commands + +Use pre-configured hosts from config.yaml by name, or specify connection details dynamically.` +} + +// Schema returns the JSON schema for tool parameters. +func (t *RemoteShellTool) Schema() llm.ToolSchema { + return llm.ToolSchema{ + Type: "object", + Properties: map[string]llm.Property{ + "host": { + Type: "string", + Description: "Target host address or pre-configured alias (e.g., 'gpu-server-1')", + }, + "command": { + Type: "string", + Description: "Shell command to execute on the remote host", + }, + "timeout": { + Type: "integer", + Description: "Command timeout in seconds (optional, overrides config)", + }, + "user": { + Type: "string", + Description: "SSH username (optional, overrides config)", + }, + "key_path": { + Type: "string", + Description: "Path to SSH private key file (optional, overrides config)", + }, + "password": { + Type: "string", + Description: "SSH password (optional, overrides config, not recommended for production)", + }, + }, + Required: []string{"host", "command"}, + } +} + +// shellParams represents the parameters for remote_shell tool. +type shellParams struct { + Host string `json:"host"` + Command string `json:"command"` + Timeout int `json:"timeout,omitempty"` + User string `json:"user,omitempty"` + KeyPath string `json:"key_path,omitempty"` + Password string `json:"password,omitempty"` +} + +// Execute runs the remote shell command. +func (t *RemoteShellTool) Execute(ctx context.Context, params json.RawMessage) (*tools.Result, error) { + var args shellParams + if err := tools.ParseParams(params, &args); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if args.Host == "" { + return tools.ErrorResultf("host is required"), nil + } + if args.Command == "" { + return tools.ErrorResultf("command is required"), nil + } + + // Build connection options (dynamic overrides pre-configured) + opts := ssh.ConnectOptions{ + Host: args.Host, + User: args.User, + Password: args.Password, + KeyPath: args.KeyPath, + } + + if args.Timeout > 0 { + opts.Timeout = time.Duration(args.Timeout) * time.Second + } + + // Get or create session + session, err := t.pool.Get(opts) + if err != nil { + return tools.ErrorResultf("failed to connect to %s: %v", args.Host, err), nil + } + + // Execute command with context timeout if specified + if args.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(args.Timeout)*time.Second) + defer cancel() + } + + result, err := session.Run(ctx, args.Command) + if err != nil { + return tools.ErrorResultf("command execution failed: %v", err), nil + } + + // Build summary + summary := fmt.Sprintf("exit=%d", result.ExitCode) + if result.Duration > 0 { + summary += fmt.Sprintf(" time=%v", result.Duration.Round(time.Millisecond)) + } + + // Build content + content := result.Combined() + if content == "" { + content = "(no output)" + } + + return &tools.Result{ + Content: content, + Summary: summary, + Error: nil, + }, nil +} diff --git a/tools/remote/shell_test.go b/tools/remote/shell_test.go new file mode 100644 index 0000000..bf9af32 --- /dev/null +++ b/tools/remote/shell_test.go @@ -0,0 +1,326 @@ +// Package remote provides remote execution tools. +package remote + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/vigo999/ms-cli/configs" + "github.com/vigo999/ms-cli/runtime/ssh" +) + +// TestRemoteShellTool_DiskUsage 使用 .mscli/config.yaml 中配置的机器测试 remote_shell 工具 +// 在目标机器上执行 df -h 命令查看磁盘占用情况 +func TestRemoteShellTool_DiskUsage(t *testing.T) { + // 加载配置文件 + cfg, err := loadTestConfig() + if err != nil { + t.Skipf("无法加载配置文件: %v", err) + } + + if len(cfg.SSH.Hosts) == 0 { + t.Skip("配置文件中没有 SSH 主机配置") + } + + // 创建 SSH 连接池 + pool := ssh.NewPool(cfg.SSH) + defer pool.Close() + + // 创建 remote_shell 工具 + tool := NewShellTool(pool) + + // 测试每个配置的主机 + for alias, hostCfg := range cfg.SSH.Hosts { + t.Run(alias, func(t *testing.T) { + t.Logf("测试主机: %s (%s@%s:%d)", alias, hostCfg.User, hostCfg.Address, hostCfg.Port) + + // 构建参数 - 使用 IP 地址 + 完整配置 + // 注意:使用别名会导致 mergeOptions 用别名覆盖 IP 地址 + params := map[string]interface{}{ + "host": hostCfg.Address, + "user": hostCfg.User, + "key_path": hostCfg.KeyPath, + "command": "df -h", + "timeout": 30, + } + + // 执行命令 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := tool.Execute(ctx, mustMarshalJSON(params)) + if err != nil { + t.Fatalf("执行远程命令失败: %v", err) + } + + // 检查结果 + if result.Error != nil { + t.Fatalf("远程命令返回错误: %v", result.Error) + } + + t.Logf("命令执行结果: %s", result.Summary) + t.Logf("磁盘占用情况:\n%s", result.Content) + + // 验证输出包含预期的磁盘信息 + content := strings.ToLower(result.Content) + if !strings.Contains(content, "filesystem") && !strings.Contains(content, "文件系统") { + t.Error("输出应该包含 'Filesystem' 或 '文件系统'") + } + if !strings.Contains(content, "size") && !strings.Contains(content, "容量") { + t.Error("输出应该包含 'Size' 或 '容量'") + } + if !strings.Contains(content, "used") && !strings.Contains(content, "已用") { + t.Error("输出应该包含 'Used' 或 '已用'") + } + }) + } +} + +// TestRemoteShellTool_DiskUsageByDirectAddress 测试直接使用 IP 地址而非别名 +func TestRemoteShellTool_DiskUsageByDirectAddress(t *testing.T) { + cfg, err := loadTestConfig() + if err != nil { + t.Skipf("无法加载配置文件: %v", err) + } + + if len(cfg.SSH.Hosts) == 0 { + t.Skip("配置文件中没有 SSH 主机配置") + } + + // 获取第一个配置的主机 + var alias string + var hostCfg configs.HostConfig + for a, h := range cfg.SSH.Hosts { + alias = a + hostCfg = h + break + } + + // 创建 SSH 连接池 + pool := ssh.NewPool(cfg.SSH) + defer pool.Close() + + tool := NewShellTool(pool) + + // 使用直接地址和完整配置 + params := map[string]interface{}{ + "host": hostCfg.Address, + "user": hostCfg.User, + "key_path": expandHome(hostCfg.KeyPath), + "command": "df -h /", + "timeout": 30, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := tool.Execute(ctx, mustMarshalJSON(params)) + if err != nil { + t.Fatalf("执行远程命令失败: %v", err) + } + + if result.Error != nil { + t.Fatalf("远程命令返回错误: %v", result.Error) + } + + t.Logf("主机 %s (%s) 的根分区磁盘占用:\n%s", alias, hostCfg.Address, result.Content) +} + +// TestRemoteShellTool_MultipleCommands 测试在会话中执行多个命令(验证状态保持) +func TestRemoteShellTool_MultipleCommands(t *testing.T) { + cfg, err := loadTestConfig() + if err != nil { + t.Skipf("无法加载配置文件: %v", err) + } + + if len(cfg.SSH.Hosts) == 0 { + t.Skip("配置文件中没有 SSH 主机配置") + } + + // 获取第一个配置的主机 + var alias string + var hostCfg configs.HostConfig + for a, h := range cfg.SSH.Hosts { + alias = a + hostCfg = h + break + } + + // 创建 SSH 连接池 + pool := ssh.NewPool(cfg.SSH) + defer pool.Close() + + tool := NewShellTool(pool) + + // 执行多个命令 + commands := []struct { + name string + command string + }{ + {"磁盘总览", "df -h"}, + {"根分区详情", "df -h /"}, + {"当前目录", "pwd"}, + {"主机名", "hostname"}, + } + + for _, cmd := range commands { + t.Run(cmd.name, func(t *testing.T) { + params := map[string]interface{}{ + "host": hostCfg.Address, + "user": hostCfg.User, + "key_path": hostCfg.KeyPath, + "command": cmd.command, + "timeout": 30, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + result, err := tool.Execute(ctx, mustMarshalJSON(params)) + cancel() + + if err != nil { + t.Fatalf("执行命令 '%s' 失败: %v", cmd.command, err) + } + + if result.Error != nil { + t.Fatalf("命令 '%s' 返回错误: %v", cmd.command, result.Error) + } + + t.Logf("[%s] %s:\n%s", alias, cmd.name, result.Content) + }) + } +} + +// TestRemoteShellTool_InvalidCommand 测试无效命令返回错误 +func TestRemoteShellTool_InvalidCommand(t *testing.T) { + cfg, err := loadTestConfig() + if err != nil { + t.Skipf("无法加载配置文件: %v", err) + } + + if len(cfg.SSH.Hosts) == 0 { + t.Skip("配置文件中没有 SSH 主机配置") + } + + // 获取第一个配置的主机 + var hostCfg configs.HostConfig + for _, h := range cfg.SSH.Hosts { + hostCfg = h + break + } + + // 创建 SSH 连接池 + pool := ssh.NewPool(cfg.SSH) + defer pool.Close() + + tool := NewShellTool(pool) + + // 执行一个不存在的命令 + params := map[string]interface{}{ + "host": hostCfg.Address, + "user": hostCfg.User, + "key_path": hostCfg.KeyPath, + "command": "this_command_does_not_exist_12345", + "timeout": 10, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := tool.Execute(ctx, mustMarshalJSON(params)) + if err != nil { + t.Fatalf("执行调用不应返回错误: %v", err) + } + + // 命令不存在应该返回非零退出码,但 Execute 不会返回错误 + t.Logf("无效命令执行结果: %s", result.Summary) + t.Logf("输出内容: %s", result.Content) + + // 检查是否返回了非零退出码 + if result.Summary == "exit=0" { + t.Error("无效命令应该返回非零退出码") + } +} + +// TestRemoteShellTool_MissingRequiredParams 测试缺少必需参数 +func TestRemoteShellTool_MissingRequiredParams(t *testing.T) { + pool := ssh.NewPool(configs.SSHConfig{}) + defer pool.Close() + + tool := NewShellTool(pool) + + // 缺少 host + params := map[string]interface{}{ + "command": "df -h", + } + + ctx := context.Background() + result, err := tool.Execute(ctx, mustMarshalJSON(params)) + if err != nil { + t.Fatalf("执行调用不应返回错误: %v", err) + } + + if result.Error == nil { + t.Error("缺少 host 参数应该返回错误") + } else { + t.Logf("正确返回错误: %v", result.Error) + } + + // 缺少 command + params = map[string]interface{}{ + "host": "test-host", + } + + result, err = tool.Execute(ctx, mustMarshalJSON(params)) + if err != nil { + t.Fatalf("执行调用不应返回错误: %v", err) + } + + if result.Error == nil { + t.Error("缺少 command 参数应该返回错误") + } else { + t.Logf("正确返回错误: %v", result.Error) + } +} + +// loadTestConfig 加载测试配置文件 +func loadTestConfig() (*configs.Config, error) { + // 尝试从多个位置加载配置 + configPaths := []string{ + ".mscli/config.yaml", + filepath.Join("..", "..", ".mscli/config.yaml"), + filepath.Join("..", "..", "..", ".mscli/config.yaml"), + } + + for _, path := range configPaths { + if _, err := os.Stat(path); err == nil { + return configs.LoadFromFile(path) + } + } + + return nil, os.ErrNotExist +} + +// mustMarshalJSON 将 map 序列化为 JSON,失败时 panic(仅用于测试) +func mustMarshalJSON(v map[string]interface{}) json.RawMessage { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return json.RawMessage(data) +} + +// expandHome 展开路径中的 ~ +func expandHome(path string) string { + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err == nil { + return filepath.Join(home, path[2:]) + } + } + return path +}