diff --git a/cmd/axis/agent.go b/cmd/axis/agent.go index 30071d1..0fa3721 100644 --- a/cmd/axis/agent.go +++ b/cmd/axis/agent.go @@ -5,10 +5,12 @@ import ( "context" "encoding/json" "fmt" + "io" "os" "strings" "time" + "github.com/chzyer/readline" "github.com/spf13/cobra" "github.com/toasterbook88/axis/internal/agent" "github.com/toasterbook88/axis/internal/api" @@ -37,6 +39,7 @@ func agentCmd() *cobra.Command { maxTurns int autoApprove bool systemMsg string + resume bool ) cmd := &cobra.Command{ @@ -93,6 +96,18 @@ func agentCmd() *cobra.Command { a := agent.New(cfg) + // Resume previous conversation if requested. + historyPath, err := chat.PersistPath("agent") + if err != nil { + fmt.Fprintf(errW, "warning: cannot determine history path: %v\n", err) + } else if resume { + if err := a.Conversation().LoadFromFile(historyPath); err != nil { + fmt.Fprintf(errW, "warning: could not resume conversation: %v\n", err) + } else if n := a.Conversation().HistoryCount(); n > 0 { + fmt.Fprintf(errW, "Resumed %d messages from previous session.\n", n) + } + } + // Single-shot mode. if len(args) > 0 { instruction := strings.Join(args, " ") @@ -105,19 +120,35 @@ func agentCmd() *cobra.Command { return ExitCodeError{Code: ExitErrCommandFail, Message: fmt.Sprintf("agent failed: %v", err)} } fmt.Fprintln(w) + if historyPath != "" { + _ = a.Conversation().SaveToFile(historyPath) + } return nil } - // Interactive REPL. + // Interactive REPL with readline. fmt.Fprintf(errW, "AXIS Agent [%s] — max %d turns per query, type exit to quit\n\n", ui.Bold(currentModel), maxTurns) - scanner := bufio.NewScanner(os.Stdin) + rlCfg := &readline.Config{ + Prompt: ui.Cyan("agent> "), + InterruptPrompt: "^C", + EOFPrompt: "exit", + } + if historyPath != "" { + rlCfg.HistoryFile = historyPath + ".line" + } + rl, err := readline.NewEx(rlCfg) + if err != nil { + return runPlainAgentREPL(a, w, errW, timeout, historyPath) + } + defer rl.Close() + for { - fmt.Fprint(errW, ui.Cyan("agent> ")) - if !scanner.Scan() { + line, err := rl.Readline() + if err != nil { break } - instruction := strings.TrimSpace(scanner.Text()) + instruction := strings.TrimSpace(line) if instruction == "" { continue } @@ -133,6 +164,14 @@ func agentCmd() *cobra.Command { cancel() fmt.Fprintln(w) } + + if historyPath != "" && a.Conversation().HistoryCount() > 0 { + if err := a.Conversation().SaveToFile(historyPath); err != nil { + fmt.Fprintf(errW, "warning: could not save conversation: %v\n", err) + } else { + fmt.Fprintf(errW, "Saved %d messages to conversation history.\n", a.Conversation().HistoryCount()) + } + } return nil }, } @@ -143,9 +182,40 @@ func agentCmd() *cobra.Command { cmd.Flags().IntVar(&maxTurns, "max-turns", 10, "Maximum agent loop iterations per query") cmd.Flags().BoolVar(&autoApprove, "auto-approve", false, "Auto-approve safe commands (safety score < 70)") cmd.Flags().StringVar(&systemMsg, "system", "", "Extra text appended to system prompt") + cmd.Flags().BoolVar(&resume, "resume", false, "Resume previous conversation from history") return cmd } +// runPlainAgentREPL is the fallback scanner-based REPL when readline is unavailable. +func runPlainAgentREPL(a *agent.Agent, w, errW io.Writer, timeout time.Duration, historyPath string) error { + fmt.Fprintln(errW, ui.Yellow("Note: using plain input mode (no arrow keys or history)")) + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Fprint(errW, ui.Cyan("agent\u003e ")) + if !scanner.Scan() { + break + } + instruction := strings.TrimSpace(scanner.Text()) + if instruction == "" { + continue + } + lower := strings.ToLower(instruction) + if lower == "exit" || lower == "quit" { + break + } + ctx, cancel := agentRequestContext(timeout) + if err := a.Run(ctx, instruction); err != nil { + fmt.Fprintf(errW, "\n%s %v\n", ui.Red("Error:"), err) + } + cancel() + fmt.Fprintln(w) + } + if historyPath != "" && a.Conversation().HistoryCount() > 0 { + _ = a.Conversation().SaveToFile(historyPath) + } + return nil +} + func agentRequestContext(timeout time.Duration) (context.Context, context.CancelFunc) { if timeout <= 0 { return context.WithCancel(context.Background()) diff --git a/cmd/axis/chat.go b/cmd/axis/chat.go index 3f16a89..ee90a0c 100644 --- a/cmd/axis/chat.go +++ b/cmd/axis/chat.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/chzyer/readline" "github.com/spf13/cobra" "github.com/toasterbook88/axis/internal/chat" "github.com/toasterbook88/axis/internal/config" @@ -32,6 +33,7 @@ func chatCmd() *cobra.Command { useContext bool systemMsg string format string + resume bool ) cmd := &cobra.Command{ @@ -68,6 +70,18 @@ func chatCmd() *cobra.Command { sysPrompt := chat.BuildSystemPrompt(cluster, systemMsg) conv.Append(chat.Message{Role: chat.RoleSystem, Content: sysPrompt}) + // Resume previous conversation if requested. + historyPath, err := chat.PersistPath("chat") + if err != nil { + fmt.Fprintf(errW, "warning: cannot determine history path: %v\n", err) + } else if resume { + if err := conv.LoadFromFile(historyPath); err != nil { + fmt.Fprintf(errW, "warning: could not resume conversation: %v\n", err) + } else if n := conv.HistoryCount(); n > 0 { + fmt.Fprintf(errW, "Resumed %d messages from previous session.\n", n) + } + } + fmt.Fprintln(errW, ui.Dim("advisory: chat output is not cluster truth — validate with axis status or axis facts")) // Single-shot mode. @@ -100,19 +114,37 @@ func chatCmd() *cobra.Command { } else { fmt.Fprintln(w) } + // Save conversation after single-shot. + if historyPath != "" { + _ = conv.SaveToFile(historyPath) + } return nil } - // Interactive REPL. + // Interactive REPL with readline. fmt.Fprintf(errW, "AXIS Chat [%s] — type /help for commands, exit to quit\n\n", ui.Bold(currentModel)) - scanner := bufio.NewScanner(os.Stdin) + + cfg := &readline.Config{ + Prompt: ui.Cyan(">>> "), + InterruptPrompt: "^C", + EOFPrompt: "exit", + } + if historyPath != "" { + cfg.HistoryFile = historyPath + ".line" + } + rl, err := readline.NewEx(cfg) + if err != nil { + // Fallback to plain scanner if readline fails (e.g., non-TTY). + return runPlainREPL(cmd.Context(), client, conv, currentModel, w, errW, timeout, historyPath) + } + defer rl.Close() for { - fmt.Fprint(errW, ui.Cyan(">>> ")) - if !scanner.Scan() { + line, err := rl.Readline() + if err != nil { // io.EOF or readline.ErrInterrupt break } - query := strings.TrimSpace(scanner.Text()) + query := strings.TrimSpace(line) if query == "" { continue } @@ -148,6 +180,15 @@ func chatCmd() *cobra.Command { conv.Append(resp) fmt.Fprintln(w) } + + // Save conversation on exit. + if historyPath != "" && conv.HistoryCount() > 0 { + if err := conv.SaveToFile(historyPath); err != nil { + fmt.Fprintf(errW, "warning: could not save conversation: %v\n", err) + } else { + fmt.Fprintf(errW, "Saved %d messages to conversation history.\n", conv.HistoryCount()) + } + } return nil }, } @@ -158,9 +199,55 @@ func chatCmd() *cobra.Command { cmd.Flags().BoolVar(&useContext, "context", false, "Inject live cluster snapshot into system prompt") cmd.Flags().StringVar(&systemMsg, "system", "", "Extra text appended to system prompt") cmd.Flags().StringVar(&format, "format", "text", "Output format for single-shot mode (text, json)") + cmd.Flags().BoolVar(&resume, "resume", false, "Resume previous conversation from history") return cmd } +// runPlainREPL is the fallback scanner-based REPL when readline is unavailable. +func runPlainREPL(ctx context.Context, client *chat.Client, conv *chat.Conversation, currentModel string, w, errW io.Writer, timeout time.Duration, historyPath string) error { + fmt.Fprintln(errW, ui.Yellow("Note: using plain input mode (no arrow keys or history)")) + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Fprint(errW, ui.Cyan(">>> ")) + if !scanner.Scan() { + break + } + query := strings.TrimSpace(scanner.Text()) + if query == "" { + continue + } + lower := strings.ToLower(query) + if lower == "exit" || lower == "quit" { + break + } + if strings.HasPrefix(query, "/") { + nextModel := handleSlashCommand(query, currentModel, conv, errW) + if nextModel != "" { + currentModel = nextModel + client = chat.NewClient(chat.DefaultEndpoint, currentModel) + } + continue + } + conv.Append(chat.Message{Role: chat.RoleUser, Content: query}) + sp := ui.NewSpinner() + sp.Start("Thinking...") + ctx2, cancel := chatRequestContext(timeout) + resp, err := client.ChatStream(ctx2, conv.Messages(), nil, w) + sp.Stop("") + cancel() + if err != nil { + fmt.Fprintf(errW, "\n%s\n", ui.Red("Error: ", err)) + continue + } + conv.Append(resp) + fmt.Fprintln(w) + } + if historyPath != "" && conv.HistoryCount() > 0 { + _ = conv.SaveToFile(historyPath) + } + return nil +} + // handleSlashCommand processes a slash command and returns a new model name // if the model was switched, or empty string otherwise. func handleSlashCommand(input, currentModel string, conv *chat.Conversation, w io.Writer) string { diff --git a/go.mod b/go.mod index 52d1357..65e1070 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( ) require ( + github.com/chzyer/readline v1.5.1 // indirect github.com/google/jsonschema-go v0.4.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 9815d3d..143452f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ al.essio.dev/pkg/shellescape v1.6.0 h1:NxFcEqzFSEVCGN2yq7Huv/9hyCEGVa/TncnOOBBeXHA= al.essio.dev/pkg/shellescape v1.6.0/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -51,6 +55,7 @@ golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 60bc42f..726bf8a 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -144,11 +144,12 @@ func (a *Agent) Run(ctx context.Context, userPrompt string) error { // Process each tool call. for _, tc := range resp.ToolCalls { + fmt.Fprintf(a.output, "\n▶ Calling %s...\n", tc.Function.Name) result, err := a.dispatchToolCall(ctx, tc) if err != nil { // Feed the error back to the model for self-correction. errMsg := fmt.Sprintf("Error executing tool %q: %s", tc.Function.Name, err.Error()) - fmt.Fprintf(a.output, "\n⚠ %s\n", errMsg) + fmt.Fprintf(a.output, "⚠ %s\n", errMsg) a.conv.Append(chat.Message{ Role: chat.RoleTool, Content: errMsg, @@ -156,7 +157,9 @@ func (a *Agent) Run(ctx context.Context, userPrompt string) error { continue } - fmt.Fprintf(a.output, "\n✓ %s returned %d chars\n", tc.Function.Name, len(result)) + // Print a compact summary line instead of raw char count. + summary := formatToolResultSummary(tc.Function.Name, result) + fmt.Fprintf(a.output, "✓ %s\n", summary) a.conv.Append(chat.Message{ Role: chat.RoleTool, Content: result, @@ -168,6 +171,51 @@ func (a *Agent) Run(ctx context.Context, userPrompt string) error { return nil } +// formatToolResultSummary produces a human-readable one-line summary of a +// tool result for operator feedback. +func formatToolResultSummary(toolName, result string) string { + switch toolName { + case "axis_status": + // Extract first line (cluster summary). + if i := strings.Index(result, "\n"); i > 0 { + return toolName + ": " + strings.TrimSpace(result[:i]) + } + case "axis_summary": + return toolName + ": " + strings.TrimSpace(result) + case "axis_facts": + if i := strings.Index(result, "\n"); i > 0 { + return toolName + ": " + strings.TrimSpace(result[:i]) + } + case "axis_place": + return toolName + ": " + strings.TrimSpace(result) + case "axis_reservations": + if strings.Contains(result, "Active reservations") { + lines := strings.Split(result, "\n") + if len(lines) >= 2 { + count := 0 + for _, l := range lines[1:] { + if strings.HasPrefix(l, "-") { + count++ + } + } + return fmt.Sprintf("%s: found %d nodes with active reservations", toolName, count) + } + } + return toolName + ": no active reservations" + case "read_file": + lines := strings.Count(result, "\n") + return fmt.Sprintf("%s: read %d lines (%d chars)", toolName, lines, len(result)) + case "list_directory": + if i := strings.Index(result, "("); i > 0 && strings.Contains(result, " entries)") { + return toolName + ": " + strings.TrimSpace(result[strings.Index(result, "Directory:")+len("Directory:"):]) + } + return toolName + ": listed directory" + case "run_shell": + return toolName + ": executed shell command" + } + return fmt.Sprintf("%s returned %d chars", toolName, len(result)) +} + // dispatchToolCall handles a single tool call with safety gating and confirmation. func (a *Agent) dispatchToolCall(ctx context.Context, tc chat.ToolCall) (string, error) { name := tc.Function.Name @@ -189,7 +237,6 @@ func (a *Agent) dispatchToolCall(ctx context.Context, tc chat.ToolCall) (string, } // 4. Read-only tools execute directly (no confirmation needed). - fmt.Fprintf(a.output, "\n▶ Executing: %s\n", name) return a.tools.Execute(ctx, name, args) } diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index 0f3cd26..27b71f6 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -7,6 +7,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" @@ -107,7 +109,10 @@ func TestToolRegistryHasAllDefaultTools(t *testing.T) { tc := &ToolContext{} r := NewToolRegistry(tc) - expected := []string{"axis_status", "axis_facts", "axis_place", "run_shell"} + expected := []string{ + "axis_status", "axis_facts", "axis_place", "axis_summary", + "axis_reservations", "read_file", "list_directory", "run_shell", + } for _, name := range expected { if !r.HasTool(name) { t.Errorf("expected tool %q to be registered", name) @@ -142,8 +147,8 @@ func TestToolStatusNilSnapshot(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !strings.Contains(result, "no snapshot available") { - t.Errorf("expected 'no snapshot available', got: %s", result) + if !strings.Contains(result, "No cluster snapshot available") { + t.Errorf("expected 'No cluster snapshot available', got: %s", result) } } @@ -203,6 +208,104 @@ func TestToolShellBlockedByDesign(t *testing.T) { } } +func TestToolSummary(t *testing.T) { + snap := &models.ClusterSnapshot{ + Status: "healthy", + Nodes: []models.NodeFacts{{Name: "test-node"}}, + Summary: models.ClusterSummary{ + TotalNodes: 1, + ReachableNodes: 1, + }, + } + tc := &ToolContext{Snapshot: snap} + r := NewToolRegistry(tc) + + result, err := r.Execute(context.Background(), "axis_summary", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "1 nodes (1 reachable), status: healthy") { + t.Errorf("expected summary result, got: %s", result) + } +} + +func TestToolReadFile(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.txt") + content := []byte("Hello, AXIS!") + if err := os.WriteFile(tmpFile, content, 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + // Change to temp dir so relative paths work. + origDir, _ := os.Getwd() + os.Chdir(tmpDir) + defer os.Chdir(origDir) + + tc := &ToolContext{} + r := NewToolRegistry(tc) + + result, err := r.Execute(context.Background(), "read_file", json.RawMessage(fmt.Sprintf(`{"path":%q}`, "test.txt"))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "Hello, AXIS!") { + t.Errorf("expected file content, got: %s", result) + } +} + +func TestToolReadFilePathValidation(t *testing.T) { + tc := &ToolContext{} + r := NewToolRegistry(tc) + + _, err := r.Execute(context.Background(), "read_file", json.RawMessage(`{"path":"../secret.txt"}`)) + if err == nil { + t.Fatal("expected error for path traversal") + } + // EvalSymlinks may report lstat failure for non-existent paths outside cwd. + if !strings.Contains(err.Error(), "escapes") && !strings.Contains(err.Error(), "cannot resolve") { + t.Errorf("expected 'escapes' or 'cannot resolve' error, got: %s", err.Error()) + } +} + +func TestToolListDirectory(t *testing.T) { + tmpDir := t.TempDir() + for _, name := range []string{"a.txt", "b.txt"} { + path := filepath.Join(tmpDir, name) + if err := os.WriteFile(path, []byte("test"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + } + + origDir, _ := os.Getwd() + os.Chdir(tmpDir) + defer os.Chdir(origDir) + + tc := &ToolContext{} + r := NewToolRegistry(tc) + + result, err := r.Execute(context.Background(), "list_directory", json.RawMessage(`{"path":"."}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "a.txt") || !strings.Contains(result, "b.txt") { + t.Errorf("expected directory listing with files, got: %s", result) + } +} + +func TestToolListDirectoryPathValidation(t *testing.T) { + tc := &ToolContext{} + r := NewToolRegistry(tc) + + _, err := r.Execute(context.Background(), "list_directory", json.RawMessage(`{"path":"/etc/shadow"}`)) + if err == nil { + t.Fatal("expected error for absolute path") + } + if !strings.Contains(err.Error(), "escapes") { + t.Errorf("expected 'escapes' error, got: %s", err.Error()) + } +} + // --- Confirmation Tests --- func TestDefaultConfirmYes(t *testing.T) { @@ -676,7 +779,10 @@ func TestExecuteShellExitError(t *testing.T) { // --- isReadOnlyTool Tests --- func TestIsReadOnlyTool(t *testing.T) { - readOnly := []string{"axis_status", "axis_facts", "axis_place"} + readOnly := []string{ + "axis_status", "axis_facts", "axis_place", "axis_summary", + "axis_reservations", "read_file", "list_directory", + } for _, name := range readOnly { if !isReadOnlyTool(name) { t.Errorf("expected %q to be read-only", name) diff --git a/internal/agent/confirm.go b/internal/agent/confirm.go index cc5fd25..a86964f 100644 --- a/internal/agent/confirm.go +++ b/internal/agent/confirm.go @@ -78,7 +78,8 @@ func AutoApproveConfirm(threshold int, fallback ConfirmFunc) ConfirmFunc { // isReadOnlyTool returns true for tools that only read cluster state. func isReadOnlyTool(name string) bool { switch name { - case "axis_status", "axis_facts", "axis_place": + case "axis_status", "axis_facts", "axis_place", "axis_summary", + "axis_reservations", "read_file", "list_directory": return true } return false diff --git a/internal/agent/summarize.go b/internal/agent/summarize.go new file mode 100644 index 0000000..e9766a4 --- /dev/null +++ b/internal/agent/summarize.go @@ -0,0 +1,137 @@ +package agent + +import ( + "fmt" + "strings" + + "github.com/toasterbook88/axis/internal/models" +) + +// summarizeSnapshot returns a compact human-readable summary of cluster state +// suitable for feeding back to an LLM. Keeps output under ~600 chars. +func summarizeSnapshot(snap *models.ClusterSnapshot) string { + if snap == nil { + return "No cluster snapshot available — cluster may not be configured." + } + + var b strings.Builder + fmt.Fprintf(&b, "Cluster: %d nodes (%d reachable), status: %s\n", + snap.Summary.TotalNodes, snap.Summary.ReachableNodes, snap.Status) + if snap.Summary.TotalRAMMB > 0 { + fmt.Fprintf(&b, "Total RAM: %d MB total, %d MB free\n", + snap.Summary.TotalRAMMB, snap.Summary.TotalFreeRAMMB) + } + + for i, n := range snap.Nodes { + if i >= 6 { + fmt.Fprintf(&b, "... and %d more nodes\n", len(snap.Nodes)-i) + break + } + status := string(n.Status) + if n.Error != "" { + status += " — " + truncate(n.Error, 40) + } + line := fmt.Sprintf("- %s (%s): %s", n.Name, n.Hostname, status) + if n.Resources != nil { + line += fmt.Sprintf(", %d MB free RAM, %d cores", n.Resources.RAMFreeMB, n.Resources.CPUCores) + if len(n.Resources.GPUs) > 0 { + gpuNames := make([]string, 0, len(n.Resources.GPUs)) + for _, g := range n.Resources.GPUs { + gpuNames = append(gpuNames, g.GPUName()) + } + line += fmt.Sprintf(", GPUs: %s", strings.Join(gpuNames, ", ")) + } + } + b.WriteString(line + "\n") + } + + if len(snap.Warnings) > 0 { + b.WriteString("Warnings:\n") + for i, w := range snap.Warnings { + if i >= 3 { + fmt.Fprintf(&b, "... and %d more warnings\n", len(snap.Warnings)-i) + break + } + fmt.Fprintf(&b, "- %s: %s\n", w.Node, truncate(w.Message, 60)) + } + } + + return b.String() +} + +// summarizeNodeFacts returns a compact human-readable summary of a single node. +func summarizeNodeFacts(n models.NodeFacts) string { + var b strings.Builder + fmt.Fprintf(&b, "Node: %s (%s/%s, %s)\n", n.Name, n.OS, n.Arch, n.Hostname) + if n.Resources != nil { + r := n.Resources + fmt.Fprintf(&b, "CPU: %d cores (%s)\n", r.CPUCores, truncate(r.CPUModel, 40)) + fmt.Fprintf(&b, "RAM: %d MB total, %d MB free\n", r.RAMTotalMB, r.RAMFreeMB) + fmt.Fprintf(&b, "Disk: %d GB total, %d GB free\n", r.DiskTotalGB, r.DiskFreeGB) + if r.Load1M > 0 { + fmt.Fprintf(&b, "Load: %.2f (1m)\n", r.Load1M) + } + if len(r.GPUs) > 0 { + for _, g := range r.GPUs { + fmt.Fprintf(&b, "GPU: %s (%s, %d MB VRAM)\n", g.GPUName(), g.Vendor, g.VRAMMB) + } + } + if r.Pressure != "" && r.Pressure != "none" { + fmt.Fprintf(&b, "Pressure: %s\n", r.Pressure) + } + if r.ThermalState != "" && r.ThermalState != "nominal" { + fmt.Fprintf(&b, "Thermal: %s\n", r.ThermalState) + } + } + if len(n.Tools) > 0 { + toolNames := make([]string, 0, len(n.Tools)) + for _, t := range n.Tools { + toolNames = append(toolNames, t.Name) + } + fmt.Fprintf(&b, "Tools: %s\n", strings.Join(toolNames, ", ")) + } + if n.Ollama != nil && n.Ollama.Installed { + fmt.Fprintf(&b, "Ollama: %s (%d models)\n", n.Ollama.Version, len(n.Ollama.Models)) + } + fmt.Fprintf(&b, "Status: %s\n", n.Status) + if n.Error != "" { + fmt.Fprintf(&b, "Error: %s\n", truncate(n.Error, 100)) + } + return b.String() +} + +// summarizePlacementDecision returns a compact human-readable summary. +func summarizePlacementDecision(dec models.PlacementDecision) string { + if !dec.OK { + return "Placement: no suitable node found for this task." + } + var b strings.Builder + fmt.Fprintf(&b, "Placement: %s (fit score: %d/100", dec.Node, dec.FitScore) + if dec.IsLocal { + b.WriteString(", local") + } + b.WriteString(")\n") + if dec.Workload.Class != "" { + fmt.Fprintf(&b, "Workload class: %s\n", dec.Workload.Class) + } + if len(dec.Reasoning) > 0 { + b.WriteString("Reasoning:\n") + for _, r := range dec.Reasoning { + fmt.Fprintf(&b, "- %s\n", r) + } + } + return b.String() +} + +// truncate truncates a string to maxLen runes, appending "..." if truncated. +// Safe for UTF-8 — operates on runes, not bytes. +func truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + if maxLen <= 3 { + return string(runes[:maxLen]) + } + return string(runes[:maxLen-3]) + "..." +} diff --git a/internal/agent/tools.go b/internal/agent/tools.go index 69ea4b8..6d69ede 100644 --- a/internal/agent/tools.go +++ b/internal/agent/tools.go @@ -5,7 +5,10 @@ import ( "context" "encoding/json" "fmt" + "io" + "os" "os/exec" + "path/filepath" "strings" "time" @@ -42,6 +45,10 @@ func NewToolRegistry(tc *ToolContext) *ToolRegistry { r.registerStatus(tc) r.registerFacts(tc) r.registerPlace(tc) + r.registerSummary(tc) + r.registerReservations(tc) + r.registerReadFile() + r.registerListDirectory() r.registerShell() return r } @@ -91,17 +98,13 @@ func (r *ToolRegistry) add(name, description string, params json.RawMessage, exe func (r *ToolRegistry) registerStatus(tc *ToolContext) { r.add("axis_status", - "Return the current AXIS cluster status snapshot as JSON. Use this to answer questions about node health, resources, and cluster state.", + "Return a compact human-readable summary of the current AXIS cluster status. Includes node count, health, resources, and warnings. Use this for cluster overview questions.", json.RawMessage(`{"type":"object","properties":{}}`), func(ctx context.Context, args json.RawMessage) (string, error) { if tc.Snapshot == nil { - return `{"error":"no snapshot available — cluster may not be configured"}`, nil + return "No cluster snapshot available — cluster may not be configured.", nil } - out, err := json.Marshal(tc.Snapshot) - if err != nil { - return "", fmt.Errorf("marshal snapshot: %w", err) - } - return string(out), nil + return summarizeSnapshot(tc.Snapshot), nil }, ) } @@ -110,19 +113,15 @@ func (r *ToolRegistry) registerStatus(tc *ToolContext) { func (r *ToolRegistry) registerFacts(tc *ToolContext) { r.add("axis_facts", - "Return local hardware facts for the current machine (CPU, RAM, disk, GPUs, installed tools).", + "Return a compact human-readable summary of local hardware facts for the current machine (CPU, RAM, disk, GPUs, installed tools, Ollama status).", json.RawMessage(`{"type":"object","properties":{}}`), func(ctx context.Context, args json.RawMessage) (string, error) { if tc.Snapshot != nil { if n, ok := models.FindLocalNode(tc.Snapshot.Nodes); ok { - out, err := json.Marshal(n) - if err != nil { - return "", fmt.Errorf("marshal facts: %w", err) - } - return string(out), nil + return summarizeNodeFacts(n), nil } } - return `{"error":"local node not found in snapshot"}`, nil + return "Local node not found in snapshot.", nil }, ) } @@ -135,7 +134,7 @@ type placeArgs struct { func (r *ToolRegistry) registerPlace(tc *ToolContext) { r.add("axis_place", - "Select the best node for a task description. Returns a placement decision with node, fit score, and reasoning.", + "Select the best node for a task description. Returns a human-readable placement decision with node name, fit score, and reasoning.", json.RawMessage(`{"type":"object","properties":{"description":{"type":"string","description":"What the task needs to do"}},"required":["description"]}`), func(ctx context.Context, args json.RawMessage) (string, error) { var a placeArgs @@ -146,19 +145,186 @@ func (r *ToolRegistry) registerPlace(tc *ToolContext) { return "", fmt.Errorf("axis_place requires a non-empty \"description\" argument") } if tc.Snapshot == nil || len(tc.Snapshot.Nodes) == 0 { - return `{"ok":false,"reasoning":["no nodes available in snapshot"]}`, nil + return "Placement: no nodes available in snapshot.", nil } reqs := placement.InferRequirements(a.Description) decision := placement.SelectBestNode(reqs, tc.Snapshot.Nodes, tc.State) - out, err := json.Marshal(decision) + return summarizePlacementDecision(decision), nil + }, + ) +} + +// --- Tool: axis_summary --- + +func (r *ToolRegistry) registerSummary(tc *ToolContext) { + r.add("axis_summary", + "Return an ultra-compact one-line summary of the cluster (node count, health, total RAM). Good for quick status checks.", + json.RawMessage(`{"type":"object","properties":{}}`), + func(ctx context.Context, args json.RawMessage) (string, error) { + if tc.Snapshot == nil { + return "No cluster snapshot available.", nil + } + var b strings.Builder + fmt.Fprintf(&b, "%d nodes (%d reachable), status: %s", + tc.Snapshot.Summary.TotalNodes, tc.Snapshot.Summary.ReachableNodes, tc.Snapshot.Status) + if tc.Snapshot.Summary.TotalRAMMB > 0 { + fmt.Fprintf(&b, ", %d MB RAM total, %d MB free", + tc.Snapshot.Summary.TotalRAMMB, tc.Snapshot.Summary.TotalFreeRAMMB) + } + if len(tc.Snapshot.Warnings) > 0 { + fmt.Fprintf(&b, ", %d warnings", len(tc.Snapshot.Warnings)) + } + return b.String(), nil + }, + ) +} + +// --- Tool: axis_reservations --- + +func (r *ToolRegistry) registerReservations(tc *ToolContext) { + r.add("axis_reservations", + "List active reservations and task assignments across the cluster.", + json.RawMessage(`{"type":"object","properties":{}}`), + func(ctx context.Context, args json.RawMessage) (string, error) { + if tc.State == nil || len(tc.State.Nodes) == 0 { + return "No reservation state available.", nil + } + var b strings.Builder + fmt.Fprintf(&b, "Active reservations for %d nodes:\n", len(tc.State.Nodes)) + for name, ns := range tc.State.Nodes { + if ns.ActiveTasks == 0 && ns.ReservedMB == 0 { + continue + } + fmt.Fprintf(&b, "- %s: %d active tasks, %d MB reserved\n", name, ns.ActiveTasks, ns.ReservedMB) + if ns.LastTask != "" { + fmt.Fprintf(&b, " Last task: %s\n", truncate(ns.LastTask, 60)) + } + if len(ns.ActiveExecs) > 0 { + fmt.Fprintf(&b, " Active execs: %s\n", strings.Join(ns.ActiveExecs, ", ")) + } + } + return b.String(), nil + }, + ) +} + +// --- Tool: read_file --- + +type readFileArgs struct { + Path string `json:"path"` +} + +func (r *ToolRegistry) registerReadFile() { + r.add("read_file", + "Read the contents of a file at a given path. Returns the file contents as text. Paths are restricted to the current working directory and its subdirectories for safety.", + json.RawMessage(`{"type":"object","properties":{"path":{"type":"string","description":"Relative or absolute file path"}},"required":["path"]}`), + func(ctx context.Context, args json.RawMessage) (string, error) { + var a readFileArgs + if err := json.Unmarshal(args, &a); err != nil { + return "", fmt.Errorf("invalid arguments for read_file: expected {\"path\": \"...\"}, got: %s", string(args)) + } + if a.Path == "" { + return "", fmt.Errorf("read_file requires a non-empty \"path\" argument") + } + clean, err := validateToolPath(a.Path) + if err != nil { + return "", err + } + f, err := os.Open(clean) if err != nil { - return "", fmt.Errorf("marshal decision: %w", err) + return "", fmt.Errorf("cannot read file %q: %w", clean, err) + } + defer f.Close() + + const maxFileSize = 8000 + // Read up to maxFileSize+1 to detect truncation. + limited := io.LimitReader(f, int64(maxFileSize)+1) + data, err := io.ReadAll(limited) + if err != nil { + return "", fmt.Errorf("cannot read file %q: %w", clean, err) + } + content := string(data) + if len(data) > maxFileSize { + content = truncateRune(content, maxFileSize) + "\n... [truncated to 8000 chars]" } - return string(out), nil + return content, nil }, ) } +// --- Tool: list_directory --- + +type listDirArgs struct { + Path string `json:"path"` +} + +func (r *ToolRegistry) registerListDirectory() { + r.add("list_directory", + "List files and directories at a given path. Returns a human-readable directory listing. Paths are restricted to the current working directory and its subdirectories for safety.", + json.RawMessage(`{"type":"object","properties":{"path":{"type":"string","description":"Relative or absolute directory path"}},"required":["path"]}`), + func(ctx context.Context, args json.RawMessage) (string, error) { + var a listDirArgs + if err := json.Unmarshal(args, &a); err != nil { + return "", fmt.Errorf("invalid arguments for list_directory: expected {\"path\": \"...\"}, got: %s", string(args)) + } + if a.Path == "" { + a.Path = "." + } + clean, err := validateToolPath(a.Path) + if err != nil { + return "", err + } + entries, err := os.ReadDir(clean) + if err != nil { + return "", fmt.Errorf("cannot read directory %q: %w", clean, err) + } + var b strings.Builder + const maxDirEntries = 100 + fmt.Fprintf(&b, "Directory: %s (%d entries)\n", clean, len(entries)) + for i, e := range entries { + if i >= maxDirEntries { + fmt.Fprintf(&b, "... and %d more entries\n", len(entries)-i) + break + } + name := e.Name() + if e.IsDir() { + name += "/" + } + b.WriteString(name + "\n") + } + return b.String(), nil + }, + ) +} + +// validateToolPath validates and resolves a path for file tools, preventing +// directory traversal outside the current working directory. Symlinks are +// resolved before the bounds check to prevent symlink-based escapes. +func validateToolPath(p string) (string, error) { + cwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("cannot determine working directory: %w", err) + } + clean := filepath.Clean(p) + if !filepath.IsAbs(clean) { + clean = filepath.Join(cwd, clean) + } + // Resolve symlinks to their real destination. + resolved, err := filepath.EvalSymlinks(clean) + if err != nil { + return "", fmt.Errorf("cannot resolve path %q: %w", p, err) + } + // Ensure the resolved path is within cwd. + rel, err := filepath.Rel(cwd, resolved) + if err != nil { + return "", fmt.Errorf("invalid path %q: %w", p, err) + } + if strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("path %q escapes working directory", p) + } + return resolved, nil +} + // --- Tool: run_shell --- type shellArgs struct { @@ -229,8 +395,20 @@ func ExecuteShell(ctx context.Context, command string) (string, error) { // Cap output to prevent blowing up the context window. const maxOutput = 4000 - if len(output) > maxOutput { - output = output[:maxOutput] + "\n... [truncated to 4000 chars]" + if len([]rune(output)) > maxOutput { + output = truncateRune(output, maxOutput) + "\n... [truncated to 4000 chars]" } return output, nil } + +// truncateRune truncates a string to maxLen runes, appending "..." if truncated. +func truncateRune(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + if maxLen <= 3 { + return string(runes[:maxLen]) + } + return string(runes[:maxLen-3]) + "..." +} diff --git a/internal/chat/persist.go b/internal/chat/persist.go new file mode 100644 index 0000000..4d5393c --- /dev/null +++ b/internal/chat/persist.go @@ -0,0 +1,84 @@ +package chat + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// conversationHistory stores a serializable conversation record. +type conversationHistory struct { + Messages []Message `json:"messages"` +} + +// PersistPath returns the default path for conversation history files. +func PersistPath(name string) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("cannot determine home directory: %w", err) + } + return filepath.Join(home, ".axis", name+"-history.json"), nil +} + +// SaveToFile writes the conversation (excluding system messages) to the given path. +func (c *Conversation) SaveToFile(path string) error { + var hist conversationHistory + for _, m := range c.messages { + // Skip system messages — they are reconstructed on load. + if m.Role == RoleSystem { + continue + } + hist.Messages = append(hist.Messages, m) + } + data, err := json.MarshalIndent(hist, "", " ") + if err != nil { + return fmt.Errorf("marshal conversation: %w", err) + } + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return fmt.Errorf("create history directory: %w", err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + return fmt.Errorf("write history file: %w", err) + } + return nil +} + +// LoadFromFile restores non-system messages from a file into the conversation. +// If the file does not exist, this is a no-op. +func (c *Conversation) LoadFromFile(path string) error { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read history file: %w", err) + } + var hist conversationHistory + if err := json.Unmarshal(data, &hist); err != nil { + return fmt.Errorf("unmarshal conversation: %w", err) + } + for _, m := range hist.Messages { + // Skip system messages to avoid duplicates. + if m.Role == RoleSystem { + continue + } + c.messages = append(c.messages, m) + } + // Compact if needed. + if c.maxChars > 0 { + c.compact() + } + return nil +} + +// HistoryCount returns the number of non-system messages in the conversation. +func (c *Conversation) HistoryCount() int { + n := 0 + for _, m := range c.messages { + if m.Role != RoleSystem { + n++ + } + } + return n +} diff --git a/internal/chat/persist_test.go b/internal/chat/persist_test.go new file mode 100644 index 0000000..363d63d --- /dev/null +++ b/internal/chat/persist_test.go @@ -0,0 +1,96 @@ +package chat + +import ( + "os" + "path/filepath" + "testing" +) + +func TestConversationSaveAndLoad(t *testing.T) { + tmp := filepath.Join(t.TempDir(), "test-history.json") + + c := NewConversation(4096) + c.Append(Message{Role: RoleSystem, Content: "sys prompt"}) + c.Append(Message{Role: RoleUser, Content: "hello"}) + c.Append(Message{Role: RoleAssistant, Content: "hi there"}) + c.Append(Message{Role: RoleTool, Content: "tool result"}) + + if err := c.SaveToFile(tmp); err != nil { + t.Fatalf("save: %v", err) + } + + // Verify file exists and does NOT contain system messages. + data, err := os.ReadFile(tmp) + if err != nil { + t.Fatalf("read: %v", err) + } + if string(data) == "" { + t.Fatal("history file is empty") + } + if contains(t, string(data), "sys prompt") { + t.Error("history file should not contain system messages") + } + + // Load into fresh conversation. + c2 := NewConversation(4096) + c2.Append(Message{Role: RoleSystem, Content: "reconstructed sys prompt"}) + if err := c2.LoadFromFile(tmp); err != nil { + t.Fatalf("load: %v", err) + } + + msgs := c2.Messages() + if len(msgs) != 4 { + t.Fatalf("expected 4 messages after load, got %d", len(msgs)) + } + if msgs[0].Role != RoleSystem || msgs[0].Content != "reconstructed sys prompt" { + t.Errorf("system message mismatch: %+v", msgs[0]) + } + if msgs[1].Content != "hello" { + t.Errorf("user message mismatch: %q", msgs[1].Content) + } + if msgs[2].Content != "hi there" { + t.Errorf("assistant message mismatch: %q", msgs[2].Content) + } + if msgs[3].Content != "tool result" { + t.Errorf("tool message mismatch: %q", msgs[3].Content) + } +} + +func TestConversationLoadMissingFile(t *testing.T) { + tmp := filepath.Join(t.TempDir(), "nonexistent.json") + c := NewConversation(4096) + c.Append(Message{Role: RoleSystem, Content: "sys"}) + + err := c.LoadFromFile(tmp) + if err != nil { + t.Fatalf("loading missing file should be a no-op, got: %v", err) + } + if c.Len() != 1 { + t.Fatalf("expected 1 message, got %d", c.Len()) + } +} + +func TestConversationHistoryCount(t *testing.T) { + c := NewConversation(4096) + c.Append(Message{Role: RoleSystem, Content: "sys"}) + c.Append(Message{Role: RoleUser, Content: "a"}) + c.Append(Message{Role: RoleAssistant, Content: "b"}) + + if c.HistoryCount() != 2 { + t.Errorf("expected history count 2, got %d", c.HistoryCount()) + } +} + +func contains(t *testing.T, s, substr string) bool { + t.Helper() + return len(substr) > 0 && len(s) >= len(substr) && (s == substr || len(s) > len(substr) && s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || findSubstr(s, substr)) +} + +func findSubstr(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}