diff --git a/config.example.yml b/config.example.yml index da44a45..a0cc84c 100644 --- a/config.example.yml +++ b/config.example.yml @@ -1,5 +1,7 @@ -deployments_path: ~/flatrun/deployments +deployments_path: /home/nfebe/flatrun/deployments +system_files_root: / docker_socket: unix:///var/run/docker.sock +default_timeout: 2m0s api: host: 0.0.0.0 port: 8090 @@ -26,6 +28,7 @@ nginx: reload_command: nginx -s reload external: false container_webroot_path: "" + reject_unknown_domains: false certbot: enabled: true image: certbot/certbot @@ -35,6 +38,9 @@ certbot: webroot_path: "" container_webroot_path: "" dns_provider: "" + auto_renewal_enabled: false + renewal_threshold_days: 30 + renewal_check_interval: 12h0m0s logging: level: info format: json @@ -58,13 +64,16 @@ infrastructure: host: "" port: 6379 password: "" -# cluster: -# enabled: false -# server_name: "" # defaults to OS hostname -# advertise_url: "" # reachable URL for this agent (e.g. https://my-server:8090) -# health_interval: "30s" -# request_timeout: "10s" - + powerdns: + enabled: false + container: powerdns + image: powerdns/pdns-auth-48:latest + api_port: 8081 + dns_port: 53 + api_key: 4bd02ab8c96205b6901660e53f7d96a2bf5253596f378d12 + data_path: "" + default_soa: "" + nameservers: "" security: enabled: true realtime_capture: false @@ -74,8 +83,44 @@ security: auto_block_enabled: true auto_block_threshold: 50 auto_block_duration: 24h0m0s - # Only list proxies you control. Forwarded client IPs are honored solely - # when the connecting peer matches; an empty list ignores them, which - # prevents X-Forwarded-For / CF-Connecting-IP spoofing. + detection_window: 2m0s + not_found_threshold: 10 + auth_failure_threshold: 5 + unique_paths_threshold: 20 + repeated_hits_threshold: 30 + internal_api_token: ff66c51caad086544e7a6372b681da856146ae5586b6b59408d613fbc85f0045 trusted_proxies: [] trust_cf_header: false +audit: + enabled: false + retention_days: 30 + capture_request_body: false + excluded_paths: + - /api/health + sensitive_fields: + - password + - token + - secret + - api_key + - authorization + cleanup_interval: 24h0m0s +cluster: + enabled: false + server_name: nfebe-zenbk-duo + advertise_url: "" + health_interval: 30s + request_timeout: 10s +system_terminal: + protected_mode: + enabled: false +cleanup: + timeout: 2m0s +plans: + ttl: 24h0m0s + retention_days: 30 +ai: + enabled: true + base_url: https://generativelanguage.googleapis.com/v1beta/openai/ + api_key: AIzaSyCYBIFuQmp35NVnQI68hzlV6l4BSTZ9_lM + model: gemini-2.5-flash + timeout: 1m0s diff --git a/go.mod b/go.mod index a003fc6..fbf4205 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/compose-spec/compose-go/v2 v2.10.1 github.com/creack/pty v1.1.24 github.com/digitalocean/godo v1.171.0 + github.com/distribution/reference v0.6.0 github.com/docker/docker v28.5.2+incompatible github.com/fsnotify/fsnotify v1.9.0 github.com/gin-contrib/cors v1.7.6 @@ -67,7 +68,6 @@ require ( github.com/containerd/typeurl/v2 v2.2.3 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/distribution/reference v0.6.0 // indirect github.com/docker/buildx v0.31.1 // indirect github.com/docker/cli v29.2.1+incompatible // indirect github.com/docker/compose/v5 v5.1.0 // indirect diff --git a/internal/ai/ai_test.go b/internal/ai/ai_test.go new file mode 100644 index 0000000..f9c3e5d --- /dev/null +++ b/internal/ai/ai_test.go @@ -0,0 +1,284 @@ +package ai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/flatrun/agent/pkg/config" +) + +func TestNewDisabled(t *testing.T) { + if _, err := New(&config.AIConfig{Enabled: false}); err != ErrDisabled { + t.Errorf("err = %v, want ErrDisabled", err) + } + if _, err := New(nil); err != ErrDisabled { + t.Errorf("nil cfg err = %v, want ErrDisabled", err) + } +} + +func TestRedactor(t *testing.T) { + r := NewRedactor([]string{"hunter2secret", "short", " spaced-secret-value "}) + + cases := []struct { + name string + in string + contains []string + excludes []string + minCount int + }{ + { + name: "known secret value", + in: "db error: auth failed for password hunter2secret retrying", + excludes: []string{"hunter2secret"}, + minCount: 1, + }, + { + name: "short values stay", + in: "level=short msg=ok", + contains: []string{"short"}, + }, + { + name: "credential assignment", + in: "MYSQL_ROOT_PASSWORD=supersafe123\napi_key: abc123def\nDEBUG=true", + contains: []string{"MYSQL_ROOT_PASSWORD=[REDACTED]", "api_key: [REDACTED]", "DEBUG=true"}, + excludes: []string{"supersafe123", "abc123def"}, + minCount: 2, + }, + { + name: "trimmed secret", + in: "token is spaced-secret-value here", + excludes: []string{"spaced-secret-value"}, + minCount: 1, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + out, count := r.Redact(tc.in) + for _, want := range tc.contains { + if !strings.Contains(out, want) { + t.Errorf("output %q missing %q", out, want) + } + } + for _, banned := range tc.excludes { + if strings.Contains(out, banned) { + t.Errorf("output %q still contains %q", out, banned) + } + } + if count < tc.minCount { + t.Errorf("count = %d, want >= %d", count, tc.minCount) + } + }) + } +} + +func TestOpenAICompatibleComplete(t *testing.T) { + var gotAuth string + var gotPayload map[string]interface{} + fake := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("path = %s", r.URL.Path) + } + _ = json.NewDecoder(r.Body).Decode(&gotPayload) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "model": "test-model", + "choices": []map[string]interface{}{ + {"message": map[string]string{"role": "assistant", "content": "diagnosis here"}}, + }, + "usage": map[string]int{"prompt_tokens": 10, "completion_tokens": 5}, + }) + })) + defer fake.Close() + + p, err := New(&config.AIConfig{ + Enabled: true, + BaseURL: fake.URL + "/v1/", + APIKey: "sk-test", + Model: "test-model", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + resp, err := p.Complete(context.Background(), Request{ + Messages: []Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Content != "diagnosis here" || resp.Model != "test-model" { + t.Errorf("resp = %+v", resp) + } + if resp.Usage.PromptTokens != 10 { + t.Errorf("usage = %+v", resp.Usage) + } + if gotAuth != "Bearer sk-test" { + t.Errorf("auth header = %q", gotAuth) + } + if gotPayload["model"] != "test-model" { + t.Errorf("payload model = %v", gotPayload["model"]) + } +} + +func TestOpenAICompatibleToolCalling(t *testing.T) { + var sentPayload map[string]interface{} + fake := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewDecoder(r.Body).Decode(&sentPayload) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "model": "test-model", + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "role": "assistant", + "content": "", + "tool_calls": []map[string]interface{}{{ + "id": "call_1", + "type": "function", + "function": map[string]interface{}{"name": "list_networks", "arguments": "{}"}, + }}, + }, + }}, + }) + })) + defer fake.Close() + + p, _ := New(&config.AIConfig{Enabled: true, BaseURL: fake.URL, Model: "test-model", Timeout: 5 * time.Second}) + resp, err := p.Complete(context.Background(), Request{ + Messages: []Message{ + {Role: "user", Content: "what networks exist?"}, + {Role: "assistant", ToolCalls: []ToolCall{{ID: "x", Name: "noop", Arguments: "{}"}}}, + {Role: "tool", ToolCallID: "x", Name: "noop", Content: "done"}, + }, + Tools: []Tool{{ + Name: "list_networks", + Description: "List docker networks", + Parameters: map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}, + }}, + }) + if err != nil { + t.Fatal(err) + } + if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "list_networks" { + t.Fatalf("tool calls = %+v", resp.ToolCalls) + } + + tools := sentPayload["tools"].([]interface{}) + if len(tools) != 1 { + t.Fatalf("tools not sent: %v", sentPayload["tools"]) + } + fn := tools[0].(map[string]interface{})["function"].(map[string]interface{}) + if fn["name"] != "list_networks" { + t.Errorf("tool name = %v", fn["name"]) + } + + // The assistant tool-call message and the tool result must reach the + // wire in OpenAI's nested shape. + msgs := sentPayload["messages"].([]interface{}) + assistant := msgs[1].(map[string]interface{}) + if _, ok := assistant["tool_calls"]; !ok { + t.Error("assistant tool_calls not serialized") + } + toolMsg := msgs[2].(map[string]interface{}) + if toolMsg["tool_call_id"] != "x" || toolMsg["role"] != "tool" { + t.Errorf("tool result message = %v", toolMsg) + } +} + +func TestOpenAICompatibleKeyless(t *testing.T) { + fake := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if auth := r.Header.Get("Authorization"); auth != "" { + t.Errorf("keyless request sent auth header %q", auth) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]string{"content": "ok"}}}, + }) + })) + defer fake.Close() + + p, _ := New(&config.AIConfig{Enabled: true, BaseURL: fake.URL, Model: "llama3"}) + resp, err := p.Complete(context.Background(), Request{Messages: []Message{{Role: "user", Content: "hi"}}}) + if err != nil { + t.Fatal(err) + } + if resp.Model != "llama3" { + t.Errorf("model fallback = %q, want configured model", resp.Model) + } +} + +func TestOpenAICompatibleErrorMapping(t *testing.T) { + fake := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":{"message":"invalid api key"}}`)) + })) + defer fake.Close() + + p, _ := New(&config.AIConfig{Enabled: true, BaseURL: fake.URL, Model: "m"}) + _, err := p.Complete(context.Background(), Request{Messages: []Message{{Role: "user", Content: "hi"}}}) + if err == nil || !strings.Contains(err.Error(), "invalid api key") || !strings.Contains(err.Error(), "401") { + t.Errorf("err = %v, want provider message and status", err) + } +} + +func TestBuildAssistMessagesTruncates(t *testing.T) { + long := strings.Repeat("x", contextBudget*2) + intent, ok := GetIntent("diagnose") + if !ok { + t.Fatal("diagnose intent missing") + } + msgs := BuildAssistMessages(intent, "deployment myapp", []Section{ + {Label: "docker-compose.yml", Content: "services: {}", Format: "yaml"}, + {Label: "Recent logs", Content: long}, + }, "why does it crash?", "https://flatrun.dev/docs/") + + if len(msgs) != 2 { + t.Fatalf("got %d messages", len(msgs)) + } + if msgs[0].Role != "system" || msgs[1].Role != "user" { + t.Errorf("roles = %s/%s", msgs[0].Role, msgs[1].Role) + } + if len(msgs[1].Content) > contextBudget+2000 { + t.Errorf("user message not truncated: %d chars", len(msgs[1].Content)) + } + if !strings.Contains(msgs[1].Content, "[... truncated ...]") { + t.Error("truncation marker missing") + } + if !strings.Contains(msgs[1].Content, strings.Repeat("x", 100)) { + t.Error("log tail missing from prompt") + } + if !strings.Contains(msgs[1].Content, "why does it crash?") { + t.Error("operator question missing from prompt") + } + if !strings.Contains(msgs[1].Content, "deployment myapp") { + t.Error("scope label missing from prompt") + } + if !strings.Contains(msgs[0].Content, "https://flatrun.dev/docs/") { + t.Error("docs link missing from system prompt") + } +} + +func TestIntentRegistry(t *testing.T) { + for _, key := range []string{"diagnose", "improve", "secure", "explain"} { + intent, ok := GetIntent(key) + if !ok { + t.Errorf("intent %q missing", key) + continue + } + msgs := BuildAssistMessages(intent, "the FlatRun host", []Section{{Label: "Output", Content: "boom"}}, "", "") + hasSuggestionFormat := strings.Contains(msgs[0].Content, "suggestions") + if intent.AllowSuggestions && !hasSuggestionFormat { + t.Errorf("intent %q should request suggestions", key) + } + if !intent.AllowSuggestions && hasSuggestionFormat { + t.Errorf("intent %q should not request suggestions", key) + } + } + if _, ok := GetIntent("nonsense"); ok { + t.Error("unknown intent should not resolve") + } +} diff --git a/internal/ai/assist.go b/internal/ai/assist.go new file mode 100644 index 0000000..48dce49 --- /dev/null +++ b/internal/ai/assist.go @@ -0,0 +1,165 @@ +package ai + +import ( + "fmt" + "strings" +) + +// Intent selects what the model is asked to do with the gathered +// context. Adding a capability is adding an entry here; the pipeline, +// endpoints and UI do not change. +type Intent struct { + Key string + Task string + AllowSuggestions bool +} + +var intents = map[string]Intent{ + "diagnose": { + Key: "diagnose", + AllowSuggestions: true, + Task: `First decide whether the context actually shows a problem. Normal startup messages, successful health checks and graceful shutdowns are signs of healthy operation, not failures; restarts without error output are usually operator actions. If everything indicates normal operation, state that plainly under ## Diagnosis and stop; do not propose changes. Never construct a problem so that there is something to fix. + +If there is a problem, answer with sections: +## Diagnosis +The most likely root cause, stated plainly in one or two sentences. +## Evidence +The specific lines or config fragments that support the diagnosis. +## Fix +Concrete steps the operator should take; show the exact command or config snippet when one applies. +If the context is insufficient for a confident diagnosis, say so and list what to check next.`, + }, + "improve": { + Key: "improve", + AllowSuggestions: true, + Task: `Review the context for improvements to reliability, performance and operability. Answer with sections: +## Findings +What could be better and why it matters, grounded in the context. +## Recommendations +Concrete, prioritized changes with the exact config or command for each where applicable.`, + }, + "secure": { + Key: "secure", + AllowSuggestions: true, + Task: `Review the context for security weaknesses and hardening opportunities. Answer with sections: +## Risks +Each weakness found and its impact, grounded in the context. Do not invent vulnerabilities the context does not support. +## Hardening +Concrete, prioritized steps with the exact config or command for each where applicable.`, + }, + "explain": { + Key: "explain", + AllowSuggestions: false, + Task: `Explain what the context shows in plain language for an operator who did not build this system. Answer with sections: +## Summary +What is happening, in a few sentences. +## Details +The notable parts explained simply, defining jargon briefly when unavoidable.`, + }, +} + +func GetIntent(key string) (Intent, bool) { + intent, ok := intents[key] + return intent, ok +} + +func IntentKeys() []string { + keys := make([]string, 0, len(intents)) + for k := range intents { + keys = append(keys, k) + } + return keys +} + +const assistBasePrompt = `You are the assistant of FlatRun, a flat-file container hosting platform: a single Go agent manages deployments (each a directory with a docker-compose.yml), Docker networks, an nginx reverse proxy and Let's Encrypt certificates on one host. + +FlatRun conventions: each deployment is a directory containing docker-compose.yml, an optional .env.flatrun env file and a service.yml metadata file. Services join pre-created external Docker networks: the configured proxy network connects apps to the nginx reverse proxy that serves them on the web, and the database network connects apps to shared databases. The agent generates one nginx virtual host per exposed deployment and manages Let's Encrypt certificates. Routing is defined in the deployment metadata: the reverse proxy forwards each domain to a service name and container port stored there; the compose "expose" field plays no role in FlatRun routing or health checks. Data uses bind mounts inside the deployment directory, never named volumes. + +Answers must fit this installation, not Docker in general. The "FlatRun platform context" section in the user message states this host's actual configuration and state; reconcile everything you recommend against it, and when the generic fix and the platform's way of doing things differ, recommend the platform's way. A finding must be supported by the context; when the evidence shows normal operation or is inconclusive, saying so is the correct answer and speculative fixes are wrong. + +%s + +Ground every statement in the provided context; never invent log lines or configuration. Secret values appear as [REDACTED]; that is expected and not an error.%s%s` + +const suggestionInstructions = ` + +If concrete steps can be run on the server, append ONE fenced code block with language tag "suggestions" containing a JSON array. Each entry is one of: +{"kind":"exec","service":"","command":"","title":"","reason":""} +{"kind":"service_action","service":"","action":"start|stop|restart|rebuild|pull","title":"","reason":""} +Suggest at most 3 actions, only ones directly supported by the evidence, never destructive commands (no rm, no DROP, no down). The operator reviews and runs them manually. Omit the block entirely when nothing safe applies.` + +const sessionBasePrompt = `You are the assistant of FlatRun, a flat-file container hosting platform: a single Go agent manages deployments (each a directory with a docker-compose.yml), Docker networks, an nginx reverse proxy and Let's Encrypt certificates on one host. + +FlatRun conventions: each deployment is a directory containing docker-compose.yml, an optional .env.flatrun env file and a service.yml metadata file. Services join pre-created external Docker networks: the configured proxy network connects apps to the nginx reverse proxy that serves them on the web, and the database network connects apps to shared databases. Routing is defined in deployment metadata: the reverse proxy forwards each domain to a service name and container port stored there; the compose "expose" field plays no role in FlatRun routing or health checks. Application logs and data live in bind-mounted files inside the deployment directory. + +You can investigate this specific installation with the provided tools: get host information (hostname, public IP), list the networks and deployments that actually exist, read a deployment's metadata, fetch its recent logs, read files it generated, run read-only commands inside service containers, and run read-only commands on the host itself for free-form questions the other tools do not cover. Application logs in FlatRun are the containers' captured stdout/stderr; fetch them with the logs tool rather than searching the filesystem for log files. + +How to use tools well: +- If the message already contains the logs or output to analyze, analyze them directly and answer with NO tool calls. +- When you are handed something to analyze (logs, an operation's output), respond with a short summary, then any problems you found with their likely solutions. If nothing is wrong, say so plainly. +- When deeper investigation would genuinely help, make the tool call directly; do not describe the command in prose and ask for permission. The operator's interface shows each call and lets them allow or decline it, so the approval happens there, not in your text. You may add one short sentence saying what you are about to check, then call the tool. +- Prefer one well-chosen lookup over many speculative ones, and never run tools just to appear thorough. + +%s + +Answer in Markdown. Ground every statement in what you observed; if the evidence shows normal operation or is inconclusive, say so plainly rather than inventing a problem. When you recommend a runnable fix, also append it as a "suggestions" block as described above. Secret values appear as [REDACTED]; that is expected.%s` + +// BuildSessionPrompt returns the system prompt for an interactive +// session, optionally scoped to one deployment and referencing the +// docs site. +func BuildSessionPrompt(scope, deployment, docsURL string) string { + scopeLine := "This is a general session about the whole FlatRun instance and its deployments." + if scope == SessionScopeDeployment && deployment != "" { + scopeLine = "This session is focused on the deployment named \"" + deployment + "\". When a tool takes a deployment argument, that deployment is the default." + } + docs := "" + if strings.TrimSpace(docsURL) != "" { + docs = "\n\nProduct documentation you may reference: " + strings.TrimSpace(docsURL) + } + return fmt.Sprintf(sessionBasePrompt, scopeLine+suggestionInstructions, docs) +} + +// Section is one labeled piece of gathered context, already redacted. +type Section struct { + Label string + Content string + Format string +} + +// BuildAssistMessages assembles the chat for an analysis. Sections +// must already be redacted. The newest end of long sections survives +// truncation since it matters most. +func BuildAssistMessages(intent Intent, scopeLabel string, sections []Section, question, docsURL string) []Message { + suggestions := "" + if intent.AllowSuggestions { + suggestions = suggestionInstructions + } + docs := "" + if strings.TrimSpace(docsURL) != "" { + docs = "\n\nProduct documentation you may reference in answers: " + strings.TrimSpace(docsURL) + } + system := fmt.Sprintf(assistBasePrompt, intent.Task, suggestions, docs) + + perSection := contextBudget + if len(sections) > 0 { + perSection = contextBudget / len(sections) + } + + var user strings.Builder + fmt.Fprintf(&user, "Scope: %s\n", scopeLabel) + for _, section := range sections { + format := section.Format + if format == "" { + format = "text" + } + fmt.Fprintf(&user, "\n## %s\n```%s\n%s\n```\n", section.Label, format, TruncateHead(section.Content, perSection)) + } + if strings.TrimSpace(question) != "" { + fmt.Fprintf(&user, "\n## Operator question\n%s\n", strings.TrimSpace(question)) + } + + return []Message{ + {Role: "system", Content: system}, + {Role: "user", Content: user.String()}, + } +} diff --git a/internal/ai/openai.go b/internal/ai/openai.go new file mode 100644 index 0000000..98558ff --- /dev/null +++ b/internal/ai/openai.go @@ -0,0 +1,184 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/flatrun/agent/pkg/config" +) + +const maxResponseBytes = 4 << 20 + +// openAICompatible speaks the OpenAI chat-completions wire format, +// which OpenAI, Ollama, vLLM, LM Studio and most gateways all accept. +type openAICompatible struct { + baseURL string + apiKey string + model string + client *http.Client +} + +func newOpenAICompatible(cfg *config.AIConfig) *openAICompatible { + timeout := cfg.Timeout + if timeout == 0 { + timeout = 60 * time.Second + } + return &openAICompatible{ + baseURL: strings.TrimRight(cfg.BaseURL, "/"), + apiKey: cfg.APIKey, + model: cfg.Model, + client: &http.Client{Timeout: timeout}, + } +} + +func (p *openAICompatible) Name() string { + return "openai-compatible" +} + +// wireMessages converts internal messages to the OpenAI chat wire +// format, where assistant tool calls and tool results are nested +// differently than our flat representation. +func wireMessages(messages []Message) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(messages)) + for _, m := range messages { + wm := map[string]interface{}{"role": m.Role, "content": m.Content} + if len(m.ToolCalls) > 0 { + calls := make([]map[string]interface{}, 0, len(m.ToolCalls)) + for _, tc := range m.ToolCalls { + calls = append(calls, map[string]interface{}{ + "id": tc.ID, + "type": "function", + "function": map[string]interface{}{ + "name": tc.Name, + "arguments": tc.Arguments, + }, + }) + } + wm["tool_calls"] = calls + } + if m.ToolCallID != "" { + wm["tool_call_id"] = m.ToolCallID + } + if m.Name != "" { + wm["name"] = m.Name + } + out = append(out, wm) + } + return out +} + +func wireTools(tools []Tool) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(tools)) + for _, t := range tools { + out = append(out, map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": t.Name, + "description": t.Description, + "parameters": t.Parameters, + }, + }) + } + return out +} + +func (p *openAICompatible) Complete(ctx context.Context, req Request) (*Response, error) { + payload := map[string]interface{}{ + "model": p.model, + "messages": wireMessages(req.Messages), + } + if len(req.Tools) > 0 { + payload["tools"] = wireTools(req.Tools) + } + if req.MaxTokens > 0 { + payload["max_tokens"] = req.MaxTokens + } + if req.Temperature > 0 { + payload["temperature"] = req.Temperature + } + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + // Local servers (Ollama, LM Studio) typically run keyless. + if p.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ai provider request failed: %w", err) + } + defer resp.Body.Close() + + raw, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return nil, fmt.Errorf("ai provider response read failed: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var apiErr struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } + msg := strings.TrimSpace(string(raw)) + if json.Unmarshal(raw, &apiErr) == nil && apiErr.Error.Message != "" { + msg = apiErr.Error.Message + } + return nil, fmt.Errorf("ai provider returned %d: %s", resp.StatusCode, msg) + } + + var parsed struct { + Model string `json:"model"` + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + } `json:"choices"` + Usage Usage `json:"usage"` + } + if err := json.Unmarshal(raw, &parsed); err != nil { + return nil, fmt.Errorf("ai provider returned invalid JSON: %w", err) + } + if len(parsed.Choices) == 0 { + return nil, fmt.Errorf("ai provider returned no choices") + } + + model := parsed.Model + if model == "" { + model = p.model + } + + choice := parsed.Choices[0].Message + var toolCalls []ToolCall + for _, tc := range choice.ToolCalls { + toolCalls = append(toolCalls, ToolCall{ID: tc.ID, Name: tc.Function.Name, Arguments: tc.Function.Arguments}) + } + + return &Response{ + Content: choice.Content, + ToolCalls: toolCalls, + Model: model, + Usage: parsed.Usage, + }, nil +} diff --git a/internal/ai/prompts.go b/internal/ai/prompts.go new file mode 100644 index 0000000..a50bd46 --- /dev/null +++ b/internal/ai/prompts.go @@ -0,0 +1,13 @@ +package ai + +// contextBudget bounds the prompt so small local models with limited +// context windows still work; long sections are truncated head-first +// since the newest lines matter most. +const contextBudget = 24000 + +func TruncateHead(s string, max int) string { + if len(s) <= max { + return s + } + return "[... truncated ...]\n" + s[len(s)-max:] +} diff --git a/internal/ai/provider.go b/internal/ai/provider.go new file mode 100644 index 0000000..aa05a46 --- /dev/null +++ b/internal/ai/provider.go @@ -0,0 +1,77 @@ +package ai + +import ( + "context" + "errors" + + "github.com/flatrun/agent/pkg/config" +) + +// ErrDisabled is returned by New when no provider is configured. The +// caller treats it as "feature off", not as a failure. +var ErrDisabled = errors.New("ai is not enabled") + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + // Display, when set, is what the UI shows for this turn instead of + // Content. Used to send bulky context (logs, output) to the model + // while showing the operator a short label. Never sent to the + // provider. + Display string `json:"display,omitempty"` + // ToolCalls is set on an assistant message that wants tools run. + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + // ToolCallID and Name identify a role:"tool" result message. + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +// ToolCall is one tool invocation requested by the model. Arguments is +// a JSON object string as produced by the model. +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// Tool is a function the model may call. Parameters is a JSON Schema +// object describing the arguments. +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} + +type Request struct { + Messages []Message + Tools []Tool + MaxTokens int + Temperature float64 +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` +} + +type Response struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Model string `json:"model"` + Usage Usage `json:"usage"` +} + +// Provider is the model-agnostic boundary: everything above this +// interface (handlers, prompts, redaction) is provider-neutral, and new +// backends plug in behind it without touching callers. +type Provider interface { + Name() string + Complete(ctx context.Context, req Request) (*Response, error) +} + +func New(cfg *config.AIConfig) (Provider, error) { + if cfg == nil || !cfg.Enabled { + return nil, ErrDisabled + } + return newOpenAICompatible(cfg), nil +} diff --git a/internal/ai/redact.go b/internal/ai/redact.go new file mode 100644 index 0000000..bab2b9f --- /dev/null +++ b/internal/ai/redact.go @@ -0,0 +1,60 @@ +package ai + +import ( + "regexp" + "sort" + "strings" +) + +const redactedMarker = "[REDACTED]" + +// minSecretLength keeps short values like "true", port numbers or +// single words from poisoning the whole text with replacements. +const minSecretLength = 6 + +var credentialPattern = regexp.MustCompile(`(?i)([a-z0-9_\-]*(?:password|passwd|secret|token|api_?key))(\s*[=:]\s*)("[^"\n]*"|'[^'\n]*'|[^\s,;]+)`) + +// Redactor removes known secret values and credential-shaped +// assignments from text before it leaves the host. +type Redactor struct { + secrets []string +} + +func NewRedactor(secrets []string) *Redactor { + seen := map[string]struct{}{} + var filtered []string + for _, s := range secrets { + s = strings.TrimSpace(s) + if len(s) < minSecretLength { + continue + } + if _, dup := seen[s]; dup { + continue + } + seen[s] = struct{}{} + filtered = append(filtered, s) + } + // Longest first, so a secret that contains another secret is + // replaced whole instead of being broken by the shorter match. + sort.Slice(filtered, func(i, j int) bool { return len(filtered[i]) > len(filtered[j]) }) + return &Redactor{secrets: filtered} +} + +func (r *Redactor) Redact(text string) (string, int) { + count := 0 + for _, s := range r.secrets { + if n := strings.Count(text, s); n > 0 { + text = strings.ReplaceAll(text, s, redactedMarker) + count += n + } + } + text = credentialPattern.ReplaceAllStringFunc(text, func(m string) string { + sub := credentialPattern.FindStringSubmatch(m) + if sub[3] == redactedMarker { + return m + } + count++ + return sub[1] + sub[2] + redactedMarker + }) + return text, count +} diff --git a/internal/ai/session.go b/internal/ai/session.go new file mode 100644 index 0000000..10a7a62 --- /dev/null +++ b/internal/ai/session.go @@ -0,0 +1,227 @@ +package ai + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +const ( + SessionStatusReady = "ready" + SessionStatusAwaitingApproval = "awaiting_approval" + SessionScopeSystem = "system" + SessionScopeDeployment = "deployment" + maxSessionToolSteps = 8 + maxSessionMessages = 200 +) + +type SessionActor struct { + ID string `json:"id"` + Name string `json:"name"` +} + +// Session is one ongoing AI conversation. It owns the full model +// transcript (including tool calls and results) plus the derived state +// the UI needs. Stored as a flat JSON file, true to FlatRun. +type Session struct { + ID string `json:"id"` + Scope string `json:"scope"` + Deployment string `json:"deployment,omitempty"` + AutoRun bool `json:"auto_run"` + Status string `json:"status"` + Model string `json:"model,omitempty"` + CreatedBy SessionActor `json:"created_by"` + Messages []Message `json:"messages"` + Pending []ToolCall `json:"pending,omitempty"` + Suggested []SuggestedAction `json:"suggested_actions"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func NewSession(scope, deployment string, autoRun bool, actor SessionActor, systemPrompt string) *Session { + now := time.Now().UTC() + return &Session{ + ID: "ais_" + uuid.New().String(), + Scope: scope, + Deployment: deployment, + AutoRun: autoRun, + Status: SessionStatusReady, + CreatedBy: actor, + Messages: []Message{{Role: "system", Content: systemPrompt}}, + Suggested: []SuggestedAction{}, + CreatedAt: now, + UpdatedAt: now, + } +} + +func (s *Session) touch() { + s.UpdatedAt = time.Now().UTC() + if len(s.Messages) > maxSessionMessages { + // Keep the system prompt and the most recent window. + head := s.Messages[:1] + tail := s.Messages[len(s.Messages)-(maxSessionMessages-1):] + s.Messages = append(head, tail...) + } +} + +// AddUserMessage records a user turn. When display differs from +// content, the model sees content (e.g. message plus embedded logs) +// while the UI shows display (e.g. a short label). +func (s *Session) AddUserMessage(content, display string) { + s.Messages = append(s.Messages, Message{Role: "user", Content: content, Display: display}) + s.touch() +} + +func (s *Session) AddAssistantMessage(content string, toolCalls []ToolCall) { + s.Messages = append(s.Messages, Message{Role: "assistant", Content: content, ToolCalls: toolCalls}) + s.touch() +} + +func (s *Session) AddToolResult(call ToolCall, result string) { + s.Messages = append(s.Messages, Message{Role: "tool", ToolCallID: call.ID, Name: call.Name, Content: result}) + s.touch() +} + +// MaxToolSteps is the per-turn cap on consecutive tool rounds, so a +// misbehaving model cannot loop forever. +func (s *Session) MaxToolSteps() int { return maxSessionToolSteps } + +var sessionIDPattern = regexp.MustCompile(`^ais_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + +var ErrSessionNotFound = fmt.Errorf("session not found") + +type SessionStore struct { + dir string + mu sync.Mutex +} + +func NewSessionStore(deploymentsPath string) *SessionStore { + return &SessionStore{dir: filepath.Join(deploymentsPath, ".flatrun", "ai-sessions")} +} + +func (st *SessionStore) path(id string) string { + return filepath.Join(st.dir, id+".json") +} + +func (st *SessionStore) Save(sess *Session) error { + if !sessionIDPattern.MatchString(sess.ID) { + return fmt.Errorf("invalid session id %q", sess.ID) + } + st.mu.Lock() + defer st.mu.Unlock() + if err := os.MkdirAll(st.dir, 0700); err != nil { + return err + } + data, err := json.MarshalIndent(sess, "", " ") + if err != nil { + return err + } + tmp := st.path(sess.ID) + ".tmp" + if err := os.WriteFile(tmp, data, 0600); err != nil { + return err + } + return os.Rename(tmp, st.path(sess.ID)) +} + +func (st *SessionStore) Get(id string) (*Session, error) { + if !sessionIDPattern.MatchString(id) { + return nil, ErrSessionNotFound + } + data, err := os.ReadFile(st.path(id)) + if err != nil { + if os.IsNotExist(err) { + return nil, ErrSessionNotFound + } + return nil, err + } + var sess Session + if err := json.Unmarshal(data, &sess); err != nil { + return nil, fmt.Errorf("corrupt session file: %w", err) + } + return &sess, nil +} + +func (st *SessionStore) Delete(id string) error { + if !sessionIDPattern.MatchString(id) { + return ErrSessionNotFound + } + return os.Remove(st.path(id)) +} + +// PruneOlderThan removes sessions whose last update predates the +// cutoff, keeping the flat-file directory from growing without bound. +func (st *SessionStore) PruneOlderThan(cutoff time.Time) int { + entries, err := os.ReadDir(st.dir) + if err != nil { + return 0 + } + removed := 0 + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") { + continue + } + id := strings.TrimSuffix(e.Name(), ".json") + sess, err := st.Get(id) + if err != nil { + continue + } + if sess.UpdatedAt.Before(cutoff) { + if st.Delete(id) == nil { + removed++ + } + } + } + return removed +} + +// DisplayMessages projects the transcript into UI-facing turns, +// dropping the system prompt and pairing tool calls with their results. +type DisplayToolStep struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + Result string `json:"result,omitempty"` +} + +type DisplayTurn struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolSteps []DisplayToolStep `json:"tool_steps,omitempty"` +} + +func (s *Session) DisplayMessages() []DisplayTurn { + results := map[string]string{} + for _, m := range s.Messages { + if m.Role == "tool" { + results[m.ToolCallID] = m.Content + } + } + turns := make([]DisplayTurn, 0, len(s.Messages)) + for _, m := range s.Messages { + switch m.Role { + case "user": + shown := m.Content + if m.Display != "" { + shown = m.Display + } + turns = append(turns, DisplayTurn{Role: "user", Content: shown}) + case "assistant": + turn := DisplayTurn{Role: "assistant", Content: m.Content} + for _, tc := range m.ToolCalls { + turn.ToolSteps = append(turn.ToolSteps, DisplayToolStep{ + Name: tc.Name, + Arguments: tc.Arguments, + Result: results[tc.ID], + }) + } + turns = append(turns, turn) + } + } + return turns +} diff --git a/internal/ai/suggestions.go b/internal/ai/suggestions.go new file mode 100644 index 0000000..40392a4 --- /dev/null +++ b/internal/ai/suggestions.go @@ -0,0 +1,68 @@ +package ai + +import ( + "encoding/json" + "regexp" + "strings" +) + +// SuggestedAction is a machine-actionable proposal extracted from a +// model response. It is never executed by the agent on its own; the +// client decides to run it through the normal, guarded APIs. +type SuggestedAction struct { + Kind string `json:"kind"` + Service string `json:"service,omitempty"` + Action string `json:"action,omitempty"` + Command string `json:"command,omitempty"` + Title string `json:"title"` + Reason string `json:"reason,omitempty"` +} + +const ( + SuggestionKindExec = "exec" + SuggestionKindServiceAction = "service_action" +) + +var validServiceActions = map[string]bool{ + "start": true, "stop": true, "restart": true, "rebuild": true, "pull": true, +} + +var suggestionsBlock = regexp.MustCompile("(?s)```suggestions\\s*(.*?)```") + +// ParseSuggestions extracts the fenced suggestions block from a model +// response, returning the response without the block and the valid +// actions found in it. Malformed blocks and invalid entries are +// dropped; diagnosis text is never lost over a bad suggestion. +func ParseSuggestions(content string) (string, []SuggestedAction) { + match := suggestionsBlock.FindStringSubmatch(content) + if match == nil { + return content, nil + } + cleaned := strings.TrimSpace(suggestionsBlock.ReplaceAllString(content, "")) + + var raw []SuggestedAction + if err := json.Unmarshal([]byte(strings.TrimSpace(match[1])), &raw); err != nil { + return cleaned, nil + } + + var valid []SuggestedAction + for _, s := range raw { + if s.Title == "" { + continue + } + switch s.Kind { + case SuggestionKindExec: + if strings.TrimSpace(s.Command) == "" || s.Service == "" { + continue + } + case SuggestionKindServiceAction: + if s.Service == "" || !validServiceActions[s.Action] { + continue + } + default: + continue + } + valid = append(valid, s) + } + return cleaned, valid +} diff --git a/internal/ai/suggestions_test.go b/internal/ai/suggestions_test.go new file mode 100644 index 0000000..2db9d77 --- /dev/null +++ b/internal/ai/suggestions_test.go @@ -0,0 +1,50 @@ +package ai + +import ( + "strings" + "testing" +) + +func TestParseSuggestions(t *testing.T) { + content := "## Diagnosis\nDB is down.\n\n```suggestions\n[\n" + + `{"kind":"service_action","service":"db","action":"restart","title":"Restart the database","reason":"connection refused"},` + "\n" + + `{"kind":"exec","service":"web","command":"php artisan config:clear","title":"Clear config cache"},` + "\n" + + `{"kind":"service_action","service":"db","action":"explode","title":"Invalid action"},` + "\n" + + `{"kind":"exec","service":"web","command":"","title":"Empty command"},` + "\n" + + `{"kind":"weird","title":"Unknown kind"}` + "\n]\n```\n" + + cleaned, suggestions := ParseSuggestions(content) + + if strings.Contains(cleaned, "suggestions") || strings.Contains(cleaned, "artisan") { + t.Errorf("block not stripped from analysis: %q", cleaned) + } + if !strings.Contains(cleaned, "DB is down.") { + t.Errorf("analysis text lost: %q", cleaned) + } + if len(suggestions) != 2 { + t.Fatalf("got %d suggestions, want 2 valid: %+v", len(suggestions), suggestions) + } + if suggestions[0].Kind != SuggestionKindServiceAction || suggestions[0].Action != "restart" { + t.Errorf("first = %+v", suggestions[0]) + } + if suggestions[1].Kind != SuggestionKindExec || suggestions[1].Command != "php artisan config:clear" { + t.Errorf("second = %+v", suggestions[1]) + } +} + +func TestParseSuggestionsNoBlock(t *testing.T) { + cleaned, suggestions := ParseSuggestions("## Diagnosis\nAll good.") + if cleaned != "## Diagnosis\nAll good." || suggestions != nil { + t.Errorf("content without block should pass through, got %q %v", cleaned, suggestions) + } +} + +func TestParseSuggestionsMalformedJSON(t *testing.T) { + cleaned, suggestions := ParseSuggestions("Text\n```suggestions\nnot json\n```") + if suggestions != nil { + t.Errorf("malformed block should yield no suggestions, got %v", suggestions) + } + if !strings.Contains(cleaned, "Text") { + t.Errorf("analysis lost: %q", cleaned) + } +} diff --git a/internal/api/ai_handlers.go b/internal/api/ai_handlers.go new file mode 100644 index 0000000..ecbbd31 --- /dev/null +++ b/internal/api/ai_handlers.go @@ -0,0 +1,308 @@ +package api + +import ( + "fmt" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/flatrun/agent/internal/ai" + "github.com/gin-gonic/gin" +) + +func (s *Server) getAIStatus(c *gin.Context) { + resp := gin.H{"enabled": s.aiProvider != nil, "intents": ai.IntentKeys()} + if s.aiProvider != nil { + resp["model"] = s.config.AI.Model + if u, err := url.Parse(s.config.AI.BaseURL); err == nil { + resp["base_url_host"] = u.Host + } + } + c.JSON(http.StatusOK, resp) +} + +type assistSource struct { + Type string `json:"type"` + Label string `json:"label"` + Content string `json:"content"` + Tail int `json:"tail"` +} + +type assistRequest struct { + Intent string `json:"intent"` + Sources []assistSource `json:"sources"` + Question string `json:"question"` +} + +// collectDeploymentSource gathers one labeled context section for a +// deployment scope. New source types (nginx logs, security events, +// linked codebases) register here without touching the pipeline. +func (s *Server) collectDeploymentSource(name string, src assistSource) (ai.Section, *apiError) { + switch src.Type { + case "logs": + tail := src.Tail + if tail <= 0 { + tail = 300 + } + if tail > 1000 { + tail = 1000 + } + logs, err := s.manager.GetDeploymentLogs(name, tail) + if err != nil { + return ai.Section{}, apiErrf(http.StatusInternalServerError, "Failed to read logs: %s", err.Error()) + } + return ai.Section{Label: "Recent logs", Content: logs}, nil + case "compose": + content, filename, err := s.manager.GetComposeFile(name) + if err != nil { + return ai.Section{}, apiErrf(http.StatusNotFound, "Compose file not found") + } + return ai.Section{Label: filename, Content: content, Format: "yaml"}, nil + case "provided": + if strings.TrimSpace(src.Content) == "" { + return ai.Section{}, apiErrf(http.StatusBadRequest, "provided source requires content") + } + label := src.Label + if label == "" { + label = "Provided output" + } + return ai.Section{Label: label, Content: src.Content}, nil + default: + return ai.Section{}, apiErrf(http.StatusBadRequest, "unknown source type %q", src.Type) + } +} + +// platformSection states this installation's live configuration so +// analyses are grounded in FlatRun specifics instead of generic Docker +// knowledge: which networks the platform expects, which networks +// actually exist, and how the deployment is wired into them. +func (s *Server) platformSection(deploymentName string) ai.Section { + var b strings.Builder + cfg := s.config + + fmt.Fprintf(&b, "Configured proxy network (apps must join this external network to be served on the web): %s\n", cfg.Infrastructure.DefaultProxyNetwork) + fmt.Fprintf(&b, "Configured database network (apps reach shared databases over this external network): %s\n", cfg.Infrastructure.DefaultDatabaseNetwork) + + if networks, err := s.networksManager.ListNetworks(); err == nil { + names := make([]string, 0, len(networks)) + for _, n := range networks { + names = append(names, n.Name) + } + fmt.Fprintf(&b, "Docker networks that currently exist on this host: %s\n", strings.Join(names, ", ")) + } + + if cfg.Infrastructure.Database.Enabled { + fmt.Fprintf(&b, "Shared %s database server is available at host %q on the database network\n", + cfg.Infrastructure.Database.Type, cfg.Infrastructure.Database.Host) + } + fmt.Fprintf(&b, "Nginx reverse proxy managed by FlatRun: %t\n", cfg.Nginx.Enabled) + + if deploymentName != "" { + if deployment, err := s.manager.GetDeployment(deploymentName); err == nil && deployment.Metadata != nil { + meta := deployment.Metadata + if domains := meta.GetDomains(); len(domains) > 0 { + for _, d := range domains { + target := d.Service + if d.ContainerPort > 0 { + target = fmt.Sprintf("%s on container port %d", d.Service, d.ContainerPort) + } + fmt.Fprintf(&b, "Reverse proxy routing: %s forwards to service %s (configured in deployment metadata; the compose expose field is not used for routing)\n", d.Domain, target) + } + } else { + fmt.Fprintf(&b, "This deployment is not exposed through the reverse proxy\n") + } + if meta.HealthCheck.Path != "" { + fmt.Fprintf(&b, "Configured health check path: %s\n", meta.HealthCheck.Path) + } + if len(meta.Databases) > 0 { + aliases := make([]string, 0, len(meta.Databases)) + for _, db := range meta.Databases { + aliases = append(aliases, fmt.Sprintf("%s (%s)", db.Alias, db.Type)) + } + fmt.Fprintf(&b, "This deployment uses database(s): %s\n", strings.Join(aliases, ", ")) + } + } + if s.proxyOrchestrator != nil && s.proxyOrchestrator.NginxManager().VirtualHostExists(deploymentName) { + fmt.Fprintf(&b, "A virtual host for this deployment exists in the reverse proxy\n") + } + } + + return ai.Section{Label: "FlatRun platform context", Content: b.String()} +} + +func (s *Server) runAssist(c *gin.Context, scopeLabel string, sections []ai.Section, req assistRequest, secrets []string, validateSuggestions func([]ai.SuggestedAction) []ai.SuggestedAction) { + intent, ok := ai.GetIntent(req.Intent) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "unknown intent; valid intents: " + strings.Join(ai.IntentKeys(), ", ")}) + return + } + + redactor := ai.NewRedactor(secrets) + redactions := 0 + for i := range sections { + redacted, n := redactor.Redact(sections[i].Content) + sections[i].Content = redacted + redactions += n + } + + messages := ai.BuildAssistMessages(intent, scopeLabel, sections, req.Question, s.config.AI.DocsURL) + resp, err := s.aiProvider.Complete(c.Request.Context(), ai.Request{Messages: messages}) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + + analysis, suggestions := ai.ParseSuggestions(resp.Content) + if !intent.AllowSuggestions { + suggestions = nil + } + if suggestions == nil { + suggestions = []ai.SuggestedAction{} + } else if validateSuggestions != nil { + suggestions = validateSuggestions(suggestions) + } + + c.JSON(http.StatusOK, gin.H{ + "analysis": analysis, + "suggested_actions": suggestions, + "intent": intent.Key, + "model": resp.Model, + "redactions": redactions, + }) +} + +func (s *Server) aiAssistDeployment(c *gin.Context) { + if s.aiProvider == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "AI assistant is not enabled", "code": "ai_disabled"}) + return + } + name := c.Param("name") + + var req assistRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.Intent == "" { + req.Intent = "diagnose" + } + if len(req.Sources) == 0 { + req.Sources = []assistSource{{Type: "logs"}, {Type: "compose"}} + } + + if _, err := s.manager.GetDeployment(name); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Deployment not found"}) + return + } + + sections := make([]ai.Section, 0, len(req.Sources)+1) + sections = append(sections, s.platformSection(name)) + for _, src := range req.Sources { + section, aerr := s.collectDeploymentSource(name, src) + if aerr != nil { + respondAPIError(c, aerr) + return + } + sections = append(sections, section) + } + + s.runAssist(c, "deployment "+name, sections, req, s.deploymentSecretValues(name), + func(suggestions []ai.SuggestedAction) []ai.SuggestedAction { + return s.filterSuggestionsForDeployment(name, suggestions) + }) +} + +// aiAssistSystem analyzes caller-provided output at host or agent +// level. The agent adds no context of its own, so any actor who could +// see the output may request the analysis. Suggestions are dropped +// because there is no deployment to validate them against. +func (s *Server) aiAssistSystem(c *gin.Context) { + if s.aiProvider == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "AI assistant is not enabled", "code": "ai_disabled"}) + return + } + + var req assistRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.Intent == "" { + req.Intent = "diagnose" + } + + sections := make([]ai.Section, 0, len(req.Sources)+1) + sections = append(sections, s.platformSection("")) + for _, src := range req.Sources { + if src.Type != "provided" { + c.JSON(http.StatusBadRequest, gin.H{"error": "system scope only accepts provided sources"}) + return + } + if strings.TrimSpace(src.Content) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "provided source requires content"}) + return + } + label := src.Label + if label == "" { + label = "Provided output" + } + sections = append(sections, ai.Section{Label: label, Content: src.Content}) + } + if len(sections) == 1 { + c.JSON(http.StatusBadRequest, gin.H{"error": "at least one provided source is required"}) + return + } + + s.runAssist(c, "the FlatRun host", sections, req, s.systemSecretValues(), + func([]ai.SuggestedAction) []ai.SuggestedAction { return []ai.SuggestedAction{} }) +} + +func (s *Server) systemSecretValues() []string { + return []string{ + s.config.AI.APIKey, + s.config.Auth.JWTSecret, + s.config.Infrastructure.Database.RootPassword, + s.config.Infrastructure.Redis.Password, + s.config.Infrastructure.PowerDNS.APIKey, + } +} + +// deploymentSecretValues collects every secret value that must never +// reach a model provider: the deployment's env values plus the agent's +// own credentials. +func (s *Server) deploymentSecretValues(name string) []string { + var secrets []string + envPath := filepath.Join(s.config.DeploymentsPath, name, ".env.flatrun") + if content, err := os.ReadFile(envPath); err == nil { + for _, v := range parseEnvContent(string(content)) { + secrets = append(secrets, v.Value) + } + } + return append(secrets, s.systemSecretValues()...) +} + +// filterSuggestionsForDeployment drops suggestions naming services +// that do not exist in the deployment's compose file, so a +// hallucinated service name can never be acted on. +func (s *Server) filterSuggestionsForDeployment(name string, suggestions []ai.SuggestedAction) []ai.SuggestedAction { + if len(suggestions) == 0 { + return []ai.SuggestedAction{} + } + serviceNames, err := s.manager.GetComposeServiceNames(name) + if err != nil { + return []ai.SuggestedAction{} + } + known := make(map[string]bool, len(serviceNames)) + for _, sn := range serviceNames { + known[sn] = true + } + valid := make([]ai.SuggestedAction, 0, len(suggestions)) + for _, sg := range suggestions { + if known[sg.Service] { + valid = append(valid, sg) + } + } + return valid +} diff --git a/internal/api/ai_handlers_test.go b/internal/api/ai_handlers_test.go new file mode 100644 index 0000000..a19a50a --- /dev/null +++ b/internal/api/ai_handlers_test.go @@ -0,0 +1,304 @@ +package api + +import ( + "context" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/flatrun/agent/internal/ai" + "github.com/flatrun/agent/pkg/models" +) + +type stubProvider struct { + lastRequest ai.Request + response *ai.Response + err error +} + +func (s *stubProvider) Name() string { return "stub" } + +func (s *stubProvider) Complete(_ context.Context, req ai.Request) (*ai.Response, error) { + s.lastRequest = req + if s.err != nil { + return nil, s.err + } + return s.response, nil +} + +func TestAIStatusDisabled(t *testing.T) { + _, _, ts := setupPlanTestServer(t) + + resp, parsed := doJSON(t, http.MethodGet, ts.URL+"/api/ai/status", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + if parsed["enabled"] != false { + t.Errorf("enabled = %v, want false", parsed["enabled"]) + } + if _, leaked := parsed["api_key"]; leaked { + t.Error("status response must never contain the api key") + } +} + +func TestAIAnalyzeDisabledReturns503(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/myapp/ai/analyze", + map[string]interface{}{"intent": "diagnose"}) + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", resp.StatusCode) + } + if parsed["code"] != "ai_disabled" { + t.Errorf("code = %v, want ai_disabled", parsed["code"]) + } +} + +func TestAIAnalyzeOperationRedactsSecrets(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + envContent := "DB_PASSWORD=hunter2secret\nAPP_NAME=myapp\n" + if err := os.WriteFile(filepath.Join(tmpDir, "myapp", ".env.flatrun"), []byte(envContent), 0600); err != nil { + t.Fatal(err) + } + + stub := &stubProvider{response: &ai.Response{Content: "## Diagnosis\nDB auth failure", Model: "stub-model"}} + s.aiProvider = stub + + body := map[string]interface{}{ + "intent": "diagnose", + "sources": []map[string]interface{}{ + { + "type": "provided", + "label": "Failed deploy output", + "content": "FATAL: password authentication failed for hunter2secret\nMYSQL_PASSWORD=other123secret", + }, + {"type": "compose"}, + }, + } + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/myapp/ai/analyze", body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body %v", resp.StatusCode, parsed) + } + if parsed["analysis"] != "## Diagnosis\nDB auth failure" { + t.Errorf("analysis = %v", parsed["analysis"]) + } + if parsed["model"] != "stub-model" { + t.Errorf("model = %v", parsed["model"]) + } + if parsed["redactions"].(float64) < 2 { + t.Errorf("redactions = %v, want >= 2", parsed["redactions"]) + } + + var prompt strings.Builder + for _, m := range stub.lastRequest.Messages { + prompt.WriteString(m.Content) + } + if strings.Contains(prompt.String(), "hunter2secret") { + t.Error("env secret value leaked into the prompt") + } + if strings.Contains(prompt.String(), "other123secret") { + t.Error("credential-shaped value leaked into the prompt") + } + if !strings.Contains(prompt.String(), "myapp") { + t.Error("prompt missing deployment context") + } +} + +func TestAIAnalyzeReturnsValidatedSuggestions(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + content := "## Diagnosis\nweb is crashing.\n```suggestions\n[" + + `{"kind":"service_action","service":"web","action":"restart","title":"Restart web"},` + + `{"kind":"service_action","service":"ghost","action":"restart","title":"Restart hallucinated service"}` + + "]\n```" + s.aiProvider = &stubProvider{response: &ai.Response{Content: content, Model: "stub"}} + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/myapp/ai/analyze", + map[string]interface{}{ + "intent": "diagnose", + "sources": []map[string]interface{}{{"type": "provided", "content": "crash"}}, + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body %v", resp.StatusCode, parsed) + } + + if strings.Contains(parsed["analysis"].(string), "suggestions") { + t.Error("suggestions block leaked into analysis text") + } + actions := parsed["suggested_actions"].([]interface{}) + if len(actions) != 1 { + t.Fatalf("got %d suggestions, want 1 (hallucinated service dropped): %v", len(actions), actions) + } + first := actions[0].(map[string]interface{}) + if first["service"] != "web" || first["action"] != "restart" { + t.Errorf("suggestion = %v", first) + } +} + +func TestAIAnalyzeIncludesPlatformContext(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", &models.ServiceMetadata{ + Name: "myapp", + Domains: []models.DomainConfig{ + {ID: "d1", Service: "web", ContainerPort: 8000, Domain: "myapp.example.com"}, + }, + }) + s.config.Infrastructure.DefaultProxyNetwork = "proxy" + s.config.Infrastructure.DefaultDatabaseNetwork = "database" + s.config.AI.DocsURL = "https://flatrun.dev/docs/" + + stub := &stubProvider{response: &ai.Response{Content: "ok", Model: "stub"}} + s.aiProvider = stub + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/myapp/ai/analyze", + map[string]interface{}{ + "intent": "diagnose", + "sources": []map[string]interface{}{{"type": "provided", "content": "network proxyy not found"}}, + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body %v", resp.StatusCode, parsed) + } + + var prompt strings.Builder + for _, m := range stub.lastRequest.Messages { + prompt.WriteString(m.Content) + } + for _, want := range []string{ + "FlatRun platform context", + "Configured proxy network", + "proxy", + "myapp.example.com forwards to service web on container port 8000", + "https://flatrun.dev/docs/", + } { + if !strings.Contains(prompt.String(), want) { + t.Errorf("prompt missing %q", want) + } + } +} + +func TestAIAnalyzeProviderErrorMapsTo502(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + s.aiProvider = &stubProvider{err: context.DeadlineExceeded} + + resp, _ := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/myapp/ai/analyze", + map[string]interface{}{ + "intent": "diagnose", + "sources": []map[string]interface{}{{"type": "provided", "content": "boom"}}, + }) + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("status = %d, want 502", resp.StatusCode) + } +} + +func TestAIAnalyzeSystem(t *testing.T) { + s, _, ts := setupPlanTestServer(t) + stub := &stubProvider{response: &ai.Response{Content: "## Diagnosis\nThe proxy network is missing.", Model: "stub"}} + s.aiProvider = stub + s.config.Infrastructure.Database.RootPassword = "rootpw-secret-1" + + body := map[string]interface{}{ + "intent": "diagnose", + "sources": []map[string]interface{}{{ + "type": "provided", + "label": "Start failed for myapp", + "content": "network proxy declared as external, but could not be found\npassword=rootpw-secret-1", + }}, + } + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/ai/analyze", body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body %v", resp.StatusCode, parsed) + } + if parsed["analysis"] != "## Diagnosis\nThe proxy network is missing." { + t.Errorf("analysis = %v", parsed["analysis"]) + } + + var prompt strings.Builder + for _, m := range stub.lastRequest.Messages { + prompt.WriteString(m.Content) + } + if strings.Contains(prompt.String(), "rootpw-secret-1") { + t.Error("agent credential leaked into the system diagnosis prompt") + } + if !strings.Contains(prompt.String(), "Start failed for myapp") { + t.Error("source label missing from prompt") + } + + // A provided source is required for system scope. + resp, _ = doJSON(t, http.MethodPost, ts.URL+"/api/ai/analyze", map[string]interface{}{"intent": "diagnose"}) + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("missing sources status = %d, want 400", resp.StatusCode) + } + + // Unknown intents are rejected. + resp, _ = doJSON(t, http.MethodPost, ts.URL+"/api/ai/analyze", map[string]interface{}{ + "intent": "world-domination", + "sources": []map[string]interface{}{{"type": "provided", "content": "x"}}, + }) + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("unknown intent status = %d, want 400", resp.StatusCode) + } +} + +func TestAIConfigKeyMaskedButWritable(t *testing.T) { + s, _, ts := setupPlanTestServer(t) + + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/config/ai.api_key", + map[string]interface{}{"value": "sk-supersecret123"}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("set ai.api_key status = %d, body %v", resp.StatusCode, parsed) + } + if s.config.AI.APIKey != "sk-supersecret123" { + t.Errorf("api key not set, got %q", s.config.AI.APIKey) + } + + entry := parsed["entry"].(map[string]interface{}) + if entry["value"] != nil { + t.Errorf("set response leaked value: %v", entry["value"]) + } + if entry["sensitive"] != true { + t.Errorf("entry not marked sensitive: %v", entry) + } + + resp, parsed = doJSON(t, http.MethodGet, ts.URL+"/api/config/ai.api_key", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("get status = %d", resp.StatusCode) + } + entry = parsed["entry"].(map[string]interface{}) + if entry["value"] != nil { + t.Errorf("get leaked value: %v", entry["value"]) + } +} + +func TestAIRuntimeApplierSwapsProvider(t *testing.T) { + s, _, ts := setupPlanTestServer(t) + if s.aiProvider != nil { + t.Fatal("provider should start nil with ai disabled") + } + + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/config/ai.enabled", + map[string]interface{}{"value": true}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("enable status = %d, body %v", resp.StatusCode, parsed) + } + if parsed["applied"] != true { + t.Errorf("applied = %v, want true", parsed["applied"]) + } + if s.aiProvider == nil { + t.Fatal("provider not constructed by runtime applier") + } + + resp, _ = doJSON(t, http.MethodPut, ts.URL+"/api/config/ai.enabled", + map[string]interface{}{"value": false}) + if resp.StatusCode != http.StatusOK { + t.Fatal("disable failed") + } + if s.aiProvider != nil { + t.Error("provider not torn down when disabled") + } +} diff --git a/internal/api/ai_session_handlers.go b/internal/api/ai_session_handlers.go new file mode 100644 index 0000000..f647672 --- /dev/null +++ b/internal/api/ai_session_handlers.go @@ -0,0 +1,295 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + + "github.com/flatrun/agent/internal/ai" + "github.com/flatrun/agent/internal/auth" + "github.com/gin-gonic/gin" +) + +// composeUserMessage merges a short message with optional bulky +// context. The model sees both; the operator's transcript shows only +// the message. +func composeUserMessage(message, context string) (content, display string) { + message = strings.TrimSpace(message) + context = strings.TrimSpace(context) + if context == "" { + return message, "" + } + return message + "\n\n" + context, message +} + +func sessionActorFrom(c *gin.Context) ai.SessionActor { + actor := auth.GetActorFromContext(c) + if actor == nil { + return ai.SessionActor{ID: "anonymous", Name: "anonymous"} + } + a := ai.SessionActor{} + switch { + case actor.User != nil: + a.ID = fmt.Sprintf("%d", actor.User.ID) + a.Name = actor.User.Username + case actor.APIKey != nil: + a.ID = actor.APIKey.KeyID + a.Name = actor.APIKey.Name + } + if a.ID == "" { + a.ID = "anonymous" + } + return a +} + +// canUseSession restricts a session to its creator (or an admin), +// since the transcript may reference resources only they can see. +func canUseSession(c *gin.Context, sess *ai.Session) bool { + actor := auth.GetActorFromContext(c) + if actor == nil || actor.Role == auth.RoleAdmin { + return true + } + return sessionActorFrom(c).ID == sess.CreatedBy.ID +} + +// advanceSession runs the tool loop: it calls the model, executes any +// requested tools (auto-run) or pauses for approval, and repeats until +// the model returns a final answer or the step budget is exhausted. +func (s *Server) advanceSession(c *gin.Context, sess *ai.Session) error { + tools := s.aiToolSpecs() + for step := 0; step < sess.MaxToolSteps(); step++ { + resp, err := s.aiProvider.Complete(c.Request.Context(), ai.Request{Messages: sess.Messages, Tools: tools}) + if err != nil { + return err + } + if sess.Model == "" { + sess.Model = resp.Model + } + + if len(resp.ToolCalls) == 0 { + analysis, suggestions := ai.ParseSuggestions(resp.Content) + sess.AddAssistantMessage(analysis, nil) + sess.Suggested = s.scopeSuggestions(sess, suggestions) + sess.Status = ai.SessionStatusReady + return nil + } + + sess.AddAssistantMessage(resp.Content, resp.ToolCalls) + + if !sess.AutoRun { + sess.Pending = resp.ToolCalls + sess.Status = ai.SessionStatusAwaitingApproval + return nil + } + + for _, call := range resp.ToolCalls { + sess.AddToolResult(call, s.runAITool(c, sess.Deployment, call)) + } + } + + sess.AddAssistantMessage("I stopped after investigating several steps without reaching a confident answer. Ask a more specific question or check the details directly.", nil) + sess.Status = ai.SessionStatusReady + return nil +} + +func (s *Server) scopeSuggestions(sess *ai.Session, suggestions []ai.SuggestedAction) []ai.SuggestedAction { + if len(suggestions) == 0 { + return []ai.SuggestedAction{} + } + if sess.Scope == ai.SessionScopeDeployment && sess.Deployment != "" { + return s.filterSuggestionsForDeployment(sess.Deployment, suggestions) + } + // System-scope suggestions have no single deployment to validate + // against, so they are not offered as one-click actions. + return []ai.SuggestedAction{} +} + +func (s *Server) sessionResponse(c *gin.Context, sess *ai.Session) { + c.JSON(http.StatusOK, gin.H{ + "id": sess.ID, + "scope": sess.Scope, + "deployment": sess.Deployment, + "auto_run": sess.AutoRun, + "status": sess.Status, + "model": sess.Model, + "messages": sess.DisplayMessages(), + "pending": sess.Pending, + "suggested_actions": sess.Suggested, + }) +} + +func (s *Server) aiRequireEnabled(c *gin.Context) bool { + if s.aiProvider == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "AI assistant is not enabled", "code": "ai_disabled"}) + return false + } + return true +} + +func (s *Server) createAISession(c *gin.Context) { + if !s.aiRequireEnabled(c) { + return + } + var req struct { + Scope string `json:"scope"` + Deployment string `json:"deployment"` + AutoRun bool `json:"auto_run"` + Message string `json:"message"` + Context string `json:"context"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.Scope == "" { + req.Scope = ai.SessionScopeSystem + } + if req.Scope == ai.SessionScopeDeployment { + if req.Deployment == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "deployment is required for a deployment-scoped session"}) + return + } + if !s.requireDeploymentAccess(c, req.Deployment, auth.AccessLevelRead) { + return + } + } + if strings.TrimSpace(req.Message) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "message is required"}) + return + } + + prompt := ai.BuildSessionPrompt(req.Scope, req.Deployment, s.config.AI.DocsURL) + sess := ai.NewSession(req.Scope, req.Deployment, req.AutoRun, sessionActorFrom(c), prompt) + content, display := composeUserMessage(req.Message, req.Context) + sess.AddUserMessage(content, display) + + if err := s.advanceSession(c, sess); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + if err := s.aiSessions.Save(sess); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + s.sessionResponse(c, sess) +} + +func (s *Server) loadOwnedSession(c *gin.Context) (*ai.Session, bool) { + sess, err := s.aiSessions.Get(c.Param("id")) + if err == ai.ErrSessionNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"}) + return nil, false + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return nil, false + } + if !canUseSession(c, sess) { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this session"}) + return nil, false + } + return sess, true +} + +func (s *Server) getAISession(c *gin.Context) { + sess, ok := s.loadOwnedSession(c) + if !ok { + return + } + s.sessionResponse(c, sess) +} + +func (s *Server) postAISessionMessage(c *gin.Context) { + if !s.aiRequireEnabled(c) { + return + } + sess, ok := s.loadOwnedSession(c) + if !ok { + return + } + if sess.Status == ai.SessionStatusAwaitingApproval { + c.JSON(http.StatusConflict, gin.H{"error": "session is waiting for tool approval; resolve it first"}) + return + } + var req struct { + Message string `json:"message"` + Context string `json:"context"` + } + if err := c.ShouldBindJSON(&req); err != nil || strings.TrimSpace(req.Message) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "message is required"}) + return + } + if sess.Scope == ai.SessionScopeDeployment && !s.requireDeploymentAccess(c, sess.Deployment, auth.AccessLevelRead) { + return + } + + content, display := composeUserMessage(req.Message, req.Context) + sess.AddUserMessage(content, display) + if err := s.advanceSession(c, sess); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + if err := s.aiSessions.Save(sess); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + s.sessionResponse(c, sess) +} + +func (s *Server) approveAISessionTools(c *gin.Context) { + if !s.aiRequireEnabled(c) { + return + } + sess, ok := s.loadOwnedSession(c) + if !ok { + return + } + if sess.Status != ai.SessionStatusAwaitingApproval { + c.JSON(http.StatusConflict, gin.H{"error": "session has no tools awaiting approval"}) + return + } + var req struct { + // Approved maps tool call id -> whether to run it. Missing or + // false means declined. + Approved map[string]bool `json:"approved"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if sess.Scope == ai.SessionScopeDeployment && !s.requireDeploymentAccess(c, sess.Deployment, auth.AccessLevelRead) { + return + } + + for _, call := range sess.Pending { + if req.Approved[call.ID] { + sess.AddToolResult(call, s.runAITool(c, sess.Deployment, call)) + } else { + sess.AddToolResult(call, "The operator declined to run this command.") + } + } + sess.Pending = nil + sess.Status = ai.SessionStatusReady + + if err := s.advanceSession(c, sess); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + if err := s.aiSessions.Save(sess); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + s.sessionResponse(c, sess) +} + +func (s *Server) deleteAISession(c *gin.Context) { + sess, ok := s.loadOwnedSession(c) + if !ok { + return + } + if err := s.aiSessions.Delete(sess.ID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "Session deleted", "id": sess.ID}) +} diff --git a/internal/api/ai_session_test.go b/internal/api/ai_session_test.go new file mode 100644 index 0000000..3371f16 --- /dev/null +++ b/internal/api/ai_session_test.go @@ -0,0 +1,265 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/flatrun/agent/internal/ai" + "github.com/flatrun/agent/pkg/models" + "github.com/gin-gonic/gin" +) + +func newAIToolContext() *gin.Context { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + return c +} + +// scriptedProvider returns a queued response per Complete call so a +// test can drive a multi-step tool loop. +type scriptedProvider struct { + responses []*ai.Response + calls int + lastReq ai.Request +} + +func (p *scriptedProvider) Name() string { return "scripted" } + +func (p *scriptedProvider) Complete(_ context.Context, req ai.Request) (*ai.Response, error) { + p.lastReq = req + if p.calls >= len(p.responses) { + return &ai.Response{Content: "done", Model: "scripted"}, nil + } + resp := p.responses[p.calls] + p.calls++ + return resp, nil +} + +func (p *scriptedProvider) lastRequestMessages() []ai.Message { return p.lastReq.Messages } + +func TestAISessionAutoRunToolLoop(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", &models.ServiceMetadata{Name: "myapp"}) + + s.aiProvider = &scriptedProvider{responses: []*ai.Response{ + {ToolCalls: []ai.ToolCall{{ID: "c1", Name: "list_deployments", Arguments: "{}"}}, Model: "scripted"}, + {Content: "## Summary\nYou have one deployment, myapp.", Model: "scripted"}, + }} + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions", map[string]interface{}{ + "scope": "system", + "auto_run": true, + "message": "what deployments do I have?", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body %v", resp.StatusCode, parsed) + } + if parsed["status"] != "ready" { + t.Errorf("status = %v, want ready", parsed["status"]) + } + + messages := parsed["messages"].([]interface{}) + // user, then assistant(tool step), then assistant(final). + if len(messages) < 2 { + t.Fatalf("expected at least 2 turns, got %v", messages) + } + last := messages[len(messages)-1].(map[string]interface{}) + if last["content"] != "## Summary\nYou have one deployment, myapp." { + t.Errorf("final content = %v", last["content"]) + } + + // The tool step must show the executed tool and its result. + foundToolStep := false + for _, m := range messages { + turn := m.(map[string]interface{}) + if steps, ok := turn["tool_steps"].([]interface{}); ok && len(steps) > 0 { + step := steps[0].(map[string]interface{}) + if step["name"] == "list_deployments" && step["result"] != nil { + foundToolStep = true + } + } + } + if !foundToolStep { + t.Error("auto-run did not execute the tool and record its result") + } + + id := parsed["id"].(string) + if _, err := s.aiSessions.Get(id); err != nil { + t.Errorf("session not persisted: %v", err) + } +} + +func TestAISessionHidesBulkyContext(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", &models.ServiceMetadata{Name: "myapp"}) + + stub := &scriptedProvider{responses: []*ai.Response{{Content: "All healthy.", Model: "scripted"}}} + s.aiProvider = stub + + logs := "GET /health 200 OK\nGET /health 200 OK\nGET /health 200 OK" + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions", map[string]interface{}{ + "scope": "deployment", + "deployment": "myapp", + "auto_run": true, + "message": "Analyze the recent logs for myapp.", + "context": "```\n" + logs + "\n```", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body %v", resp.StatusCode, parsed) + } + + // The displayed user turn must be the short message, not the logs. + messages := parsed["messages"].([]interface{}) + first := messages[0].(map[string]interface{}) + if first["content"] != "Analyze the recent logs for myapp." { + t.Errorf("displayed user turn = %q, want the short message", first["content"]) + } + if strings.Contains(first["content"].(string), "GET /health") { + t.Error("bulky logs leaked into the displayed transcript") + } + + // The model, however, must have received the logs. + var prompt strings.Builder + for _, m := range stub.lastRequestMessages() { + prompt.WriteString(m.Content) + } + if !strings.Contains(prompt.String(), "GET /health 200 OK") { + t.Error("logs were not sent to the model") + } +} + +func TestAISessionApprovalGating(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", &models.ServiceMetadata{Name: "myapp"}) + + s.aiProvider = &scriptedProvider{responses: []*ai.Response{ + {ToolCalls: []ai.ToolCall{{ID: "c1", Name: "list_networks", Arguments: "{}"}}, Model: "scripted"}, + {Content: "## Summary\nThe proxy network exists.", Model: "scripted"}, + }} + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions", map[string]interface{}{ + "scope": "system", + "auto_run": false, + "message": "do my networks exist?", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, body %v", resp.StatusCode, parsed) + } + if parsed["status"] != "awaiting_approval" { + t.Fatalf("status = %v, want awaiting_approval", parsed["status"]) + } + pending := parsed["pending"].([]interface{}) + if len(pending) != 1 { + t.Fatalf("pending = %v, want 1", pending) + } + if pending[0].(map[string]interface{})["name"] != "list_networks" { + t.Errorf("pending tool = %v", pending[0]) + } + + id := parsed["id"].(string) + + // A new message is rejected while approval is pending. + resp, _ = doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions/"+id+"/messages", + map[string]interface{}{"message": "hello"}) + if resp.StatusCode != http.StatusConflict { + t.Errorf("message during approval = %d, want 409", resp.StatusCode) + } + + // Approve the tool; the loop resumes and finishes. + resp, parsed = doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions/"+id+"/approve", + map[string]interface{}{"approved": map[string]bool{"c1": true}}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("approve status = %d, body %v", resp.StatusCode, parsed) + } + if parsed["status"] != "ready" { + t.Errorf("status after approve = %v, want ready", parsed["status"]) + } + messages := parsed["messages"].([]interface{}) + last := messages[len(messages)-1].(map[string]interface{}) + if last["content"] != "## Summary\nThe proxy network exists." { + t.Errorf("final content = %v", last["content"]) + } +} + +func TestAISessionDeclineTool(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", &models.ServiceMetadata{Name: "myapp"}) + + s.aiProvider = &scriptedProvider{responses: []*ai.Response{ + {ToolCalls: []ai.ToolCall{{ID: "c1", Name: "list_networks", Arguments: "{}"}}, Model: "scripted"}, + {Content: "I could not inspect the networks because you declined.", Model: "scripted"}, + }} + + _, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions", map[string]interface{}{ + "scope": "system", "auto_run": false, "message": "check networks", + }) + id := parsed["id"].(string) + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions/"+id+"/approve", + map[string]interface{}{"approved": map[string]bool{"c1": false}}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("decline status = %d", resp.StatusCode) + } + if parsed["status"] != "ready" { + t.Errorf("status = %v", parsed["status"]) + } +} + +func TestAISessionDisabledReturns503(t *testing.T) { + _, _, ts := setupPlanTestServer(t) + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/ai/sessions", + map[string]interface{}{"scope": "system", "message": "hi"}) + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", resp.StatusCode) + } + if parsed["code"] != "ai_disabled" { + t.Errorf("code = %v", parsed["code"]) + } +} + +func TestAIToolExecRefusesDestructive(t *testing.T) { + s, tmpDir, _ := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", &models.ServiceMetadata{Name: "myapp"}) + + c := newAIToolContext() + result := s.runAITool(c, "myapp", ai.ToolCall{ + Name: "exec_in_service", + Arguments: `{"service":"web","command":"rm -rf /data"}`, + }) + if !strings.Contains(result, "refused") { + t.Errorf("destructive exec not refused: %q", result) + } +} + +func TestAIToolHostCommandRefusesDestructive(t *testing.T) { + s, _, _ := setupPlanTestServer(t) + c := newAIToolContext() + result := s.runAITool(c, "", ai.ToolCall{ + Name: "run_host_command", + Arguments: `{"command":"rm -rf /"}`, + }) + if !strings.Contains(result, "refused") { + t.Errorf("destructive host command not refused: %q", result) + } +} + +func TestAIToolInstanceInfo(t *testing.T) { + s, _, _ := setupPlanTestServer(t) + c := newAIToolContext() + result := s.runAITool(c, "", ai.ToolCall{Name: "get_instance_info", Arguments: "{}"}) + if !strings.Contains(result, "Hostname:") || !strings.Contains(result, "Public IP:") { + t.Errorf("instance info missing fields: %q", result) + } +} + +func TestAIToolUnknownToolReported(t *testing.T) { + s, _, _ := setupPlanTestServer(t) + c := newAIToolContext() + result := s.runAITool(c, "", ai.ToolCall{Name: "does_not_exist", Arguments: "{}"}) + if !strings.Contains(result, "unknown tool") { + t.Errorf("unknown tool not reported: %q", result) + } +} diff --git a/internal/api/ai_tools.go b/internal/api/ai_tools.go new file mode 100644 index 0000000..833aaa6 --- /dev/null +++ b/internal/api/ai_tools.go @@ -0,0 +1,384 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "regexp" + "sort" + "strings" + "time" + + "github.com/flatrun/agent/internal/ai" + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/setup" + "github.com/gin-gonic/gin" +) + +const maxToolOutputChars = 8000 + +// aiTool is one read-only investigation capability the model can call +// to discover facts about this installation instead of guessing. +type aiTool struct { + Spec ai.Tool + Run func(s *Server, c *gin.Context, boundDeployment string, args map[string]interface{}) (string, error) +} + +// destructiveCommand matches obviously state-changing shell tokens, so +// an auto-run exec can never mutate the system even if the model asks. +var destructiveCommand = regexp.MustCompile(`(?i)\b(rm|rmdir|mv|dd|mkfs|truncate|tee|chmod|chown|kill|killall|pkill|shutdown|reboot|halt|apt|apt-get|yum|apk|dnf|systemctl|service|drop|delete|update|insert|alter)\b|>>?|\bmkdir\b`) + +func truncateToolOutput(s string) string { + if len(s) <= maxToolOutputChars { + return s + } + return s[:maxToolOutputChars] + "\n[... output truncated ...]" +} + +func toolDeployment(boundDeployment string, args map[string]interface{}) string { + if name, ok := args["deployment"].(string); ok && name != "" { + return name + } + return boundDeployment +} + +// toolAllowedDeployment resolves the target deployment and verifies the +// current actor may read it. Returns "" plus an error message string +// the model sees when access is denied. +func (s *Server) toolAllowedDeployment(c *gin.Context, boundDeployment string, args map[string]interface{}) (string, error) { + name := toolDeployment(boundDeployment, args) + if name == "" { + return "", fmt.Errorf("no deployment specified") + } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin && !actor.CanAccessDeployment(name, auth.AccessLevelRead) { + return "", fmt.Errorf("you do not have access to deployment %q", name) + } + return name, nil +} + +func argString(args map[string]interface{}, key string) string { + if v, ok := args[key].(string); ok { + return v + } + return "" +} + +func (s *Server) aiToolRegistry() map[string]aiTool { + objSchema := func(props map[string]interface{}, required ...string) map[string]interface{} { + schema := map[string]interface{}{"type": "object", "properties": props} + if len(required) > 0 { + schema["required"] = required + } + return schema + } + strProp := func(desc string) map[string]interface{} { + return map[string]interface{}{"type": "string", "description": desc} + } + + return map[string]aiTool{ + "get_instance_info": { + Spec: ai.Tool{ + Name: "get_instance_info", + Description: "Get information about the FlatRun host itself: its hostname and public IP address.", + Parameters: objSchema(map[string]interface{}{}), + }, + Run: func(s *Server, _ *gin.Context, _ string, _ map[string]interface{}) (string, error) { + hostname, _ := os.Hostname() + var b strings.Builder + fmt.Fprintf(&b, "Hostname: %s\n", hostname) + fmt.Fprintf(&b, "Public IP: %s\n", setup.ResolvePublicIP()) + return b.String(), nil + }, + }, + "run_host_command": { + Spec: ai.Tool{ + Name: "run_host_command", + Description: "Run a READ-ONLY shell command on the FlatRun host (not inside a container) to inspect the machine, for example: ip addr, hostname -I, df -h, docker ps, free -m. Commands that change state are refused. Requires system access.", + Parameters: objSchema(map[string]interface{}{ + "command": strProp("The read-only shell command to run on the host."), + }, "command"), + }, + Run: func(s *Server, c *gin.Context, _ string, args map[string]interface{}) (string, error) { + command := argString(args, "command") + if command == "" { + return "", fmt.Errorf("command is required") + } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin && !actor.HasPermission(auth.PermSystemRead) { + return "", fmt.Errorf("running host commands requires system access, which you do not have") + } + if s.config.SystemTerminal.ProtectedMode.Enabled && s.config.SystemTerminal.ProtectedMode.DisableTerminal { + return "", fmt.Errorf("the system terminal is disabled by global protected mode") + } + if destructiveCommand.MatchString(command) { + return "", fmt.Errorf("refused: %q looks like it changes state; only read-only commands are allowed", command) + } + if blocked, rule, _ := protectedCommandBlocked(&s.config.SystemTerminal.ProtectedMode, command); blocked { + return "", fmt.Errorf("command blocked by global protected mode: %s", protectedCommandBlockMessage(command, rule)) + } + ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, "sh", "-c", command) + cmd.Dir = s.config.DeploymentsPath + out, err := cmd.CombinedOutput() + redactor := ai.NewRedactor(s.systemSecretValues()) + redacted, _ := redactor.Redact(string(out)) + if err != nil { + return truncateToolOutput(redacted) + "\n[command exited with error: " + err.Error() + "]", nil + } + return truncateToolOutput(redacted), nil + }, + }, + "list_networks": { + Spec: ai.Tool{ + Name: "list_networks", + Description: "List the Docker networks that currently exist on this host, with their drivers.", + Parameters: objSchema(map[string]interface{}{}), + }, + Run: func(s *Server, _ *gin.Context, _ string, _ map[string]interface{}) (string, error) { + networks, err := s.networksManager.ListNetworks() + if err != nil { + return "", err + } + var b strings.Builder + for _, n := range networks { + fmt.Fprintf(&b, "- %s (driver: %s)\n", n.Name, n.Driver) + } + if b.Len() == 0 { + return "No networks found.", nil + } + return b.String(), nil + }, + }, + "list_deployments": { + Spec: ai.Tool{ + Name: "list_deployments", + Description: "List all deployments managed by this FlatRun instance with their current status.", + Parameters: objSchema(map[string]interface{}{}), + }, + Run: func(s *Server, _ *gin.Context, _ string, _ map[string]interface{}) (string, error) { + deployments, err := s.manager.ListDeployments() + if err != nil { + return "", err + } + var b strings.Builder + for _, d := range deployments { + fmt.Fprintf(&b, "- %s (status: %s)\n", d.Name, d.Status) + } + if b.Len() == 0 { + return "No deployments found.", nil + } + return b.String(), nil + }, + }, + "get_platform_config": { + Spec: ai.Tool{ + Name: "get_platform_config", + Description: "Get this installation's FlatRun configuration: the proxy and database network names, whether the managed nginx reverse proxy and shared database are enabled, and the database host and type.", + Parameters: objSchema(map[string]interface{}{}), + }, + Run: func(s *Server, _ *gin.Context, _ string, _ map[string]interface{}) (string, error) { + return s.platformSection("").Content, nil + }, + }, + "get_deployment_metadata": { + Spec: ai.Tool{ + Name: "get_deployment_metadata", + Description: "Get a deployment's FlatRun metadata: its reverse-proxy routing (which service and container port each domain forwards to), exposed domains, health check path and databases. This is the source of truth for routing, not the compose expose field.", + Parameters: objSchema(map[string]interface{}{ + "deployment": strProp("Deployment name. Omit to use the session's deployment."), + }), + }, + Run: func(s *Server, c *gin.Context, bound string, args map[string]interface{}) (string, error) { + name, err := s.toolAllowedDeployment(c, bound, args) + if err != nil { + return "", err + } + return s.platformSection(name).Content, nil + }, + }, + "get_deployment_logs": { + Spec: ai.Tool{ + Name: "get_deployment_logs", + Description: "Get a deployment's recent container logs. In FlatRun, application logs are the containers' stdout/stderr captured by Docker, not files on disk, so use this tool to read logs rather than searching the filesystem.", + Parameters: objSchema(map[string]interface{}{ + "deployment": strProp("Deployment name. Omit to use the session's deployment."), + "tail": map[string]interface{}{"type": "integer", "description": "How many recent lines to fetch (default 300, max 1000)."}, + }), + }, + Run: func(s *Server, c *gin.Context, bound string, args map[string]interface{}) (string, error) { + name, err := s.toolAllowedDeployment(c, bound, args) + if err != nil { + return "", err + } + tail := 300 + if v, ok := args["tail"].(float64); ok && int(v) > 0 { + tail = int(v) + } + if tail > 1000 { + tail = 1000 + } + logs, err := s.manager.GetDeploymentLogs(name, tail) + if err != nil { + return "", err + } + if strings.TrimSpace(logs) == "" { + return "The deployment has produced no logs.", nil + } + redactor := ai.NewRedactor(s.deploymentSecretValues(name)) + redacted, _ := redactor.Redact(logs) + return truncateToolOutput(redacted), nil + }, + }, + "list_deployment_files": { + Spec: ai.Tool{ + Name: "list_deployment_files", + Description: "List files and directories inside a deployment's directory at the given path. Use this to find application-generated logs and data files mounted in the deployment.", + Parameters: objSchema(map[string]interface{}{ + "deployment": strProp("Deployment name. Omit to use the session's deployment."), + "path": strProp("Path relative to the deployment directory. Defaults to the root."), + }), + }, + Run: func(s *Server, c *gin.Context, bound string, args map[string]interface{}) (string, error) { + name, err := s.toolAllowedDeployment(c, bound, args) + if err != nil { + return "", err + } + path := argString(args, "path") + if path == "" { + path = "/" + } + files, err := s.filesManager.ListFiles(name, path) + if err != nil { + return "", err + } + var b strings.Builder + for _, f := range files { + kind := "file" + if f.IsDir { + kind = "dir" + } + fmt.Fprintf(&b, "- %s (%s, %d bytes)\n", f.Path, kind, f.Size) + } + if b.Len() == 0 { + return "Directory is empty.", nil + } + return b.String(), nil + }, + }, + "read_deployment_file": { + Spec: ai.Tool{ + Name: "read_deployment_file", + Description: "Read a text file inside a deployment's directory, for example an application log file or config the app generated. Secret values are redacted.", + Parameters: objSchema(map[string]interface{}{ + "deployment": strProp("Deployment name. Omit to use the session's deployment."), + "path": strProp("Path to the file relative to the deployment directory."), + }, "path"), + }, + Run: func(s *Server, c *gin.Context, bound string, args map[string]interface{}) (string, error) { + name, err := s.toolAllowedDeployment(c, bound, args) + if err != nil { + return "", err + } + path := argString(args, "path") + if path == "" { + return "", fmt.Errorf("path is required") + } + reader, info, err := s.filesManager.ReadFile(name, path) + if err != nil { + return "", err + } + defer reader.Close() + if info.IsDir { + return "", fmt.Errorf("%s is a directory; use list_deployment_files", path) + } + buf := make([]byte, maxToolOutputChars) + n, _ := reader.Read(buf) + redactor := ai.NewRedactor(s.deploymentSecretValues(name)) + content, _ := redactor.Redact(string(buf[:n])) + return truncateToolOutput(content), nil + }, + }, + "exec_in_service": { + Spec: ai.Tool{ + Name: "exec_in_service", + Description: "Run a READ-ONLY shell command inside a deployment's service container to inspect its state (for example: ls, cat, env, ps, netstat, curl localhost). Commands that change state are refused. Secret values in the output are redacted.", + Parameters: objSchema(map[string]interface{}{ + "deployment": strProp("Deployment name. Omit to use the session's deployment."), + "service": strProp("Compose service name to run the command in."), + "command": strProp("The read-only shell command to run."), + }, "service", "command"), + }, + Run: func(s *Server, c *gin.Context, bound string, args map[string]interface{}) (string, error) { + name, err := s.toolAllowedDeployment(c, bound, args) + if err != nil { + return "", err + } + service := argString(args, "service") + command := argString(args, "command") + if service == "" || command == "" { + return "", fmt.Errorf("service and command are required") + } + if destructiveCommand.MatchString(command) { + return "", fmt.Errorf("refused: %q looks like it changes state; only read-only commands are allowed", command) + } + resolved, err := s.manager.ResolveService(name, service) + if err != nil { + return "", err + } + if blocked, reason, perr := s.protectedDeploymentActionBlocked(name, protectedActionExec); perr == nil && blocked { + return "", fmt.Errorf("%s", reason) + } + ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second) + defer cancel() + output, err := s.manager.ComposeExec(ctx, name, resolved, command) + if err != nil { + return "", err + } + redactor := ai.NewRedactor(s.deploymentSecretValues(name)) + redacted, _ := redactor.Redact(output) + return truncateToolOutput(redacted), nil + }, + }, + } +} + +// aiToolSpecs returns the tool schemas to advertise to the model, in a +// stable order. +func (s *Server) aiToolSpecs() []ai.Tool { + registry := s.aiToolRegistry() + names := make([]string, 0, len(registry)) + for name := range registry { + names = append(names, name) + } + sort.Strings(names) + specs := make([]ai.Tool, 0, len(names)) + for _, name := range names { + specs = append(specs, registry[name].Spec) + } + return specs +} + +// runAITool executes one tool call, returning the textual result the +// model reads. Errors are returned as content (prefixed) so the model +// can recover rather than the loop aborting. +func (s *Server) runAITool(c *gin.Context, boundDeployment string, call ai.ToolCall) string { + tool, ok := s.aiToolRegistry()[call.Name] + if !ok { + return "Error: unknown tool " + call.Name + } + var args map[string]interface{} + if strings.TrimSpace(call.Arguments) != "" { + if err := json.Unmarshal([]byte(call.Arguments), &args); err != nil { + return "Error: could not parse tool arguments: " + err.Error() + } + } + result, err := tool.Run(s, c, boundDeployment, args) + if err != nil { + return "Error: " + err.Error() + } + return result +} diff --git a/internal/api/cluster_handlers.go b/internal/api/cluster_handlers.go index 9759c94..ae5e9d1 100644 --- a/internal/api/cluster_handlers.go +++ b/internal/api/cluster_handlers.go @@ -88,9 +88,9 @@ func (s *Server) clusterAccept(c *gin.Context) { } var req struct { - InviteToken string `json:"invite_token" binding:"required"` - PeerURL string `json:"peer_url" binding:"required"` - CallbackURL string `json:"callback_url"` + InviteToken string `json:"invite_token" binding:"required"` + PeerURL string `json:"peer_url" binding:"required"` + CallbackURL string `json:"callback_url"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/config_handlers.go b/internal/api/config_handlers.go index 6e39368..7bed313 100644 --- a/internal/api/config_handlers.go +++ b/internal/api/config_handlers.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "github.com/flatrun/agent/internal/ai" "github.com/flatrun/agent/pkg/config" "github.com/gin-gonic/gin" ) @@ -41,15 +42,41 @@ func (s *Server) updateConfigKey(c *gin.Context) { return } - if err := config.Set(s.config, key, req.Value); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + if planRequested(c) { + s.planConfigUpdate(c, key, req.Value) return } + outcome, err := s.applyConfigUpdate(key, req.Value) + if err != nil { + respondAPIError(c, err) + return + } + + resp := gin.H{ + "entry": outcome.Entry, + "applied": outcome.Applied, + } + if outcome.ApplyErr != nil { + resp["apply_error"] = outcome.ApplyErr.Error() + } + c.JSON(http.StatusOK, resp) +} + +type configUpdateOutcome struct { + Entry config.Entry + Applied bool + ApplyErr error +} + +func (s *Server) applyConfigUpdate(key string, value interface{}) (*configUpdateOutcome, error) { + if err := config.Set(s.config, key, value); err != nil { + return nil, apiErrf(http.StatusBadRequest, "%s", err.Error()) + } + if s.configPath != "" { if err := config.Save(s.config, s.configPath); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "value updated in memory but not persisted: " + err.Error()}) - return + return nil, apiErrf(http.StatusInternalServerError, "value updated in memory but not persisted: %s", err.Error()) } } @@ -61,14 +88,7 @@ func (s *Server) updateConfigKey(c *gin.Context) { } entry, _ := config.Get(s.config, key) - resp := gin.H{ - "entry": entry, - "applied": applied, - } - if applyErr != nil { - resp["apply_error"] = applyErr.Error() - } - c.JSON(http.StatusOK, resp) + return &configUpdateOutcome{Entry: entry, Applied: applied, ApplyErr: applyErr}, nil } func (s *Server) runtimeAppliers() map[string]func(*Server) error { @@ -96,11 +116,28 @@ func (s *Server) runtimeAppliers() map[string]func(*Server) error { _, err := srv.infraManager.RefreshSecurityScripts() return err } + rebuildAIProvider := func(srv *Server) error { + provider, err := ai.New(&srv.config.AI) + if err == ai.ErrDisabled { + srv.aiProvider = nil + return nil + } + if err != nil { + return err + } + srv.aiProvider = provider + return nil + } return map[string]func(*Server) error{ "cleanup.timeout": func(srv *Server) error { srv.manager.SetCleanupTimeout(srv.config.Cleanup.Timeout) return nil }, + "ai.enabled": rebuildAIProvider, + "ai.base_url": rebuildAIProvider, + "ai.api_key": rebuildAIProvider, + "ai.model": rebuildAIProvider, + "ai.timeout": rebuildAIProvider, "security.rate_threshold": applyDetectorThresholds, "security.not_found_threshold": applyDetectorThresholds, "security.auth_failure_threshold": applyDetectorThresholds, diff --git a/internal/api/deployment_actions.go b/internal/api/deployment_actions.go new file mode 100644 index 0000000..ac5caa1 --- /dev/null +++ b/internal/api/deployment_actions.go @@ -0,0 +1,231 @@ +package api + +import ( + "fmt" + "log" + "net/http" + + "github.com/flatrun/agent/pkg/models" + "github.com/gin-gonic/gin" + "gopkg.in/yaml.v3" +) + +type apiError struct { + Status int + Msg string +} + +func (e *apiError) Error() string { return e.Msg } + +func apiErrf(status int, format string, args ...interface{}) *apiError { + return &apiError{Status: status, Msg: fmt.Sprintf(format, args...)} +} + +func respondAPIError(c *gin.Context, err error) { + if ae, ok := err.(*apiError); ok { + c.JSON(ae.Status, gin.H{"error": ae.Msg}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) +} + +func cloneMetadata(meta *models.ServiceMetadata) (*models.ServiceMetadata, error) { + if meta == nil { + return nil, nil + } + data, err := yaml.Marshal(meta) + if err != nil { + return nil, err + } + var cp models.ServiceMetadata + if err := yaml.Unmarshal(data, &cp); err != nil { + return nil, err + } + return &cp, nil +} + +type deploymentDeleteOptions struct { + DeleteSSL bool + DeleteDatabase bool + DeleteVhost bool +} + +func (s *Server) applyDeploymentDelete(name string, opts deploymentDeleteOptions) ([]string, error) { + deployment, _ := s.manager.GetDeployment(name) + + var deletedItems []string + + if opts.DeleteVhost { + if err := s.proxyOrchestrator.TeardownDeployment(name); err != nil { + log.Printf("Warning: failed to teardown proxy for %s: %v", name, err) + } else { + deletedItems = append(deletedItems, "virtual_host") + } + } + + if deployment != nil && deployment.Metadata != nil && opts.DeleteSSL { + domainsToDelete := deployment.Metadata.GetUniqueDomainNames() + if len(domainsToDelete) == 0 && deployment.Metadata.Networking.Domain != "" { + domainsToDelete = []string{deployment.Metadata.Networking.Domain} + } + + for _, domain := range domainsToDelete { + if err := s.proxyOrchestrator.SSLManager().DeleteCertificate(domain); err != nil { + log.Printf("Warning: failed to delete SSL certificate for %s: %v", domain, err) + } else { + deletedItems = append(deletedItems, fmt.Sprintf("ssl_certificate:%s", domain)) + } + } + } + + if opts.DeleteDatabase && s.config.Infrastructure.Database.Enabled { + if deployment != nil && deployment.Metadata != nil && len(deployment.Metadata.Databases) > 0 { + for _, dbConfig := range deployment.Metadata.Databases { + if dbConfig.IsShared { + if err := s.deleteDatabaseByAlias(name, dbConfig.Alias); err != nil { + log.Printf("Warning: failed to delete database %s for %s: %v", dbConfig.Alias, name, err) + } else { + deletedItems = append(deletedItems, fmt.Sprintf("database:%s", dbConfig.Alias)) + } + } + } + } else { + if err := s.deleteDatabaseForDeployment(name); err != nil { + log.Printf("Warning: failed to delete database for %s: %v", name, err) + } else { + deletedItems = append(deletedItems, "database") + } + } + } + + if err := s.manager.DeleteDeployment(name); err != nil { + return deletedItems, err + } + + return deletedItems, nil +} + +// mutateDomainAdd validates the new domain and appends it to the +// deployment metadata in memory only; persisting is the caller's job. +func (s *Server) mutateDomainAdd(deployment *models.Deployment, domain *models.DomainConfig) error { + if domain.Domain == "" { + return apiErrf(http.StatusBadRequest, "Domain is required") + } + + resolved, err := s.resolveService(deployment.Name, domain.Service) + if err != nil { + return apiErrf(http.StatusBadRequest, "%s", err.Error()) + } + domain.Service = resolved + + if domain.ID == "" { + domain.ID = generateDomainID() + } + + if deployment.Metadata == nil { + deployment.Metadata = &models.ServiceMetadata{} + } + + if len(deployment.Metadata.Domains) == 0 && deployment.Metadata.Networking.Expose { + existingService := deployment.Metadata.Networking.Service + if existingService == "" { + existingService = resolved + } + existingDomain := models.DomainConfig{ + ID: "default", + Service: existingService, + ContainerPort: deployment.Metadata.Networking.ContainerPort, + Domain: deployment.Metadata.Networking.Domain, + SSL: deployment.Metadata.SSL, + } + deployment.Metadata.Domains = []models.DomainConfig{existingDomain} + } + + for _, existing := range deployment.Metadata.Domains { + if existing.Domain == domain.Domain && existing.PathPrefix == domain.PathPrefix { + return apiErrf(http.StatusConflict, "Domain %s%s already exists", domain.Domain, domain.PathPrefix) + } + } + + if domain.ContainerPort == 0 && deployment.Metadata.Networking.ContainerPort != 0 { + domain.ContainerPort = deployment.Metadata.Networking.ContainerPort + } + + deployment.Metadata.Domains = append(deployment.Metadata.Domains, *domain) + return nil +} + +// mutateDomainUpdate replaces the domain with the given ID in memory +// only; persisting is the caller's job. +func (s *Server) mutateDomainUpdate(deployment *models.Deployment, domainID string, updated *models.DomainConfig) error { + if deployment.Metadata == nil || len(deployment.Metadata.Domains) == 0 { + return apiErrf(http.StatusNotFound, "Domain not found") + } + + if updated.Service != "" { + resolved, err := s.resolveService(deployment.Name, updated.Service) + if err != nil { + return apiErrf(http.StatusBadRequest, "%s", err.Error()) + } + updated.Service = resolved + } + + for i, d := range deployment.Metadata.Domains { + if d.ID == domainID { + updated.ID = domainID + if updated.Service == "" { + updated.Service = d.Service + } + deployment.Metadata.Domains[i] = *updated + return nil + } + } + return apiErrf(http.StatusNotFound, "Domain not found") +} + +// mutateDomainDelete removes the domain with the given ID in memory and +// reports whether the proxy should be torn down (true) or re-rendered +// (false). Persisting is the caller's job. +func mutateDomainDelete(deployment *models.Deployment, domainID string) (bool, error) { + meta := deployment.Metadata + if meta == nil { + return false, apiErrf(http.StatusNotFound, "Domain not found") + } + + // Legacy "default" domain backed by the networking config rather + // than the domains list. + if domainID == "default" && len(meta.Domains) == 0 { + if !meta.Networking.Expose || meta.Networking.Domain == "" { + return false, apiErrf(http.StatusNotFound, "Domain not found") + } + meta.Networking.Expose = false + meta.Networking.Domain = "" + meta.SSL.Enabled = false + meta.SSL.AutoCert = false + return true, nil + } + + if len(meta.Domains) == 0 { + return false, apiErrf(http.StatusNotFound, "Domain not found") + } + + found := false + newDomains := make([]models.DomainConfig, 0) + for _, d := range meta.Domains { + if d.ID == domainID { + found = true + continue + } + newDomains = append(newDomains, d) + } + if !found { + return false, apiErrf(http.StatusNotFound, "Domain not found") + } + + if len(newDomains) == 0 { + meta.Domains = nil + return !meta.Networking.Expose, nil + } + meta.Domains = newDomains + return false, nil +} diff --git a/internal/api/plan_actions.go b/internal/api/plan_actions.go new file mode 100644 index 0000000..373b57b --- /dev/null +++ b/internal/api/plan_actions.go @@ -0,0 +1,608 @@ +package api + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/flatrun/agent/internal/plan" + "github.com/flatrun/agent/internal/proxy" + "github.com/flatrun/agent/pkg/config" + "github.com/flatrun/agent/pkg/models" + "github.com/gin-gonic/gin" + "gopkg.in/yaml.v3" +) + +func renderEnvContent(envVars []EnvVar) string { + var content strings.Builder + for _, env := range envVars { + if env.Key != "" { + content.WriteString(fmt.Sprintf("%s=%s\n", env.Key, env.Value)) + } + } + return content.String() +} + +func diffEnvCounts(current, requested []EnvVar) (added, changed, removed int) { + currentMap := make(map[string]string, len(current)) + for _, e := range current { + currentMap[e.Key] = e.Value + } + requestedKeys := make(map[string]struct{}, len(requested)) + for _, e := range requested { + if e.Key == "" { + continue + } + requestedKeys[e.Key] = struct{}{} + old, ok := currentMap[e.Key] + switch { + case !ok: + added++ + case old != e.Value: + changed++ + } + } + for _, e := range current { + if _, ok := requestedKeys[e.Key]; !ok { + removed++ + } + } + return added, changed, removed +} + +func (s *Server) runningContainerReplacements(name, reason string) []plan.Change { + deployment, err := s.manager.GetDeployment(name) + if err != nil { + return nil + } + var changes []plan.Change + for _, svc := range deployment.Services { + if svc.Status != "running" { + continue + } + changes = append(changes, plan.Change{ + Type: "container", + ID: svc.Name, + Actions: []string{plan.ActionDelete, plan.ActionCreate}, + Reason: reason, + }) + } + return changes +} + +func (s *Server) vhostConfigPath(name string) string { + return filepath.Join(s.proxyOrchestrator.NginxManager().ConfigPath(), name+".conf") +} + +func (s *Server) planEnvUpdate(c *gin.Context, name string, envVars []EnvVar) { + envRel := filepath.Join(name, ".env.flatrun") + beforeBytes, readErr := os.ReadFile(filepath.Join(s.config.DeploymentsPath, envRel)) + exists := readErr == nil + before := string(beforeBytes) + after := renderEnvContent(envVars) + + p := s.newPlan("deployment.env.update", "deployment", name) + p.Snapshot.Files = plan.SnapshotFiles(s.config.DeploymentsPath, envRel) + + if exists && before == after { + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: ".env.flatrun", + Actions: []string{plan.ActionNoOp}, + Reason: "rendered content is identical to the current file", + Sensitive: true, + }) + } else { + action := plan.ActionCreate + var beforePtr *string + if exists { + action = plan.ActionUpdate + beforePtr = plan.StrPtr(before) + } + added, changed, removed := diffEnvCounts(parseEnvContent(before), envVars) + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: ".env.flatrun", + Actions: []string{action}, + Reason: fmt.Sprintf("%d variable(s) added, %d changed, %d removed", added, changed, removed), + Before: beforePtr, + After: plan.StrPtr(after), + Sensitive: true, + }) + p.Changes = append(p.Changes, s.runningContainerReplacements(name, + "env file change requires recreating containers; takes effect on the next start or deploy, not on apply")...) + } + + s.finishPlan(c, p, gin.H{"env_vars": envVars}) +} + +func applyPlannedEnvUpdate(s *Server, p *plan.Plan) (gin.H, error) { + var req struct { + EnvVars []EnvVar `json:"env_vars"` + } + if err := json.Unmarshal(p.Request.Body, &req); err != nil { + return nil, apiErrf(http.StatusBadRequest, "invalid plan body: %s", err.Error()) + } + if err := s.writeEnvFile(p.Resource.ID, req.EnvVars); err != nil { + return nil, err + } + return gin.H{"message": "Environment variables updated"}, nil +} + +func (s *Server) planComposeUpdate(c *gin.Context, name, content string) { + current, filename, err := s.manager.GetComposeFile(name) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Deployment not found"}) + return + } + rel := filepath.Join(name, filename) + + p := s.newPlan("deployment.compose.update", "deployment", name) + p.Snapshot.Files = plan.SnapshotFiles(s.config.DeploymentsPath, rel) + + if current == content { + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: filename, + Actions: []string{plan.ActionNoOp}, + Reason: "submitted compose file is identical to the current one", + }) + } else { + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: filename, + Actions: []string{plan.ActionUpdate}, + Reason: "compose configuration replaced with the submitted content", + Before: plan.StrPtr(current), + After: plan.StrPtr(content), + }) + p.Changes = append(p.Changes, s.runningContainerReplacements(name, + "compose change requires recreating containers; takes effect on the next deploy, not on apply")...) + } + + s.finishPlan(c, p, gin.H{"compose_content": content}) +} + +func applyPlannedComposeUpdate(s *Server, p *plan.Plan) (gin.H, error) { + var req struct { + ComposeContent string `json:"compose_content"` + } + if err := json.Unmarshal(p.Request.Body, &req); err != nil { + return nil, apiErrf(http.StatusBadRequest, "invalid plan body: %s", err.Error()) + } + if err := s.validateComposeContent(req.ComposeContent, p.Resource.ID); err != nil { + return nil, apiErrf(http.StatusBadRequest, "%s", err.Error()) + } + if err := s.manager.UpdateDeployment(p.Resource.ID, req.ComposeContent); err != nil { + return nil, err + } + return gin.H{"message": "Deployment updated", "name": p.Resource.ID}, nil +} + +func (s *Server) planDeploymentDelete(c *gin.Context, name string, opts deploymentDeleteOptions) { + deployment, err := s.manager.GetDeployment(name) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Deployment not found"}) + return + } + + p := s.newPlan("deployment.delete", "deployment", name) + + snapshotPaths := []string{filepath.Join(name, "service.yml")} + if _, composeName, cerr := s.manager.GetComposeFile(name); cerr == nil { + snapshotPaths = append(snapshotPaths, filepath.Join(name, composeName)) + } + vhostPath := s.vhostConfigPath(name) + snapshotPaths = append(snapshotPaths, vhostPath) + p.Snapshot.Files = plan.SnapshotFiles(s.config.DeploymentsPath, snapshotPaths...) + + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + "/", + Actions: []string{plan.ActionDelete}, + Reason: "deployment directory is removed, including all configs and bind-mounted data", + }) + for _, svc := range deployment.Services { + p.Changes = append(p.Changes, plan.Change{ + Type: "container", ID: svc.Name, + Actions: []string{plan.ActionDelete}, + Reason: "containers are stopped and removed with the deployment", + }) + } + nginxMgr := s.proxyOrchestrator.NginxManager() + if opts.DeleteVhost && nginxMgr.VirtualHostExists(name) { + var beforePtr *string + if current, verr := nginxMgr.GetVirtualHost(name); verr == nil { + beforePtr = plan.StrPtr(current) + } + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionDelete}, + Reason: "reverse proxy virtual host is removed", + Before: beforePtr, + }) + } + if opts.DeleteSSL && deployment.Metadata != nil { + domains := deployment.Metadata.GetUniqueDomainNames() + if len(domains) == 0 && deployment.Metadata.Networking.Domain != "" { + domains = []string{deployment.Metadata.Networking.Domain} + } + for _, domain := range domains { + if s.proxyOrchestrator.SSLManager().CertificateExists(domain) { + p.Changes = append(p.Changes, plan.Change{ + Type: "certificate", ID: domain, + Actions: []string{plan.ActionDelete}, + Reason: "SSL certificate is deleted with the deployment", + }) + } + } + } + if opts.DeleteDatabase && s.config.Infrastructure.Database.Enabled { + if deployment.Metadata != nil && len(deployment.Metadata.Databases) > 0 { + for _, dbConfig := range deployment.Metadata.Databases { + if dbConfig.IsShared { + p.Changes = append(p.Changes, plan.Change{ + Type: "database", ID: dbConfig.Alias, + Actions: []string{plan.ActionDelete}, + Reason: "shared database and its user are dropped", + }) + } + } + } else { + p.Changes = append(p.Changes, plan.Change{ + Type: "database", ID: "primary", + Actions: []string{plan.ActionDelete}, + Reason: "shared database and its user are dropped", + }) + } + } + + s.finishPlan(c, p, gin.H{ + "delete_ssl": opts.DeleteSSL, + "delete_database": opts.DeleteDatabase, + "delete_vhost": opts.DeleteVhost, + }) +} + +func applyPlannedDeploymentDelete(s *Server, p *plan.Plan) (gin.H, error) { + var opts struct { + DeleteSSL bool `json:"delete_ssl"` + DeleteDatabase bool `json:"delete_database"` + DeleteVhost bool `json:"delete_vhost"` + } + if err := json.Unmarshal(p.Request.Body, &opts); err != nil { + return nil, apiErrf(http.StatusBadRequest, "invalid plan body: %s", err.Error()) + } + deletedItems, err := s.applyDeploymentDelete(p.Resource.ID, deploymentDeleteOptions{ + DeleteSSL: opts.DeleteSSL, + DeleteDatabase: opts.DeleteDatabase, + DeleteVhost: opts.DeleteVhost, + }) + if err != nil { + return nil, err + } + return gin.H{"message": "Deployment deleted", "name": p.Resource.ID, "deleted_items": deletedItems}, nil +} + +// planDomainChange simulates a domain mutation on a metadata clone and +// previews the resulting service.yml and virtual host. +func (s *Server) planDomainChange(c *gin.Context, deployment *models.Deployment, action string, body interface{}, mutate func(*models.Deployment) (bool, error)) { + name := deployment.Name + + metaRel := filepath.Join(name, "service.yml") + metaBytes, metaErr := os.ReadFile(filepath.Join(s.config.DeploymentsPath, metaRel)) + vhostCurrent, vhostErr := s.proxyOrchestrator.NginxManager().GetVirtualHost(name) + + metaCopy, err := cloneMetadata(deployment.Metadata) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + depCopy := *deployment + depCopy.Metadata = metaCopy + + teardown, err := mutate(&depCopy) + if err != nil { + respondAPIError(c, err) + return + } + + p := s.newPlan(action, "deployment", name) + p.Snapshot.Files = plan.SnapshotFiles(s.config.DeploymentsPath, metaRel, s.vhostConfigPath(name)) + + afterMeta, err := yaml.Marshal(depCopy.Metadata) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + metaChange := plan.Change{ + Type: "file", ID: "service.yml", + Actions: []string{plan.ActionUpdate}, + Reason: "deployment metadata updated with the domain change", + After: plan.StrPtr(string(afterMeta)), + } + if metaErr != nil { + metaChange.Actions = []string{plan.ActionCreate} + metaChange.Reason = "deployment metadata file is created for the domain change" + } else { + metaChange.Before = plan.StrPtr(string(metaBytes)) + } + p.Changes = append(p.Changes, metaChange) + + if teardown { + if vhostErr == nil { + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionDelete}, + Reason: "deployment is no longer exposed; virtual host is removed", + Before: plan.StrPtr(vhostCurrent), + }) + } + } else { + rendered, rerr := s.proxyOrchestrator.RenderDeployment(&depCopy) + switch { + case rerr != nil: + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionUpdate}, + Reason: "virtual host will be regenerated (preview unavailable: " + rerr.Error() + ")", + }) + case rendered == "": + case vhostErr != nil: + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionCreate}, + Reason: "reverse proxy virtual host is created", + After: plan.StrPtr(rendered), + }) + case rendered == vhostCurrent: + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionNoOp}, + Reason: "virtual host is unchanged", + }) + default: + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionUpdate}, + Reason: "reverse proxy virtual host is regenerated", + Before: plan.StrPtr(vhostCurrent), + After: plan.StrPtr(rendered), + }) + } + p.Changes = append(p.Changes, s.pendingCertificateChanges(&depCopy)...) + } + + s.finishPlan(c, p, body) +} + +func (s *Server) pendingCertificateChanges(deployment *models.Deployment) []plan.Change { + if deployment.Metadata == nil { + return nil + } + var changes []plan.Change + seen := map[string]struct{}{} + for _, d := range deployment.Metadata.GetDomains() { + if !d.SSL.Enabled || !d.SSL.AutoCert { + continue + } + if _, dup := seen[d.Domain]; dup { + continue + } + seen[d.Domain] = struct{}{} + if !s.proxyOrchestrator.SSLManager().CertificateExists(d.Domain) { + changes = append(changes, plan.Change{ + Type: "certificate", ID: d.Domain, + Actions: []string{plan.ActionCreate}, + Reason: "certificate is requested from the configured CA on apply", + }) + } + } + return changes +} + +func applyPlannedDomainAdd(s *Server, p *plan.Plan) (gin.H, error) { + name := p.Resource.ID + deployment, err := s.manager.GetDeployment(name) + if err != nil { + return nil, apiErrf(http.StatusNotFound, "Deployment not found") + } + var domain models.DomainConfig + if err := json.Unmarshal(p.Request.Body, &domain); err != nil { + return nil, apiErrf(http.StatusBadRequest, "invalid plan body: %s", err.Error()) + } + if err := s.mutateDomainAdd(deployment, &domain); err != nil { + return nil, err + } + if err := s.manager.SaveMetadata(name, deployment.Metadata); err != nil { + return nil, apiErrf(http.StatusInternalServerError, "Failed to save domain: %s", err.Error()) + } + var result *proxy.SetupResult + if s.proxyOrchestrator != nil { + result, err = s.proxyOrchestrator.SetupDeployment(deployment) + if err != nil { + return nil, apiErrf(http.StatusConflict, "Failed to configure proxy: %s", err.Error()) + } + } + return gin.H{"message": "Domain added successfully", "domain": domain, "proxy_result": result}, nil +} + +func applyPlannedDomainUpdate(s *Server, p *plan.Plan) (gin.H, error) { + name := p.Resource.ID + domainID := p.Request.Params["domainId"] + deployment, err := s.manager.GetDeployment(name) + if err != nil { + return nil, apiErrf(http.StatusNotFound, "Deployment not found") + } + var updatedDomain models.DomainConfig + if err := json.Unmarshal(p.Request.Body, &updatedDomain); err != nil { + return nil, apiErrf(http.StatusBadRequest, "invalid plan body: %s", err.Error()) + } + if err := s.mutateDomainUpdate(deployment, domainID, &updatedDomain); err != nil { + return nil, err + } + if err := s.manager.SaveMetadata(name, deployment.Metadata); err != nil { + return nil, apiErrf(http.StatusInternalServerError, "Failed to save domain: %s", err.Error()) + } + result, err := s.proxyOrchestrator.SetupDeployment(deployment) + if err != nil { + return nil, apiErrf(http.StatusConflict, "Failed to configure proxy: %s", err.Error()) + } + return gin.H{"message": "Domain updated successfully", "domain": updatedDomain, "proxy_result": result}, nil +} + +func applyPlannedDomainDelete(s *Server, p *plan.Plan) (gin.H, error) { + name := p.Resource.ID + domainID := p.Request.Params["domainId"] + deployment, err := s.manager.GetDeployment(name) + if err != nil { + return nil, apiErrf(http.StatusNotFound, "Deployment not found") + } + teardown, err := mutateDomainDelete(deployment, domainID) + if err != nil { + return nil, err + } + if err := s.manager.SaveMetadata(name, deployment.Metadata); err != nil { + return nil, apiErrf(http.StatusInternalServerError, "Failed to save metadata: %s", err.Error()) + } + if s.proxyOrchestrator != nil { + if teardown { + if err := s.proxyOrchestrator.TeardownDeployment(name); err != nil { + log.Printf("Warning: failed to teardown proxy for %s: %v", name, err) + } + } else { + if _, err := s.proxyOrchestrator.SetupDeployment(deployment); err != nil { + log.Printf("Warning: failed to update proxy for %s: %v", name, err) + } + } + } + return gin.H{"message": "Domain deleted successfully"}, nil +} + +func (s *Server) planProxySetup(c *gin.Context, deployment *models.Deployment) { + name := deployment.Name + + p := s.newPlan("proxy.setup", "deployment", name) + metaRel := filepath.Join(name, "service.yml") + p.Snapshot.Files = plan.SnapshotFiles(s.config.DeploymentsPath, metaRel, s.vhostConfigPath(name)) + + rendered, err := s.proxyOrchestrator.RenderDeployment(deployment) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + current, vhostErr := s.proxyOrchestrator.NginxManager().GetVirtualHost(name) + + switch { + case rendered == "": + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionNoOp}, + Reason: "deployment is not configured for exposure; nothing to set up", + }) + case vhostErr != nil: + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionCreate}, + Reason: "reverse proxy virtual host is created and nginx reloaded", + After: plan.StrPtr(rendered), + }) + case rendered == current: + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionNoOp}, + Reason: "virtual host already matches the desired configuration", + }) + default: + p.Changes = append(p.Changes, plan.Change{ + Type: "file", ID: name + ".conf", + Actions: []string{plan.ActionUpdate}, + Reason: "reverse proxy virtual host is regenerated and nginx reloaded", + Before: plan.StrPtr(current), + After: plan.StrPtr(rendered), + }) + } + if rendered != "" { + p.Changes = append(p.Changes, s.pendingCertificateChanges(deployment)...) + } + + s.finishPlan(c, p, nil) +} + +func applyPlannedProxySetup(s *Server, p *plan.Plan) (gin.H, error) { + deployment, err := s.manager.GetDeployment(p.Resource.ID) + if err != nil { + return nil, apiErrf(http.StatusNotFound, "Deployment not found") + } + result, err := s.proxyOrchestrator.SetupDeployment(deployment) + if err != nil { + return nil, err + } + return gin.H{"message": "Proxy setup completed", "result": result}, nil +} + +func (s *Server) planConfigUpdate(c *gin.Context, key string, value interface{}) { + entry, err := config.Get(s.config, key) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + p := s.newPlan("config.update", "config", "_global") + if s.configPath != "" { + // Absolute so snapshot verification never resolves it against + // the deployments dir. + if absConfig, aerr := filepath.Abs(s.configPath); aerr == nil { + p.Snapshot.Files = plan.SnapshotFiles(s.config.DeploymentsPath, absConfig) + } + } + + beforeJSON, _ := json.Marshal(entry.Value) + afterJSON, _ := json.Marshal(value) + + _, hasApplier := s.runtimeAppliers()[key] + effect := "saved to the config file; takes effect after the agent restarts" + if hasApplier { + effect = "saved to the config file and applied to the running agent immediately" + } + + if string(beforeJSON) == string(afterJSON) { + p.Changes = append(p.Changes, plan.Change{ + Type: "config", ID: key, + Actions: []string{plan.ActionNoOp}, + Reason: "value is unchanged", + }) + } else { + p.Changes = append(p.Changes, plan.Change{ + Type: "config", ID: key, + Actions: []string{plan.ActionUpdate}, + Reason: effect, + Before: plan.StrPtr(string(beforeJSON)), + After: plan.StrPtr(string(afterJSON)), + Sensitive: entry.Sensitive, + }) + } + + s.finishPlan(c, p, gin.H{"value": value}) +} + +func applyPlannedConfigUpdate(s *Server, p *plan.Plan) (gin.H, error) { + key := normalizeConfigKey(p.Request.Params["key"]) + if key == "" { + return nil, apiErrf(http.StatusBadRequest, "plan is missing the config key") + } + var req struct { + Value interface{} `json:"value"` + } + if err := json.Unmarshal(p.Request.Body, &req); err != nil { + return nil, apiErrf(http.StatusBadRequest, "invalid plan body: %s", err.Error()) + } + outcome, err := s.applyConfigUpdate(key, req.Value) + if err != nil { + return nil, err + } + resp := gin.H{"entry": outcome.Entry, "applied": outcome.Applied} + if outcome.ApplyErr != nil { + resp["apply_error"] = outcome.ApplyErr.Error() + } + return resp, nil +} diff --git a/internal/api/plan_apply_test.go b/internal/api/plan_apply_test.go new file mode 100644 index 0000000..cb0d68e --- /dev/null +++ b/internal/api/plan_apply_test.go @@ -0,0 +1,400 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + + "github.com/flatrun/agent/internal/ai" + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/docker" + "github.com/flatrun/agent/internal/files" + "github.com/flatrun/agent/internal/networks" + "github.com/flatrun/agent/internal/plan" + "github.com/flatrun/agent/internal/proxy" + "github.com/flatrun/agent/pkg/config" + "github.com/flatrun/agent/pkg/models" +) + +func setupPlanTestServer(t *testing.T) (*Server, string, *httptest.Server) { + t.Helper() + gin.SetMode(gin.TestMode) + + tmpDir := t.TempDir() + cfg := &config.Config{ + DeploymentsPath: tmpDir, + Auth: config.AuthConfig{Enabled: false}, + Infrastructure: config.InfrastructureConfig{DefaultProxyNetwork: "proxy"}, + Nginx: config.NginxConfig{ConfigPath: filepath.Join(tmpDir, "nginx", "conf.d")}, + Cleanup: config.CleanupConfig{Timeout: 2 * time.Minute}, + Plans: config.PlansConfig{TTL: time.Hour, RetentionDays: 30}, + } + configPath := filepath.Join(tmpDir, "config.yml") + if err := config.Save(cfg, configPath); err != nil { + t.Fatalf("Failed to save config: %v", err) + } + + s := &Server{ + config: cfg, + configPath: configPath, + router: gin.New(), + manager: docker.NewManager(tmpDir), + networksManager: networks.NewManager(), + authMiddleware: auth.NewMiddleware(&cfg.Auth), + proxyOrchestrator: proxy.NewOrchestrator(cfg), + planStore: plan.NewStore(tmpDir), + aiSessions: ai.NewSessionStore(tmpDir), + filesManager: files.NewManager(tmpDir), + } + s.setupRoutes() + + ts := httptest.NewServer(s.router) + t.Cleanup(ts.Close) + return s, tmpDir, ts +} + +func doJSON(t *testing.T, method, url string, body interface{}) (*http.Response, map[string]interface{}) { + t.Helper() + var reader *bytes.Reader + if body != nil { + raw, err := json.Marshal(body) + if err != nil { + t.Fatal(err) + } + reader = bytes.NewReader(raw) + } else { + reader = bytes.NewReader(nil) + } + req, err := http.NewRequest(method, url, reader) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + var parsed map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&parsed) + return resp, parsed +} + +func planFromResponse(t *testing.T, parsed map[string]interface{}) map[string]interface{} { + t.Helper() + p, ok := parsed["plan"].(map[string]interface{}) + if !ok { + t.Fatalf("response has no plan object: %v", parsed) + } + return p +} + +func TestEnvPlanLifecycle(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", &models.ServiceMetadata{Name: "myapp", Type: "web"}) + + envBody := map[string]interface{}{ + "env_vars": []map[string]string{{"key": "DB_HOST", "value": "db.internal"}}, + } + + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/myapp/env?plan=true", envBody) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("plan create status = %d, body %v", resp.StatusCode, parsed) + } + planObj := planFromResponse(t, parsed) + planID := planObj["id"].(string) + + if planObj["status"] != "available" { + t.Errorf("status = %v, want available", planObj["status"]) + } + if planObj["action"] != "deployment.env.update" { + t.Errorf("action = %v", planObj["action"]) + } + + // Plan must not mutate anything. + if _, err := os.Stat(filepath.Join(tmpDir, "myapp", ".env.flatrun")); !os.IsNotExist(err) { + t.Fatal("plan creation wrote the env file") + } + + // Plan file is on disk under the resource directory. + planPath := filepath.Join(tmpDir, ".flatrun", "plans", "deployment", "myapp", planID+".json") + if _, err := os.Stat(planPath); err != nil { + t.Fatalf("plan file missing: %v", err) + } + + // Sensitive contents are redacted in responses... + changes := planObj["changes"].([]interface{}) + first := changes[0].(map[string]interface{}) + if first["after"] != plan.RedactedPlaceholder { + t.Errorf("env diff not redacted: %v", first["after"]) + } + // ...but available with include_sensitive. + resp, parsed = doJSON(t, http.MethodGet, ts.URL+"/api/plans/"+planID+"?include_sensitive=true", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("get include_sensitive status = %d", resp.StatusCode) + } + full := planFromResponse(t, parsed) + fullChange := full["changes"].([]interface{})[0].(map[string]interface{}) + if !strings.Contains(fullChange["after"].(string), "DB_HOST=db.internal") { + t.Errorf("include_sensitive should expose the diff, got %v", fullChange["after"]) + } + + // Apply executes the planned mutation. + resp, parsed = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("apply status = %d, body %v", resp.StatusCode, parsed) + } + content, err := os.ReadFile(filepath.Join(tmpDir, "myapp", ".env.flatrun")) + if err != nil || !strings.Contains(string(content), "DB_HOST=db.internal") { + t.Fatalf("env file not written by apply: %v %q", err, content) + } + stored, err := s.planStore.Get(planID) + if err != nil || stored.Status != plan.StatusApplied { + t.Fatalf("stored plan status = %v err %v, want applied", stored, err) + } + if stored.AppliedAt == nil || stored.AppliedBy == nil { + t.Error("applied plan missing applied_at/applied_by") + } + + // A second apply is rejected. + resp, _ = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusConflict { + t.Errorf("second apply status = %d, want 409", resp.StatusCode) + } +} + +func TestPlanDriftMarksObsolete(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + envBody := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": "1"}}} + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/myapp/env?plan=true", envBody) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("plan create status = %d", resp.StatusCode) + } + planID := planFromResponse(t, parsed)["id"].(string) + + // Out-of-band change to the same file the plan read. + if err := os.WriteFile(filepath.Join(tmpDir, "myapp", ".env.flatrun"), []byte("A=changed\n"), 0600); err != nil { + t.Fatal(err) + } + + resp, parsed = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusConflict { + t.Fatalf("stale apply status = %d, body %v", resp.StatusCode, parsed) + } + if _, ok := parsed["drifted"]; !ok { + t.Error("stale response missing drifted paths") + } + stored, _ := s.planStore.Get(planID) + if stored.Status != plan.StatusObsolete { + t.Errorf("stored status = %s, want obsolete", stored.Status) + } + + // The env file keeps the out-of-band content; apply must not have run. + content, _ := os.ReadFile(filepath.Join(tmpDir, "myapp", ".env.flatrun")) + if string(content) != "A=changed\n" { + t.Errorf("env file mutated by stale apply: %q", content) + } +} + +func TestExpiredPlanRejected(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + s.config.Plans.TTL = -time.Minute + envBody := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": "1"}}} + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/myapp/env?plan=true", envBody) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("plan create status = %d", resp.StatusCode) + } + s.config.Plans.TTL = time.Hour + planID := planFromResponse(t, parsed)["id"].(string) + + resp, _ = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusGone { + t.Fatalf("expired apply status = %d, want 410", resp.StatusCode) + } + stored, _ := s.planStore.Get(planID) + if stored.Status != plan.StatusExpired { + t.Errorf("stored status = %s, want expired", stored.Status) + } +} + +func TestApplyObsoletesSiblingPlans(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + mkPlan := func(val string) string { + body := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": val}}} + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/myapp/env?plan=true", body) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("plan create status = %d", resp.StatusCode) + } + return planFromResponse(t, parsed)["id"].(string) + } + first := mkPlan("1") + second := mkPlan("2") + + resp, _ := doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+first+"/apply", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("apply status = %d", resp.StatusCode) + } + + sibling, _ := s.planStore.Get(second) + if sibling.Status != plan.StatusObsolete { + t.Errorf("sibling status = %s, want obsolete", sibling.Status) + } +} + +func TestProtectedModeBlocksPlanAndApply(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + + // Plan creation on a protected deployment is blocked up front. + createTestDeployment(t, tmpDir, "locked", &models.ServiceMetadata{ + Name: "locked", + ProtectedMode: &models.ProtectedModeConfig{Enabled: true, BlockedActions: []string{"update_env"}}, + }) + envBody := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": "1"}}} + resp, _ := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/locked/env?plan=true", envBody) + if resp.StatusCode != http.StatusLocked { + t.Fatalf("plan on protected deployment status = %d, want 423", resp.StatusCode) + } + + // A plan created before protection turns on is blocked at apply. + createTestDeployment(t, tmpDir, "open", nil) + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/open/env?plan=true", envBody) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("plan create status = %d", resp.StatusCode) + } + planID := planFromResponse(t, parsed)["id"].(string) + + createTestDeployment(t, tmpDir, "open", &models.ServiceMetadata{ + Name: "open", + ProtectedMode: &models.ProtectedModeConfig{Enabled: true, BlockedActions: []string{"update_env"}}, + }) + resp, _ = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusLocked { + t.Fatalf("apply on protected deployment status = %d, want 423", resp.StatusCode) + } +} + +func TestConfigPlanLifecycle(t *testing.T) { + s, _, ts := setupPlanTestServer(t) + + body := map[string]interface{}{"value": "5m0s"} + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/config/cleanup.timeout?plan=true", body) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("config plan status = %d, body %v", resp.StatusCode, parsed) + } + planObj := planFromResponse(t, parsed) + planID := planObj["id"].(string) + + if got := planObj["resource"].(map[string]interface{}); got["type"] != "config" { + t.Errorf("resource = %v", got) + } + if s.config.Cleanup.Timeout != 2*time.Minute { + t.Fatal("config plan mutated live config") + } + + resp, parsed = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("config apply status = %d, body %v", resp.StatusCode, parsed) + } + if s.config.Cleanup.Timeout != 5*time.Minute { + t.Errorf("config not applied, timeout = %v", s.config.Cleanup.Timeout) + } + if parsed["applied"] != true { + t.Errorf("runtime applier should have fired, got %v", parsed["applied"]) + } +} + +func TestComposePlanShowsDiff(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + newCompose := "name: myapp\nservices:\n web:\n image: nginx:1.27\n networks:\n - proxy\nnetworks:\n proxy:\n external: true\n" + body := map[string]interface{}{"compose_content": newCompose} + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/myapp?plan=true", body) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("compose plan status = %d, body %v", resp.StatusCode, parsed) + } + planObj := planFromResponse(t, parsed) + changes := planObj["changes"].([]interface{}) + first := changes[0].(map[string]interface{}) + if first["id"] != "docker-compose.yml" { + t.Errorf("change id = %v", first["id"]) + } + if !strings.Contains(first["after"].(string), "nginx:1.27") { + t.Errorf("compose diff missing new content") + } + // Compose file untouched by planning. + content, _ := os.ReadFile(filepath.Join(tmpDir, "myapp", "docker-compose.yml")) + if strings.Contains(string(content), "nginx:1.27") { + t.Fatal("plan creation wrote the compose file") + } +} + +func TestPlanListAndDiscard(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + envBody := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": "1"}}} + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/myapp/env?plan=true", envBody) + if resp.StatusCode != http.StatusCreated { + t.Fatal("plan create failed") + } + planID := planFromResponse(t, parsed)["id"].(string) + + resp, parsed = doJSON(t, http.MethodGet, ts.URL+"/api/plans?deployment=myapp", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("list status = %d", resp.StatusCode) + } + plans := parsed["plans"].([]interface{}) + if len(plans) != 1 { + t.Fatalf("list returned %d plans, want 1", len(plans)) + } + + resp, _ = doJSON(t, http.MethodDelete, ts.URL+"/api/plans/"+planID, nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("discard status = %d", resp.StatusCode) + } + resp, _ = doJSON(t, http.MethodGet, ts.URL+"/api/plans/"+planID, nil) + if resp.StatusCode != http.StatusNotFound { + t.Errorf("get after discard status = %d, want 404", resp.StatusCode) + } +} + +func TestDeletePlanPreviewsScope(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "victim", &models.ServiceMetadata{Name: "victim", Type: "web"}) + + url := fmt.Sprintf("%s/api/deployments/victim?plan=true&delete_vhost=false", ts.URL) + resp, parsed := doJSON(t, http.MethodDelete, url, nil) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("delete plan status = %d, body %v", resp.StatusCode, parsed) + } + planObj := planFromResponse(t, parsed) + planID := planObj["id"].(string) + + if _, err := os.Stat(filepath.Join(tmpDir, "victim")); err != nil { + t.Fatal("plan creation deleted the deployment") + } + + resp, parsed = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("delete apply status = %d, body %v", resp.StatusCode, parsed) + } + if _, err := os.Stat(filepath.Join(tmpDir, "victim")); !os.IsNotExist(err) { + t.Fatal("deployment directory still exists after applied delete plan") + } +} diff --git a/internal/api/plan_handlers.go b/internal/api/plan_handlers.go new file mode 100644 index 0000000..da03489 --- /dev/null +++ b/internal/api/plan_handlers.go @@ -0,0 +1,425 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/plan" + "github.com/gin-gonic/gin" +) + +type planAction struct { + Permission auth.Permission + AccessLevel string + ProtectedAction string + SuccessStatus int + Apply func(s *Server, p *plan.Plan) (gin.H, error) +} + +func (s *Server) planRegistry() map[string]planAction { + return map[string]planAction{ + "deployment.env.update": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + ProtectedAction: protectedActionUpdateEnv, + Apply: applyPlannedEnvUpdate, + }, + "deployment.compose.update": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + ProtectedAction: protectedActionUpdateDeployment, + Apply: applyPlannedComposeUpdate, + }, + "deployment.delete": { + Permission: auth.PermDeploymentsDelete, + AccessLevel: auth.AccessLevelAdmin, + ProtectedAction: protectedActionDeleteDeployment, + Apply: applyPlannedDeploymentDelete, + }, + "deployment.domain.add": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + SuccessStatus: http.StatusCreated, + Apply: applyPlannedDomainAdd, + }, + "deployment.domain.update": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + Apply: applyPlannedDomainUpdate, + }, + "deployment.domain.delete": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + Apply: applyPlannedDomainDelete, + }, + "proxy.setup": { + Permission: auth.PermCertificatesWrite, + AccessLevel: auth.AccessLevelWrite, + Apply: applyPlannedProxySetup, + }, + "config.update": { + Permission: auth.PermConfigWrite, + Apply: applyPlannedConfigUpdate, + }, + "deployment.service.start": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + Apply: applyPlannedServiceAction("start"), + }, + "deployment.service.stop": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + Apply: applyPlannedServiceAction("stop"), + }, + "deployment.service.restart": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + Apply: applyPlannedServiceAction("restart"), + }, + "deployment.service.pull": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + Apply: applyPlannedServiceAction("pull"), + }, + "deployment.service.rebuild": { + Permission: auth.PermDeploymentsWrite, + AccessLevel: auth.AccessLevelWrite, + ProtectedAction: protectedActionRebuild, + Apply: applyPlannedServiceAction("rebuild"), + }, + } +} + +func planRequested(c *gin.Context) bool { + return c.Query("plan") == "true" +} + +// requirePlannedAction enforces the deployment's require_plan setting: +// when set, direct execution is refused and the caller must create and +// apply a plan instead. Plan creation itself is always allowed, and +// applies bypass this guard because they run the reviewed plan. +func (s *Server) requirePlannedAction(c *gin.Context, name string) bool { + if planRequested(c) { + return true + } + deployment, err := s.manager.GetDeployment(name) + if err != nil || deployment.Metadata == nil || !deployment.Metadata.RequirePlan { + return true + } + c.JSON(http.StatusPreconditionRequired, gin.H{ + "error": "This deployment requires changes to be planned and reviewed before they run", + "code": "plan_required", + }) + return false +} + +func planActorFrom(c *gin.Context) plan.Actor { + actor := auth.GetActorFromContext(c) + if actor == nil { + return plan.Actor{Type: "anonymous"} + } + a := plan.Actor{Type: actor.Type} + switch { + case actor.User != nil: + a.ID = fmt.Sprintf("%d", actor.User.ID) + a.Name = actor.User.Username + case actor.APIKey != nil: + a.ID = actor.APIKey.KeyID + a.Name = actor.APIKey.Name + } + return a +} + +func (s *Server) newPlan(action, resourceType, resourceID string) *plan.Plan { + return plan.New(action, plan.Resource{Type: resourceType, ID: resourceID}, plan.Actor{}, s.config.Plans.TTL) +} + +// finishPlan stamps actor and request context onto the plan, persists +// it, and answers the original mutating request with the preview +// instead of executing it. +func (s *Server) finishPlan(c *gin.Context, p *plan.Plan, body interface{}) { + p.CreatedBy = planActorFrom(c) + p.Request.Method = c.Request.Method + p.Request.Path = c.Request.URL.Path + + params := map[string]string{} + for _, prm := range c.Params { + params[prm.Key] = prm.Value + } + if len(params) > 0 { + p.Request.Params = params + } + + query := map[string]string{} + for k, v := range c.Request.URL.Query() { + if k == "plan" || len(v) == 0 { + continue + } + query[k] = v[0] + } + if len(query) > 0 { + p.Request.Query = query + } + + if body != nil { + raw, err := json.Marshal(body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode plan request: " + err.Error()}) + return + } + p.Request.Body = raw + } + + p.Summarize() + + if err := s.planStore.Save(p); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save plan: " + err.Error()}) + return + } + + c.JSON(http.StatusCreated, gin.H{"plan": p.Redacted()}) +} + +func (s *Server) canViewPlan(c *gin.Context, p *plan.Plan) bool { + actor := auth.GetActorFromContext(c) + if actor == nil { + return true + } + if p.Resource.Type == "deployment" { + return actor.CanAccessDeployment(p.Resource.ID, auth.AccessLevelRead) + } + return actor.HasPermission(auth.PermConfigRead) +} + +func (s *Server) canManagePlan(c *gin.Context, p *plan.Plan) (planAction, *apiError) { + action, ok := s.planRegistry()[p.Action] + if !ok { + return planAction{}, apiErrf(http.StatusConflict, "plan action %q is not supported by this agent", p.Action) + } + actor := auth.GetActorFromContext(c) + if actor != nil { + if !actor.HasPermission(action.Permission) { + return planAction{}, apiErrf(http.StatusForbidden, "Permission denied: %s required", action.Permission) + } + if action.AccessLevel != "" && p.Resource.Type == "deployment" && !actor.CanAccessDeployment(p.Resource.ID, action.AccessLevel) { + return planAction{}, apiErrf(http.StatusForbidden, "No access to this deployment") + } + } + return action, nil +} + +// reverifyPlan lazily transitions an available plan to expired or +// obsolete, so reads always reflect reality even when the change that +// invalidated the plan happened outside the API (e.g. an SSH edit). +func (s *Server) reverifyPlan(p *plan.Plan) *plan.Plan { + if p.Status != plan.StatusAvailable { + return p + } + if p.Expired(time.Now().UTC()) { + p.Status = plan.StatusExpired + _ = s.planStore.Save(p) + return p + } + if err := plan.VerifySnapshot(s.config.DeploymentsPath, p.Snapshot.Files); err != nil { + p.Status = plan.StatusObsolete + _ = s.planStore.Save(p) + } + return p +} + +func (s *Server) listPlans(c *gin.Context) { + filter := plan.ListFilter{ + ResourceType: c.Query("resource_type"), + ResourceID: c.Query("deployment"), + } + if filter.ResourceID != "" && filter.ResourceType == "" { + filter.ResourceType = "deployment" + } + + plans, err := s.planStore.List(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + statusFilter := c.Query("status") + out := make([]*plan.Plan, 0, len(plans)) + for _, p := range plans { + if !s.canViewPlan(c, p) { + continue + } + p = s.reverifyPlan(p) + if statusFilter != "" && p.Status != statusFilter { + continue + } + out = append(out, p.Redacted()) + } + c.JSON(http.StatusOK, gin.H{"plans": out}) +} + +func (s *Server) getPlan(c *gin.Context) { + p, err := s.planStore.Get(c.Param("id")) + if err == plan.ErrNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "Plan not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !s.canViewPlan(c, p) { + c.JSON(http.StatusForbidden, gin.H{"error": "No access to this plan"}) + return + } + p = s.reverifyPlan(p) + + if c.Query("include_sensitive") == "true" { + if _, aerr := s.canManagePlan(c, p); aerr != nil { + respondAPIError(c, aerr) + return + } + c.JSON(http.StatusOK, gin.H{"plan": p}) + return + } + c.JSON(http.StatusOK, gin.H{"plan": p.Redacted()}) +} + +func (s *Server) applyPlan(c *gin.Context) { + id := c.Param("id") + p, err := s.planStore.Get(id) + if err == plan.ErrNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "Plan not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + unlock := s.planStore.LockResource(p.Resource) + defer unlock() + + // Re-read under the resource lock: a concurrent apply may have + // just transitioned this plan. + p, err = s.planStore.Get(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Plan not found"}) + return + } + + action, aerr := s.canManagePlan(c, p) + if aerr != nil { + respondAPIError(c, aerr) + return + } + + if p.Status == plan.StatusAvailable && p.Expired(time.Now().UTC()) { + p.Status = plan.StatusExpired + _ = s.planStore.Save(p) + } + switch p.Status { + case plan.StatusAvailable: + case plan.StatusExpired: + c.JSON(http.StatusGone, gin.H{"error": "Plan has expired; create a new plan", "plan": p.Redacted()}) + return + default: + c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("Plan is %s and can no longer be applied", p.Status), "plan": p.Redacted()}) + return + } + + if action.ProtectedAction != "" && p.Resource.Type == "deployment" { + blocked, reason, perr := s.protectedDeploymentActionBlocked(p.Resource.ID, action.ProtectedAction) + if perr == nil && blocked { + c.JSON(http.StatusLocked, gin.H{"error": reason}) + return + } + } + + if verr := plan.VerifySnapshot(s.config.DeploymentsPath, p.Snapshot.Files); verr != nil { + p.Status = plan.StatusObsolete + _ = s.planStore.Save(p) + resp := gin.H{"error": "Plan is stale: state changed since it was created", "plan": p.Redacted()} + if drift, ok := verr.(*plan.DriftError); ok { + resp["drifted"] = drift.Paths + } + c.JSON(http.StatusConflict, resp) + return + } + + p.Status = plan.StatusApplying + if err := s.planStore.Save(p); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to persist plan state: " + err.Error()}) + return + } + + result, applyErr := action.Apply(s, p) + + now := time.Now().UTC() + actor := planActorFrom(c) + p.AppliedAt = &now + p.AppliedBy = &actor + + if applyErr != nil { + p.Status = plan.StatusFailed + p.ApplyError = applyErr.Error() + _ = s.planStore.Save(p) + status := http.StatusInternalServerError + if ae, ok := applyErr.(*apiError); ok { + status = ae.Status + } + c.JSON(status, gin.H{"error": applyErr.Error(), "plan": p.Redacted()}) + return + } + + p.Status = plan.StatusApplied + _ = s.planStore.Save(p) + + // Other still-open plans on this resource were computed against + // state that no longer exists. + siblings, _ := s.planStore.List(plan.ListFilter{ + ResourceType: p.Resource.Type, + ResourceID: p.Resource.ID, + Status: plan.StatusAvailable, + }) + for _, sib := range siblings { + if sib.ID == p.ID { + continue + } + sib.Status = plan.StatusObsolete + _ = s.planStore.Save(sib) + } + + resp := gin.H{"plan": p.Redacted()} + for k, v := range result { + resp[k] = v + } + status := action.SuccessStatus + if status == 0 { + status = http.StatusOK + } + c.JSON(status, resp) +} + +func (s *Server) deletePlan(c *gin.Context) { + p, err := s.planStore.Get(c.Param("id")) + if err == plan.ErrNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "Plan not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if _, aerr := s.canManagePlan(c, p); aerr != nil { + respondAPIError(c, aerr) + return + } + if err := s.planStore.Delete(p.ID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "Plan discarded", "id": p.ID}) +} diff --git a/internal/api/require_plan_test.go b/internal/api/require_plan_test.go new file mode 100644 index 0000000..b893554 --- /dev/null +++ b/internal/api/require_plan_test.go @@ -0,0 +1,127 @@ +package api + +import ( + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/flatrun/agent/pkg/models" +) + +func TestRequirePlanBlocksDirectMutation(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "guarded", &models.ServiceMetadata{ + Name: "guarded", + RequirePlan: true, + }) + + envBody := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": "1"}}} + + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/guarded/env", envBody) + if resp.StatusCode != http.StatusPreconditionRequired { + t.Fatalf("direct mutation status = %d, want 428", resp.StatusCode) + } + if parsed["code"] != "plan_required" { + t.Errorf("code = %v, want plan_required", parsed["code"]) + } + if _, err := os.Stat(filepath.Join(tmpDir, "guarded", ".env.flatrun")); !os.IsNotExist(err) { + t.Fatal("blocked mutation still wrote the env file") + } + + // The plan path stays open: plan, then apply. + resp, parsed = doJSON(t, http.MethodPut, ts.URL+"/api/deployments/guarded/env?plan=true", envBody) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("plan create status = %d, body %v", resp.StatusCode, parsed) + } + planID := planFromResponse(t, parsed)["id"].(string) + + resp, parsed = doJSON(t, http.MethodPost, ts.URL+"/api/plans/"+planID+"/apply", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("apply status = %d, body %v", resp.StatusCode, parsed) + } + content, err := os.ReadFile(filepath.Join(tmpDir, "guarded", ".env.flatrun")) + if err != nil || !strings.Contains(string(content), "A=1") { + t.Fatalf("apply did not write the env file: %v %q", err, content) + } +} + +func TestRequirePlanOffByDefault(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "open", nil) + + envBody := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": "1"}}} + resp, _ := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/open/env", envBody) + if resp.StatusCode != http.StatusOK { + t.Fatalf("direct mutation status = %d, want 200 when require_plan is off", resp.StatusCode) + } +} + +func TestRequirePlanToggleViaMetadata(t *testing.T) { + s, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "app", &models.ServiceMetadata{Name: "app", Type: "web"}) + + resp, parsed := doJSON(t, http.MethodPut, ts.URL+"/api/deployments/app/metadata", + map[string]interface{}{"require_plan": true}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("metadata update status = %d, body %v", resp.StatusCode, parsed) + } + + deployment, err := s.manager.GetDeployment("app") + if err != nil || deployment.Metadata == nil || !deployment.Metadata.RequirePlan { + t.Fatalf("require_plan not persisted: %+v err %v", deployment, err) + } + + // And the guard is live immediately. + envBody := map[string]interface{}{"env_vars": []map[string]string{{"key": "A", "value": "1"}}} + resp, _ = doJSON(t, http.MethodPut, ts.URL+"/api/deployments/app/env", envBody) + if resp.StatusCode != http.StatusPreconditionRequired { + t.Fatalf("status = %d, want 428 after enabling require_plan", resp.StatusCode) + } +} + +func TestServiceActionPlan(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/myapp/services/web/restart?plan=true", nil) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("service plan status = %d, body %v", resp.StatusCode, parsed) + } + planObj := planFromResponse(t, parsed) + if planObj["action"] != "deployment.service.restart" { + t.Errorf("action = %v", planObj["action"]) + } + change := planObj["changes"].([]interface{})[0].(map[string]interface{}) + if change["type"] != "service" || change["id"] != "web" { + t.Errorf("change = %v", change) + } + actions := change["actions"].([]interface{}) + if len(actions) != 2 || actions[0] != "delete" || actions[1] != "create" { + t.Errorf("restart should be a replace pair, got %v", actions) + } +} + +func TestServiceActionUnknownService(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "myapp", nil) + + resp, _ := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/myapp/services/nope/start", nil) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("unknown service status = %d, want 400", resp.StatusCode) + } +} + +func TestServiceActionRespectsRequirePlan(t *testing.T) { + _, tmpDir, ts := setupPlanTestServer(t) + createTestDeployment(t, tmpDir, "guarded", &models.ServiceMetadata{ + Name: "guarded", + RequirePlan: true, + }) + + resp, parsed := doJSON(t, http.MethodPost, ts.URL+"/api/deployments/guarded/services/web/stop", nil) + if resp.StatusCode != http.StatusPreconditionRequired { + t.Fatalf("status = %d, want 428, body %v", resp.StatusCode, parsed) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 020899e..400b5da 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -20,6 +20,7 @@ import ( "github.com/compose-spec/compose-go/v2/loader" composetypes "github.com/compose-spec/compose-go/v2/types" + "github.com/flatrun/agent/internal/ai" "github.com/flatrun/agent/internal/audit" "github.com/flatrun/agent/internal/auth" "github.com/flatrun/agent/internal/backup" @@ -32,6 +33,7 @@ import ( "github.com/flatrun/agent/internal/files" "github.com/flatrun/agent/internal/infra" "github.com/flatrun/agent/internal/networks" + "github.com/flatrun/agent/internal/plan" "github.com/flatrun/agent/internal/proxy" "github.com/flatrun/agent/internal/scheduler" "github.com/flatrun/agent/internal/security" @@ -81,6 +83,9 @@ type Server struct { setupManager *setup.Manager setupHandlers *setup.Handlers certRenewer *ssl.Renewer + planStore *plan.Store + aiProvider ai.Provider + aiSessions *ai.SessionStore statsMu sync.RWMutex statsCache gin.H @@ -265,6 +270,16 @@ func New(cfg *config.Config, configPath string) *Server { clusterManager: clusterManager, setupManager: setupManager, setupHandlers: setup.NewHandlers(setupManager, authManager), + planStore: plan.NewStore(cfg.DeploymentsPath), + aiSessions: ai.NewSessionStore(cfg.DeploymentsPath), + } + + s.planStore.StartPruneLoop(context.Background(), time.Hour, time.Duration(cfg.Plans.RetentionDays)*24*time.Hour) + + if provider, aiErr := ai.New(&cfg.AI); aiErr == nil { + s.aiProvider = provider + } else if aiErr != ai.ErrDisabled { + log.Printf("Warning: failed to initialize AI provider: %v", aiErr) } if backupManager != nil { @@ -380,6 +395,31 @@ func (s *Server) setupRoutes() { protected.GET("/config/*key", s.authMiddleware.RequirePermission(auth.PermConfigRead), s.getConfigKey) protected.PUT("/config/*key", s.authMiddleware.RequirePermission(auth.PermConfigWrite), s.updateConfigKey) + // Plan endpoints (previewed mutations; see internal/plan) + protected.GET("/plans", s.listPlans) + protected.GET("/plans/:id", s.getPlan) + protected.POST("/plans/:id/apply", s.applyPlan) + protected.DELETE("/plans/:id", s.deletePlan) + + // Service-level actions + protected.POST("/deployments/:name/services/:service/start", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.serviceActionHandler("start")) + protected.POST("/deployments/:name/services/:service/stop", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.serviceActionHandler("stop")) + protected.POST("/deployments/:name/services/:service/restart", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.serviceActionHandler("restart")) + protected.POST("/deployments/:name/services/:service/rebuild", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.serviceActionHandler("rebuild")) + protected.POST("/deployments/:name/services/:service/pull", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.serviceActionHandler("pull")) + + // AI assist endpoints + protected.GET("/ai/status", s.getAIStatus) + protected.POST("/ai/analyze", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.aiAssistSystem) + protected.POST("/deployments/:name/ai/analyze", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.aiAssistDeployment) + + // Interactive AI sessions (agentic tool loop) + protected.POST("/ai/sessions", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.createAISession) + protected.GET("/ai/sessions/:id", s.getAISession) + protected.POST("/ai/sessions/:id/messages", s.postAISessionMessage) + protected.POST("/ai/sessions/:id/approve", s.approveAISessionTools) + protected.DELETE("/ai/sessions/:id", s.deleteAISession) + // Compose, stats, subdomain (deployment-scoped) protected.GET("/subdomain/generate", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.generateSubdomain) protected.POST("/compose/update", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.updateCompose) @@ -1437,6 +1477,14 @@ func (s *Server) updateDeploymentEnv(c *gin.Context) { return } + if planRequested(c) { + s.planEnvUpdate(c, name, req.EnvVars) + return + } + if !s.requirePlannedAction(c, name) { + return + } + if err := s.writeEnvFile(name, req.EnvVars); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1502,6 +1550,14 @@ func (s *Server) updateDeployment(c *gin.Context) { return } + if planRequested(c) { + s.planComposeUpdate(c, name, req.ComposeContent) + return + } + if !s.requirePlannedAction(c, name) { + return + } + if err := s.manager.UpdateDeployment(name, req.ComposeContent); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), @@ -1542,7 +1598,9 @@ func (s *Server) updateDeploymentMetadata(c *gin.Context) { return } - if _, ok := sentFields["protected_mode"]; ok { + _, sentProtectedMode := sentFields["protected_mode"] + _, sentRequirePlan := sentFields["require_plan"] + if sentProtectedMode || sentRequirePlan { if !s.requireDeploymentAccess(c, name, auth.AccessLevelAdmin) { return } @@ -1629,6 +1687,9 @@ func mergeMetadata(existing, incoming *models.ServiceMetadata, sentFields map[st if _, ok := sentFields["protected_mode"]; ok { merged.ProtectedMode = incoming.ProtectedMode } + if _, ok := sentFields["require_plan"]; ok { + merged.RequirePlan = incoming.RequirePlan + } return &merged } @@ -1682,58 +1743,22 @@ func (s *Server) deleteDeployment(c *gin.Context) { return } - deleteSSL := c.DefaultQuery("delete_ssl", "true") == "true" - deleteDatabase := c.DefaultQuery("delete_database", "false") == "true" - deleteVhost := c.DefaultQuery("delete_vhost", "true") == "true" - - deployment, _ := s.manager.GetDeployment(name) - - var deletedItems []string - - if deleteVhost { - if err := s.proxyOrchestrator.TeardownDeployment(name); err != nil { - log.Printf("Warning: failed to teardown proxy for %s: %v", name, err) - } else { - deletedItems = append(deletedItems, "virtual_host") - } + opts := deploymentDeleteOptions{ + DeleteSSL: c.DefaultQuery("delete_ssl", "true") == "true", + DeleteDatabase: c.DefaultQuery("delete_database", "false") == "true", + DeleteVhost: c.DefaultQuery("delete_vhost", "true") == "true", } - if deployment != nil && deployment.Metadata != nil && deleteSSL { - domainsToDelete := deployment.Metadata.GetUniqueDomainNames() - if len(domainsToDelete) == 0 && deployment.Metadata.Networking.Domain != "" { - domainsToDelete = []string{deployment.Metadata.Networking.Domain} - } - - for _, domain := range domainsToDelete { - if err := s.proxyOrchestrator.SSLManager().DeleteCertificate(domain); err != nil { - log.Printf("Warning: failed to delete SSL certificate for %s: %v", domain, err) - } else { - deletedItems = append(deletedItems, fmt.Sprintf("ssl_certificate:%s", domain)) - } - } + if planRequested(c) { + s.planDeploymentDelete(c, name, opts) + return } - - if deleteDatabase && s.config.Infrastructure.Database.Enabled { - if deployment != nil && deployment.Metadata != nil && len(deployment.Metadata.Databases) > 0 { - for _, dbConfig := range deployment.Metadata.Databases { - if dbConfig.IsShared { - if err := s.deleteDatabaseByAlias(name, dbConfig.Alias); err != nil { - log.Printf("Warning: failed to delete database %s for %s: %v", dbConfig.Alias, name, err) - } else { - deletedItems = append(deletedItems, fmt.Sprintf("database:%s", dbConfig.Alias)) - } - } - } - } else { - if err := s.deleteDatabaseForDeployment(name); err != nil { - log.Printf("Warning: failed to delete database for %s: %v", name, err) - } else { - deletedItems = append(deletedItems, "database") - } - } + if !s.requirePlannedAction(c, name) { + return } - if err := s.manager.DeleteDeployment(name); err != nil { + deletedItems, err := s.applyDeploymentDelete(name, opts) + if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), }) @@ -4195,6 +4220,14 @@ func (s *Server) setupProxy(c *gin.Context) { return } + if planRequested(c) { + s.planProxySetup(c, deployment) + return + } + if !s.requirePlannedAction(c, name) { + return + } + result, err := s.proxyOrchestrator.SetupDeployment(deployment) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -4470,56 +4503,22 @@ func (s *Server) addDomain(c *gin.Context) { return } - if domain.Domain == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Domain is required"}) + if planRequested(c) { + s.planDomainChange(c, deployment, "deployment.domain.add", &domain, func(dep *models.Deployment) (bool, error) { + return false, s.mutateDomainAdd(dep, &domain) + }) return } - resolved, err := s.resolveService(name, domain.Service) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + if !s.requirePlannedAction(c, name) { return } - domain.Service = resolved - if domain.ID == "" { - domain.ID = generateDomainID() - } - - if deployment.Metadata == nil { - deployment.Metadata = &models.ServiceMetadata{} - } - - if len(deployment.Metadata.Domains) == 0 && deployment.Metadata.Networking.Expose { - existingService := deployment.Metadata.Networking.Service - if existingService == "" { - existingService = resolved - } - existingDomain := models.DomainConfig{ - ID: "default", - Service: existingService, - ContainerPort: deployment.Metadata.Networking.ContainerPort, - Domain: deployment.Metadata.Networking.Domain, - SSL: deployment.Metadata.SSL, - } - deployment.Metadata.Domains = []models.DomainConfig{existingDomain} - } - - for _, existing := range deployment.Metadata.Domains { - if existing.Domain == domain.Domain && existing.PathPrefix == domain.PathPrefix { - c.JSON(http.StatusConflict, gin.H{ - "error": fmt.Sprintf("Domain %s%s already exists", domain.Domain, domain.PathPrefix), - }) - return - } - } - - if domain.ContainerPort == 0 && deployment.Metadata.Networking.ContainerPort != 0 { - domain.ContainerPort = deployment.Metadata.Networking.ContainerPort + if err := s.mutateDomainAdd(deployment, &domain); err != nil { + respondAPIError(c, err) + return } - deployment.Metadata.Domains = append(deployment.Metadata.Domains, domain) - if err := s.manager.SaveMetadata(name, deployment.Metadata); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save domain: " + err.Error()}) return @@ -4552,41 +4551,25 @@ func (s *Server) updateDomain(c *gin.Context) { return } - if deployment.Metadata == nil || len(deployment.Metadata.Domains) == 0 { - c.JSON(http.StatusNotFound, gin.H{"error": "Domain not found"}) - return - } - var updatedDomain models.DomainConfig if err := c.ShouldBindJSON(&updatedDomain); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid domain data: " + err.Error()}) return } - if updatedDomain.Service != "" { - resolved, err := s.resolveService(name, updatedDomain.Service) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - updatedDomain.Service = resolved + if planRequested(c) { + s.planDomainChange(c, deployment, "deployment.domain.update", &updatedDomain, func(dep *models.Deployment) (bool, error) { + return false, s.mutateDomainUpdate(dep, domainID, &updatedDomain) + }) + return } - found := false - for i, d := range deployment.Metadata.Domains { - if d.ID == domainID { - updatedDomain.ID = domainID - if updatedDomain.Service == "" { - updatedDomain.Service = d.Service - } - deployment.Metadata.Domains[i] = updatedDomain - found = true - break - } + if !s.requirePlannedAction(c, name) { + return } - if !found { - c.JSON(http.StatusNotFound, gin.H{"error": "Domain not found"}) + if err := s.mutateDomainUpdate(deployment, domainID, &updatedDomain); err != nil { + respondAPIError(c, err) return } @@ -4618,79 +4601,32 @@ func (s *Server) deleteDomain(c *gin.Context) { return } - if deployment.Metadata == nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Domain not found"}) - return - } - - // Handle legacy "default" domain from networking config - if domainID == "default" && len(deployment.Metadata.Domains) == 0 { - if !deployment.Metadata.Networking.Expose || deployment.Metadata.Networking.Domain == "" { - c.JSON(http.StatusNotFound, gin.H{"error": "Domain not found"}) - return - } - // Clear the legacy networking config - deployment.Metadata.Networking.Expose = false - deployment.Metadata.Networking.Domain = "" - deployment.Metadata.SSL.Enabled = false - deployment.Metadata.SSL.AutoCert = false - - if err := s.manager.SaveMetadata(name, deployment.Metadata); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save metadata: " + err.Error()}) - return - } - - if s.proxyOrchestrator != nil { - if err := s.proxyOrchestrator.TeardownDeployment(name); err != nil { - log.Printf("Warning: failed to teardown proxy for %s: %v", name, err) - } - } - - c.JSON(http.StatusOK, gin.H{"message": "Domain deleted successfully"}) + if planRequested(c) { + s.planDomainChange(c, deployment, "deployment.domain.delete", nil, func(dep *models.Deployment) (bool, error) { + return mutateDomainDelete(dep, domainID) + }) return } - if len(deployment.Metadata.Domains) == 0 { - c.JSON(http.StatusNotFound, gin.H{"error": "Domain not found"}) + if !s.requirePlannedAction(c, name) { return } - found := false - newDomains := make([]models.DomainConfig, 0) - for _, d := range deployment.Metadata.Domains { - if d.ID == domainID { - found = true - continue - } - newDomains = append(newDomains, d) - } - - if !found { - c.JSON(http.StatusNotFound, gin.H{"error": "Domain not found"}) + teardown, err := mutateDomainDelete(deployment, domainID) + if err != nil { + respondAPIError(c, err) return } - if len(newDomains) == 0 { - deployment.Metadata.Domains = nil - } else { - deployment.Metadata.Domains = newDomains - } - if err := s.manager.SaveMetadata(name, deployment.Metadata); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save metadata: " + err.Error()}) return } if s.proxyOrchestrator != nil { - if len(newDomains) == 0 { - if deployment.Metadata.Networking.Expose { - if _, err := s.proxyOrchestrator.SetupDeployment(deployment); err != nil { - log.Printf("Warning: failed to setup legacy proxy for %s: %v", name, err) - } - } else { - if err := s.proxyOrchestrator.TeardownDeployment(name); err != nil { - log.Printf("Warning: failed to teardown proxy for %s: %v", name, err) - } + if teardown { + if err := s.proxyOrchestrator.TeardownDeployment(name); err != nil { + log.Printf("Warning: failed to teardown proxy for %s: %v", name, err) } } else { if _, err := s.proxyOrchestrator.SetupDeployment(deployment); err != nil { diff --git a/internal/api/service_actions.go b/internal/api/service_actions.go new file mode 100644 index 0000000..0f3a706 --- /dev/null +++ b/internal/api/service_actions.go @@ -0,0 +1,156 @@ +package api + +import ( + "net/http" + "path/filepath" + + "github.com/flatrun/agent/internal/plan" + "github.com/gin-gonic/gin" +) + +type serviceActionSpec struct { + verb string + planAction string + protectedAction string + message string + changeActions []string + changeReason string + run func(s *Server, name, service string) (string, error) +} + +func serviceActionSpecs() map[string]serviceActionSpec { + return map[string]serviceActionSpec{ + "start": { + verb: "start", + planAction: "deployment.service.start", + message: "Service started", + changeActions: []string{plan.ActionCreate}, + changeReason: "containers for this service are started, created from the current compose file if missing", + run: func(s *Server, name, service string) (string, error) { + auth, opts := s.deploymentAuthOptions(name) + defer auth.Close() + return s.manager.StartService(name, service, opts...) + }, + }, + "stop": { + verb: "stop", + planAction: "deployment.service.stop", + message: "Service stopped", + changeActions: []string{plan.ActionDelete}, + changeReason: "containers for this service are stopped; data and configuration are untouched", + run: func(s *Server, name, service string) (string, error) { + return s.manager.StopService(name, service) + }, + }, + "restart": { + verb: "restart", + planAction: "deployment.service.restart", + message: "Service restarted", + changeActions: []string{plan.ActionDelete, plan.ActionCreate}, + changeReason: "containers for this service are stopped and started again", + run: func(s *Server, name, service string) (string, error) { + auth, opts := s.deploymentAuthOptions(name) + defer auth.Close() + return s.manager.RestartService(name, service, opts...) + }, + }, + "pull": { + verb: "pull", + planAction: "deployment.service.pull", + message: "Service image pulled", + changeActions: []string{plan.ActionUpdate}, + changeReason: "latest image for this service is pulled from the registry; running containers are unchanged until the next deploy", + run: func(s *Server, name, service string) (string, error) { + auth, opts := s.deploymentAuthOptions(name) + defer auth.Close() + return s.manager.PullService(name, service, opts...) + }, + }, + "rebuild": { + verb: "rebuild", + planAction: "deployment.service.rebuild", + protectedAction: protectedActionRebuild, + message: "Service rebuilt", + changeActions: []string{plan.ActionDelete, plan.ActionCreate}, + changeReason: "service image is rebuilt and its containers are recreated from the current compose file", + run: func(s *Server, name, service string) (string, error) { + auth, opts := s.deploymentAuthOptions(name) + defer auth.Close() + return s.manager.RebuildService(name, service, opts...) + }, + }, + } +} + +func (s *Server) serviceActionHandler(verb string) gin.HandlerFunc { + spec := serviceActionSpecs()[verb] + return func(c *gin.Context) { + name := c.Param("name") + if spec.protectedAction != "" && !s.requireUnprotectedDeploymentAction(c, name, spec.protectedAction) { + return + } + + service, err := s.resolveService(name, c.Param("service")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if planRequested(c) { + s.planServiceAction(c, name, service, spec) + return + } + if !s.requirePlannedAction(c, name) { + return + } + + output, err := spec.run(s, name, service) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error(), "output": output}) + return + } + c.JSON(http.StatusOK, gin.H{ + "message": spec.message, + "name": name, + "service": service, + "output": output, + }) + } +} + +func (s *Server) planServiceAction(c *gin.Context, name, service string, spec serviceActionSpec) { + p := s.newPlan(spec.planAction, "deployment", name) + if _, composeName, err := s.manager.GetComposeFile(name); err == nil { + p.Snapshot.Files = plan.SnapshotFiles(s.config.DeploymentsPath, filepath.Join(name, composeName)) + } + p.Changes = append(p.Changes, plan.Change{ + Type: "service", ID: service, + Actions: spec.changeActions, + Reason: spec.changeReason, + }) + s.finishPlan(c, p, nil) +} + +func applyPlannedServiceAction(verb string) func(*Server, *plan.Plan) (gin.H, error) { + return func(s *Server, p *plan.Plan) (gin.H, error) { + spec := serviceActionSpecs()[verb] + service := p.Request.Params["service"] + if service == "" { + return nil, apiErrf(http.StatusBadRequest, "plan is missing the service name") + } + resolved, err := s.resolveService(p.Resource.ID, service) + if err != nil { + return nil, apiErrf(http.StatusBadRequest, "%s", err.Error()) + } + output, err := spec.run(s, p.Resource.ID, resolved) + if err != nil { + return nil, err + } + return gin.H{ + "message": spec.message, + "name": p.Resource.ID, + "service": resolved, + "output": output, + }, nil + } +} diff --git a/internal/docker/compose.go b/internal/docker/compose.go index 2b6682a..ccec648 100644 --- a/internal/docker/compose.go +++ b/internal/docker/compose.go @@ -63,6 +63,30 @@ func (c *ComposeExecutor) Rebuild(deploymentPath string, opts ...RunOption) (str return c.runCompose(deploymentPath, opts, "up", "-d", "--build", "--remove-orphans") } +func (c *ComposeExecutor) StartService(deploymentPath, service string, opts ...RunOption) (string, error) { + output, err := c.runCompose(deploymentPath, opts, "start", service) + if err != nil { + return c.runCompose(deploymentPath, opts, "up", "-d", "--no-deps", service) + } + return output, nil +} + +func (c *ComposeExecutor) StopService(deploymentPath, service string, opts ...RunOption) (string, error) { + return c.runCompose(deploymentPath, opts, "stop", service) +} + +func (c *ComposeExecutor) RestartService(deploymentPath, service string, opts ...RunOption) (string, error) { + return c.runCompose(deploymentPath, opts, "restart", service) +} + +func (c *ComposeExecutor) RebuildService(deploymentPath, service string, opts ...RunOption) (string, error) { + return c.runCompose(deploymentPath, opts, "up", "-d", "--no-deps", "--build", "--force-recreate", service) +} + +func (c *ComposeExecutor) PullService(deploymentPath, service string, opts ...RunOption) (string, error) { + return c.runCompose(deploymentPath, opts, "pull", "--ignore-buildable", "--policy", "always", service) +} + func (c *ComposeExecutor) Logs(deploymentPath string, tail int) (string, error) { tailStr := fmt.Sprintf("%d", tail) return c.runCompose(deploymentPath, nil, "logs", "--tail", tailStr) diff --git a/internal/docker/manager.go b/internal/docker/manager.go index e0bd9c0..aeaa7ce 100644 --- a/internal/docker/manager.go +++ b/internal/docker/manager.go @@ -412,6 +412,68 @@ func (m *Manager) RebuildDeployment(name string, opts ...RunOption) (string, err return output, nil } +func (m *Manager) StartService(name, service string, opts ...RunOption) (string, error) { + m.mu.RLock() + deployment, err := m.discovery.GetDeployment(name) + m.mu.RUnlock() + + if err != nil { + return "", err + } + + m.ensureContainerNames(name) + return m.executor.StartService(deployment.Path, service, opts...) +} + +func (m *Manager) StopService(name, service string, opts ...RunOption) (string, error) { + m.mu.RLock() + deployment, err := m.discovery.GetDeployment(name) + m.mu.RUnlock() + + if err != nil { + return "", err + } + + return m.executor.StopService(deployment.Path, service, opts...) +} + +func (m *Manager) RestartService(name, service string, opts ...RunOption) (string, error) { + m.mu.RLock() + deployment, err := m.discovery.GetDeployment(name) + m.mu.RUnlock() + + if err != nil { + return "", err + } + + return m.executor.RestartService(deployment.Path, service, opts...) +} + +func (m *Manager) RebuildService(name, service string, opts ...RunOption) (string, error) { + m.mu.RLock() + deployment, err := m.discovery.GetDeployment(name) + m.mu.RUnlock() + + if err != nil { + return "", err + } + + m.ensureContainerNames(name) + return m.executor.RebuildService(deployment.Path, service, opts...) +} + +func (m *Manager) PullService(name, service string, opts ...RunOption) (string, error) { + m.mu.RLock() + deployment, err := m.discovery.GetDeployment(name) + m.mu.RUnlock() + + if err != nil { + return "", err + } + + return m.executor.PullService(deployment.Path, service, opts...) +} + func (m *Manager) PullDeployment(name string, onlyLatest bool, opts ...RunOption) (string, error) { m.mu.RLock() deployment, err := m.discovery.GetDeployment(name) diff --git a/internal/nginx/manager.go b/internal/nginx/manager.go index 65b1864..194e7c4 100644 --- a/internal/nginx/manager.go +++ b/internal/nginx/manager.go @@ -133,6 +133,20 @@ func (m *Manager) UpdateVirtualHost(deployment *models.Deployment) error { return m.CreateVirtualHost(deployment) } +// RenderVirtualHost returns the config that CreateVirtualHost would +// write, without touching disk or nginx. +func (m *Manager) RenderVirtualHost(deployment *models.Deployment) (string, error) { + if deployment.Metadata == nil { + return "", fmt.Errorf("deployment has no metadata") + } + if len(deployment.Metadata.GetDomains()) == 0 { + return "", nil + } + m.mu.RLock() + defer m.mu.RUnlock() + return m.generateMultiDomainConfig(deployment) +} + func (m *Manager) VirtualHostExists(deploymentName string) bool { configFile := filepath.Join(m.configPath, deploymentName+".conf") _, err := os.Stat(configFile) diff --git a/internal/plan/hash.go b/internal/plan/hash.go new file mode 100644 index 0000000..c958ac9 --- /dev/null +++ b/internal/plan/hash.go @@ -0,0 +1,65 @@ +package plan + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "sort" + "strings" +) + +const absentMarker = "absent" + +type DriftError struct { + Paths []string +} + +func (e *DriftError) Error() string { + return fmt.Sprintf("state changed since plan was created: %s", strings.Join(e.Paths, ", ")) +} + +func hashFile(path string) string { + data, err := os.ReadFile(path) + if err != nil { + return absentMarker + } + sum := sha256.Sum256(data) + return "sha256:" + hex.EncodeToString(sum[:]) +} + +// SnapshotFiles hashes the given paths. Relative paths are resolved +// against base and stored relative; absolute paths are stored as-is. +// Missing files hash to the literal "absent" so creation is detectable. +func SnapshotFiles(base string, paths ...string) map[string]string { + out := make(map[string]string, len(paths)) + for _, p := range paths { + full := p + if !filepath.IsAbs(p) { + full = filepath.Join(base, p) + } + out[p] = hashFile(full) + } + return out +} + +// VerifySnapshot re-hashes every snapshotted file and returns a +// *DriftError listing the paths whose content changed since plan time. +func VerifySnapshot(base string, snapshot map[string]string) error { + var drifted []string + for p, want := range snapshot { + full := p + if !filepath.IsAbs(p) { + full = filepath.Join(base, p) + } + if hashFile(full) != want { + drifted = append(drifted, p) + } + } + if len(drifted) > 0 { + sort.Strings(drifted) + return &DriftError{Paths: drifted} + } + return nil +} diff --git a/internal/plan/plan.go b/internal/plan/plan.go new file mode 100644 index 0000000..a0ae0f8 --- /dev/null +++ b/internal/plan/plan.go @@ -0,0 +1,156 @@ +package plan + +import ( + "encoding/json" + "time" + + "github.com/google/uuid" +) + +const FormatVersion = 1 + +const ( + StatusAvailable = "available" + StatusApplying = "applying" + StatusApplied = "applied" + StatusFailed = "failed" + StatusObsolete = "obsolete" + StatusExpired = "expired" +) + +const ( + ActionCreate = "create" + ActionUpdate = "update" + ActionDelete = "delete" + ActionNoOp = "no-op" +) + +const RedactedPlaceholder = "[redacted]" + +type Resource struct { + Type string `json:"type"` + ID string `json:"id"` +} + +type Actor struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` +} + +type RequestEnvelope struct { + Method string `json:"method"` + Path string `json:"path"` + Params map[string]string `json:"params,omitempty"` + Query map[string]string `json:"query,omitempty"` + Body json.RawMessage `json:"body,omitempty"` +} + +type Change struct { + Type string `json:"type"` + ID string `json:"id"` + Actions []string `json:"actions"` + Reason string `json:"reason"` + Before *string `json:"before"` + After *string `json:"after"` + Sensitive bool `json:"sensitive"` +} + +type Snapshot struct { + Files map[string]string `json:"files"` +} + +type Summary struct { + Create int `json:"create"` + Update int `json:"update"` + Replace int `json:"replace"` + Delete int `json:"delete"` + NoOp int `json:"no-op"` +} + +type Plan struct { + FormatVersion int `json:"format_version"` + ID string `json:"id"` + Action string `json:"action"` + Status string `json:"status"` + Resource Resource `json:"resource"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + CreatedBy Actor `json:"created_by"` + AppliedAt *time.Time `json:"applied_at,omitempty"` + AppliedBy *Actor `json:"applied_by,omitempty"` + ApplyError string `json:"apply_error,omitempty"` + Request RequestEnvelope `json:"request"` + Snapshot Snapshot `json:"snapshot"` + Changes []Change `json:"changes"` + Summary Summary `json:"summary"` +} + +func New(action string, resource Resource, actor Actor, ttl time.Duration) *Plan { + now := time.Now().UTC() + return &Plan{ + FormatVersion: FormatVersion, + ID: "pln_" + uuid.New().String(), + Action: action, + Status: StatusAvailable, + Resource: resource, + CreatedAt: now, + ExpiresAt: now.Add(ttl), + CreatedBy: actor, + Snapshot: Snapshot{Files: map[string]string{}}, + } +} + +func StrPtr(s string) *string { + return &s +} + +func (p *Plan) Expired(now time.Time) bool { + return now.After(p.ExpiresAt) +} + +// Summarize recomputes Summary from Changes. An ordered action pair +// (delete+create or create+delete) counts as one replace. +func (p *Plan) Summarize() { + s := Summary{} + for _, ch := range p.Changes { + switch { + case len(ch.Actions) == 2: + s.Replace++ + case len(ch.Actions) == 1 && ch.Actions[0] == ActionCreate: + s.Create++ + case len(ch.Actions) == 1 && ch.Actions[0] == ActionUpdate: + s.Update++ + case len(ch.Actions) == 1 && ch.Actions[0] == ActionDelete: + s.Delete++ + default: + s.NoOp++ + } + } + p.Summary = s +} + +// Redacted returns a copy safe for API responses: sensitive change +// contents are masked. The on-disk plan keeps full values (0600, same +// trust domain as .env.flatrun). +func (p *Plan) Redacted() *Plan { + cp := *p + cp.Changes = make([]Change, len(p.Changes)) + hasSensitive := false + for i, ch := range p.Changes { + if ch.Sensitive { + hasSensitive = true + if ch.Before != nil { + ch.Before = StrPtr(RedactedPlaceholder) + } + if ch.After != nil { + ch.After = StrPtr(RedactedPlaceholder) + } + } + cp.Changes[i] = ch + } + if hasSensitive && p.Request.Body != nil { + cp.Request.Body = json.RawMessage(`"` + RedactedPlaceholder + `"`) + } + return &cp +} diff --git a/internal/plan/plan_test.go b/internal/plan/plan_test.go new file mode 100644 index 0000000..514c796 --- /dev/null +++ b/internal/plan/plan_test.go @@ -0,0 +1,209 @@ +package plan + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func newTestPlan(t *testing.T) *Plan { + t.Helper() + return New("deployment.env.update", + Resource{Type: "deployment", ID: "myapp"}, + Actor{ID: "u1", Name: "admin", Type: "user"}, + time.Hour) +} + +func TestNewPlanDefaults(t *testing.T) { + p := newTestPlan(t) + if p.FormatVersion != FormatVersion { + t.Errorf("format_version = %d, want %d", p.FormatVersion, FormatVersion) + } + if p.Status != StatusAvailable { + t.Errorf("status = %q, want available", p.Status) + } + if !strings.HasPrefix(p.ID, "pln_") { + t.Errorf("id %q missing pln_ prefix", p.ID) + } + if !p.ExpiresAt.After(p.CreatedAt) { + t.Error("expires_at not after created_at") + } +} + +func TestSummarize(t *testing.T) { + p := newTestPlan(t) + p.Changes = []Change{ + {Actions: []string{ActionCreate}}, + {Actions: []string{ActionUpdate}}, + {Actions: []string{ActionUpdate}}, + {Actions: []string{ActionDelete}}, + {Actions: []string{ActionDelete, ActionCreate}}, + {Actions: []string{ActionNoOp}}, + } + p.Summarize() + want := Summary{Create: 1, Update: 2, Delete: 1, Replace: 1, NoOp: 1} + if p.Summary != want { + t.Errorf("summary = %+v, want %+v", p.Summary, want) + } +} + +func TestRedacted(t *testing.T) { + p := newTestPlan(t) + p.Request.Body = json.RawMessage(`{"env_vars":[{"key":"SECRET","value":"hunter2"}]}`) + p.Changes = []Change{ + {ID: ".env.flatrun", Actions: []string{ActionUpdate}, Before: StrPtr("SECRET=old"), After: StrPtr("SECRET=hunter2"), Sensitive: true}, + {ID: "web", Actions: []string{ActionDelete, ActionCreate}, Reason: "recreate"}, + } + r := p.Redacted() + if *r.Changes[0].Before != RedactedPlaceholder || *r.Changes[0].After != RedactedPlaceholder { + t.Error("sensitive change not redacted") + } + if strings.Contains(string(r.Request.Body), "hunter2") { + t.Error("request body not redacted when plan has sensitive changes") + } + if *p.Changes[0].After != "SECRET=hunter2" { + t.Error("original plan mutated by Redacted") + } + if r.Changes[1].Before != nil { + t.Error("non-sensitive change should be untouched") + } +} + +func TestSnapshotAndVerify(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "a.txt"), []byte("one"), 0600); err != nil { + t.Fatal(err) + } + + snap := SnapshotFiles(dir, "a.txt", "missing.txt") + if snap["missing.txt"] != "absent" { + t.Errorf("missing file hash = %q, want absent", snap["missing.txt"]) + } + if !strings.HasPrefix(snap["a.txt"], "sha256:") { + t.Errorf("hash %q missing sha256 prefix", snap["a.txt"]) + } + + if err := VerifySnapshot(dir, snap); err != nil { + t.Fatalf("unchanged snapshot should verify, got %v", err) + } + + if err := os.WriteFile(filepath.Join(dir, "a.txt"), []byte("two"), 0600); err != nil { + t.Fatal(err) + } + err := VerifySnapshot(dir, snap) + drift, ok := err.(*DriftError) + if !ok { + t.Fatalf("want *DriftError, got %T (%v)", err, err) + } + if len(drift.Paths) != 1 || drift.Paths[0] != "a.txt" { + t.Errorf("drift paths = %v, want [a.txt]", drift.Paths) + } + + // Creating a previously absent file is drift too. + if err := os.WriteFile(filepath.Join(dir, "missing.txt"), []byte("x"), 0600); err != nil { + t.Fatal(err) + } + if VerifySnapshot(dir, snap) == nil { + t.Error("created file should count as drift") + } +} + +func TestStoreRoundTrip(t *testing.T) { + dir := t.TempDir() + store := NewStore(dir) + + p := newTestPlan(t) + p.Changes = []Change{{Type: "file", ID: ".env.flatrun", Actions: []string{ActionUpdate}}} + if err := store.Save(p); err != nil { + t.Fatal(err) + } + + onDisk := filepath.Join(store.Root(), "deployment", "myapp", p.ID+".json") + info, err := os.Stat(onDisk) + if err != nil { + t.Fatalf("plan file not at expected path: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("plan file mode = %v, want 0600", info.Mode().Perm()) + } + + got, err := store.Get(p.ID) + if err != nil { + t.Fatal(err) + } + if got.Action != p.Action || got.Resource != p.Resource { + t.Errorf("round trip mismatch: %+v", got) + } + + plans, err := store.List(ListFilter{ResourceType: "deployment", ResourceID: "myapp"}) + if err != nil { + t.Fatal(err) + } + if len(plans) != 1 { + t.Fatalf("list returned %d plans, want 1", len(plans)) + } + + if err := store.Delete(p.ID); err != nil { + t.Fatal(err) + } + if _, err := store.Get(p.ID); err != ErrNotFound { + t.Errorf("after delete, Get err = %v, want ErrNotFound", err) + } +} + +func TestStoreRejectsBadIDs(t *testing.T) { + store := NewStore(t.TempDir()) + if _, err := store.Get("../../etc/passwd"); err != ErrNotFound { + t.Errorf("traversal id should be ErrNotFound, got %v", err) + } + p := newTestPlan(t) + p.ID = "pln_not-a-uuid" + if err := store.Save(p); err == nil { + t.Error("Save should reject malformed plan id") + } +} + +func TestPruneOnce(t *testing.T) { + dir := t.TempDir() + store := NewStore(dir) + now := time.Now().UTC() + + expired := newTestPlan(t) + expired.ExpiresAt = now.Add(-time.Minute) + + oldApplied := newTestPlan(t) + oldApplied.Status = StatusApplied + oldApplied.CreatedAt = now.Add(-31 * 24 * time.Hour) + + oldObsolete := newTestPlan(t) + oldObsolete.Status = StatusObsolete + oldObsolete.CreatedAt = now.Add(-8 * 24 * time.Hour) + + fresh := newTestPlan(t) + + for _, p := range []*Plan{expired, oldApplied, oldObsolete, fresh} { + if err := store.Save(p); err != nil { + t.Fatal(err) + } + } + + store.PruneOnce(now, 30*24*time.Hour) + + got, err := store.Get(expired.ID) + if err != nil || got.Status != StatusExpired { + t.Errorf("ttl-passed plan: status %v err %v, want expired", got, err) + } + if _, err := store.Get(oldApplied.ID); err != ErrNotFound { + t.Errorf("old applied plan should be deleted, got %v", err) + } + if _, err := store.Get(oldObsolete.ID); err != ErrNotFound { + t.Errorf("old obsolete plan should be deleted, got %v", err) + } + got, err = store.Get(fresh.ID) + if err != nil || got.Status != StatusAvailable { + t.Errorf("fresh plan should be untouched, got %v err %v", got, err) + } +} diff --git a/internal/plan/prune.go b/internal/plan/prune.go new file mode 100644 index 0000000..36e8b95 --- /dev/null +++ b/internal/plan/prune.go @@ -0,0 +1,60 @@ +package plan + +import ( + "context" + "log" + "time" +) + +const terminalRetention = 7 * 24 * time.Hour + +// PruneOnce expires available plans past their TTL, deletes applied and +// failed plans older than retention, and deletes expired and obsolete +// plans older than a week. Returns how many files were touched. +func (s *Store) PruneOnce(now time.Time, retention time.Duration) int { + plans, err := s.List(ListFilter{}) + if err != nil { + log.Printf("Warning: plan prune failed to list plans: %v", err) + return 0 + } + touched := 0 + for _, p := range plans { + switch p.Status { + case StatusAvailable: + if p.Expired(now) { + p.Status = StatusExpired + if err := s.Save(p); err == nil { + touched++ + } + } + case StatusApplied, StatusFailed: + if now.Sub(p.CreatedAt) > retention { + if err := s.Delete(p.ID); err == nil { + touched++ + } + } + case StatusExpired, StatusObsolete: + if now.Sub(p.CreatedAt) > terminalRetention { + if err := s.Delete(p.ID); err == nil { + touched++ + } + } + } + } + return touched +} + +func (s *Store) StartPruneLoop(ctx context.Context, interval, retention time.Duration) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.PruneOnce(time.Now().UTC(), retention) + } + } + }() +} diff --git a/internal/plan/store.go b/internal/plan/store.go new file mode 100644 index 0000000..432ff0d --- /dev/null +++ b/internal/plan/store.go @@ -0,0 +1,202 @@ +package plan + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" +) + +var planIDPattern = regexp.MustCompile(`^pln_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + +var ErrNotFound = fmt.Errorf("plan not found") + +type Store struct { + root string + + mu sync.Mutex + locks map[string]*sync.Mutex +} + +func NewStore(deploymentsPath string) *Store { + return &Store{ + root: filepath.Join(deploymentsPath, ".flatrun", "plans"), + locks: map[string]*sync.Mutex{}, + } +} + +func (s *Store) Root() string { + return s.root +} + +// LockResource serializes applies on the same resource. Returns the +// unlock function. +func (s *Store) LockResource(r Resource) func() { + key := r.Type + "/" + r.ID + s.mu.Lock() + l, ok := s.locks[key] + if !ok { + l = &sync.Mutex{} + s.locks[key] = l + } + s.mu.Unlock() + l.Lock() + return l.Unlock +} + +func sanitizeSegment(seg string) string { + var b strings.Builder + for _, r := range seg { + switch { + case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9', r == '-', r == '_', r == '.': + b.WriteRune(r) + default: + b.WriteRune('_') + } + } + out := b.String() + if out == "" || out == "." || out == ".." { + out = "_" + } + return out +} + +func (s *Store) planPath(p *Plan) string { + return filepath.Join(s.root, sanitizeSegment(p.Resource.Type), sanitizeSegment(p.Resource.ID), p.ID+".json") +} + +// Save writes the plan atomically (tmp file + rename, fsynced) so a +// crash never leaves a torn plan file behind. +func (s *Store) Save(p *Plan) error { + if !planIDPattern.MatchString(p.ID) { + return fmt.Errorf("invalid plan id %q", p.ID) + } + path := s.planPath(p) + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return err + } + data, err := json.MarshalIndent(p, "", " ") + if err != nil { + return err + } + tmp := path + ".tmp" + f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + if _, err := f.Write(data); err != nil { + f.Close() + os.Remove(tmp) + return err + } + if err := f.Sync(); err != nil { + f.Close() + os.Remove(tmp) + return err + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return err + } + return os.Rename(tmp, path) +} + +func (s *Store) findPath(id string) (string, error) { + if !planIDPattern.MatchString(id) { + return "", ErrNotFound + } + var found string + err := filepath.WalkDir(s.root, func(path string, d os.DirEntry, err error) error { + if err != nil { + if os.IsNotExist(err) { + return filepath.SkipAll + } + return err + } + if !d.IsDir() && d.Name() == id+".json" { + found = path + return filepath.SkipAll + } + return nil + }) + if err != nil { + return "", err + } + if found == "" { + return "", ErrNotFound + } + return found, nil +} + +func (s *Store) Get(id string) (*Plan, error) { + path, err := s.findPath(id) + if err != nil { + return nil, err + } + return readPlanFile(path) +} + +func readPlanFile(path string) (*Plan, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var p Plan + if err := json.Unmarshal(data, &p); err != nil { + return nil, fmt.Errorf("corrupt plan file %s: %w", filepath.Base(path), err) + } + return &p, nil +} + +type ListFilter struct { + ResourceType string + ResourceID string + Status string +} + +func (s *Store) List(filter ListFilter) ([]*Plan, error) { + var plans []*Plan + err := filepath.WalkDir(s.root, func(path string, d os.DirEntry, err error) error { + if err != nil { + if os.IsNotExist(err) { + return filepath.SkipAll + } + return err + } + if d.IsDir() || !strings.HasSuffix(d.Name(), ".json") { + return nil + } + p, readErr := readPlanFile(path) + if readErr != nil { + return nil + } + if filter.ResourceType != "" && p.Resource.Type != filter.ResourceType { + return nil + } + if filter.ResourceID != "" && p.Resource.ID != filter.ResourceID { + return nil + } + if filter.Status != "" && p.Status != filter.Status { + return nil + } + plans = append(plans, p) + return nil + }) + if err != nil { + return nil, err + } + sort.Slice(plans, func(i, j int) bool { return plans[i].CreatedAt.After(plans[j].CreatedAt) }) + return plans, nil +} + +func (s *Store) Delete(id string) error { + path, err := s.findPath(id) + if err != nil { + return err + } + return os.Remove(path) +} diff --git a/internal/proxy/orchestrator.go b/internal/proxy/orchestrator.go index a765ae4..91b3deb 100644 --- a/internal/proxy/orchestrator.go +++ b/internal/proxy/orchestrator.go @@ -42,6 +42,35 @@ func (o *Orchestrator) UpdateConfig(cfg *config.Config) { o.ssl.UpdateConfig(&cfg.Certbot, cfg.DeploymentsPath) } +// RenderDeployment returns the virtual host content SetupDeployment +// would install, applying the same SSL downgrade for missing +// certificates, without writing or reloading anything. Returns "" when +// the deployment is not exposed. +func (o *Orchestrator) RenderDeployment(deployment *models.Deployment) (string, error) { + if deployment.Metadata == nil { + return "", nil + } + domains := deployment.Metadata.GetDomains() + if len(domains) == 0 { + return "", nil + } + + rendered := make([]models.DomainConfig, len(domains)) + copy(rendered, domains) + for i := range rendered { + if rendered[i].SSL.Enabled && rendered[i].SSL.AutoCert && !o.ssl.CertificateExists(rendered[i].Domain) { + rendered[i].SSL.Enabled = false + } + } + + metaCopy := *deployment.Metadata + metaCopy.Domains = rendered + deploymentCopy := *deployment + deploymentCopy.Metadata = &metaCopy + + return o.nginx.RenderVirtualHost(&deploymentCopy) +} + func (o *Orchestrator) SetupDeployment(deployment *models.Deployment) (*SetupResult, error) { result := &SetupResult{ DeploymentName: deployment.Name, diff --git a/pkg/config/config.go b/pkg/config/config.go index 199bb52..e31a0f6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -39,6 +39,22 @@ type Config struct { Cluster ClusterConfig `yaml:"cluster"` SystemTerminal SystemTerminalConfig `yaml:"system_terminal"` Cleanup CleanupConfig `yaml:"cleanup"` + Plans PlansConfig `yaml:"plans"` + AI AIConfig `yaml:"ai"` +} + +type AIConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + BaseURL string `yaml:"base_url" json:"base_url"` + APIKey string `yaml:"api_key" json:"api_key"` + Model string `yaml:"model" json:"model"` + Timeout time.Duration `yaml:"timeout" json:"timeout"` + DocsURL string `yaml:"docs_url" json:"docs_url"` +} + +type PlansConfig struct { + TTL time.Duration `yaml:"ttl" json:"ttl"` + RetentionDays int `yaml:"retention_days" json:"retention_days"` } type DomainConfig struct { @@ -393,6 +409,23 @@ func setDefaults(cfg *Config) { if cfg.Audit.SensitiveFields == nil { cfg.Audit.SensitiveFields = []string{"password", "token", "secret", "api_key", "authorization"} } + // AI defaults + if cfg.AI.BaseURL == "" { + cfg.AI.BaseURL = "https://api.openai.com/v1" + } + if cfg.AI.Timeout == 0 { + cfg.AI.Timeout = 60 * time.Second + } + if cfg.AI.DocsURL == "" { + cfg.AI.DocsURL = "https://flatrun.dev/docs/" + } + // Plans defaults + if cfg.Plans.TTL == 0 { + cfg.Plans.TTL = 24 * time.Hour + } + if cfg.Plans.RetentionDays == 0 { + cfg.Plans.RetentionDays = 30 + } // Cluster defaults if cfg.Cluster.ServerName == "" { hostname, err := os.Hostname() diff --git a/pkg/config/registry.go b/pkg/config/registry.go index 0f28556..c36b487 100644 --- a/pkg/config/registry.go +++ b/pkg/config/registry.go @@ -24,6 +24,16 @@ var hiddenKeys = map[string]bool{ "auth.api_keys": true, } +// sensitiveKeys are masked on read like hiddenKeys but stay writable +// through the API, so credentials can be configured without ever being +// echoed back. +var sensitiveKeys = map[string]bool{ + "ai.api_key": true, + "infrastructure.database.root_password": true, + "infrastructure.redis.password": true, + "infrastructure.powerdns.api_key": true, +} + func Walk(cfg *Config) []Entry { defaults := &Config{} setDefaults(defaults) @@ -33,13 +43,14 @@ func Walk(cfg *Config) []Entry { out := make([]Entry, 0, len(current)) for _, e := range current { - if hiddenKeys[e.Key] { - e.Sensitive = true - e.Value = nil - } if d, ok := defaultMap[e.Key]; ok { e.Default = d } + if hiddenKeys[e.Key] || sensitiveKeys[e.Key] { + e.Sensitive = true + e.Value = nil + e.Default = nil + } out = append(out, e) } sort.Slice(out, func(i, j int) bool { return out[i].Key < out[j].Key }) diff --git a/pkg/models/deployment.go b/pkg/models/deployment.go index a6aea39..45ba1de 100644 --- a/pkg/models/deployment.go +++ b/pkg/models/deployment.go @@ -33,6 +33,7 @@ type ServiceMetadata struct { Security *DeploymentSecurityConfig `yaml:"security,omitempty" json:"security,omitempty"` Backup *BackupSpec `yaml:"backup,omitempty" json:"backup,omitempty"` ProtectedMode *ProtectedModeConfig `yaml:"protected_mode,omitempty" json:"protected_mode,omitempty"` + RequirePlan bool `yaml:"require_plan,omitempty" json:"require_plan,omitempty"` CredentialID string `yaml:"credential_id,omitempty" json:"credential_id,omitempty"` ServiceCredentials map[string]string `yaml:"service_credentials,omitempty" json:"service_credentials,omitempty"` Domains []DomainConfig `yaml:"domains,omitempty" json:"domains,omitempty"`