From b96332b95a54f3cfbc9ab1b6a734dd44e2847e1a Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Mar 2026 14:10:03 +0800 Subject: [PATCH 1/2] refactor: simplify startup TUI and fix config zero-value handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Startup: reduce interactive setup from 5 form groups to zero interaction. `crust start` defaults to auto mode; `--manual` prompts for endpoint+key. DB encryption key, telemetry, retention, port, and disable-builtin are now CLI-flag-only (already existed as --db-key, --telemetry, etc.). Config: use rawConfig with *int/*float64 for YAML parsing so Load() can distinguish absent fields (nil → use default) from explicit zeros (non-nil → copy as-is, let Validate reject invalid values). Fixes stale config files with max_buffer_events: 0 silently breaking the daemon. --- README.md | 4 +- docs/cli.md | 22 ++-- docs/tui.md | 6 +- internal/config/config.go | 102 ++++++++++++++++- internal/config/config_test.go | 60 +++++++++- internal/tui/startup/common.go | 118 ++++--------------- internal/tui/startup/startup.go | 159 +++----------------------- internal/tui/startup/startup_notui.go | 11 +- main.go | 47 +++++--- 9 files changed, 236 insertions(+), 293 deletions(-) diff --git a/README.md b/README.md index 602ff44a..e2f29271 100644 --- a/README.md +++ b/README.md @@ -89,10 +89,10 @@ docker run -p 9090:9090 crust Then start the gateway: ```bash -crust start --auto +crust start ``` -Auto mode detects your LLM provider from the model name — no endpoint URL or API key configuration needed. Your agent's existing auth is passed through. +Auto mode is the default — it detects your LLM provider from the model name with zero configuration. Your agent's existing auth is passed through. ## Agent Setup diff --git a/docs/cli.md b/docs/cli.md index 67b09926..90954243 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -4,10 +4,11 @@ ```bash # Gateway -crust start --auto # Auto mode (recommended) -crust start --endpoint URL --api-key KEY # Manual mode -crust start --auto --block-mode replace # Show block messages to agent -crust start --foreground --auto # Foreground mode (for Docker) +crust start # Auto mode (default, zero interaction) +crust start --manual # Prompt for endpoint URL + API key +crust start --endpoint URL --api-key KEY # Manual mode via flags +crust start --block-mode replace # Show block messages to agent +crust start --foreground # Foreground mode (for Docker) crust stop # Stop the gateway crust status [--json] [--live] # Check if running crust status --live --api-addr HOST:PORT # Remote dashboard (Docker) @@ -38,7 +39,8 @@ crust uninstall # Complete removal | Flag | Description | |------|-------------| -| `--auto` | Resolve providers from model names | +| `--auto` | Resolve providers from model names (default when no flags given) | +| `--manual` | Prompt interactively for endpoint URL and API key | | `--endpoint URL` | LLM API endpoint URL | | `--api-key KEY` | API key (prefer `LLM_API_KEY` env var) | | `--foreground` | Run in foreground (for Docker/containers) | @@ -108,17 +110,17 @@ Works with or without the daemon running. When the daemon is running, `crust sta ## Examples ```bash -# Interactive setup +# Auto mode (default, zero interaction) crust start -# Auto mode with env-based auth -crust start --auto +# Manual mode — interactive prompt for endpoint + API key +crust start --manual -# Manual mode with explicit endpoint +# Manual mode with explicit flags LLM_API_KEY=sk-xxx crust start --endpoint https://openrouter.ai/api/v1 # Docker/container mode -crust start --foreground --auto --listen-address 0.0.0.0 +crust start --foreground --listen-address 0.0.0.0 # Follow logs crust logs -f diff --git a/docs/tui.md b/docs/tui.md index 058af802..100c7707 100644 --- a/docs/tui.md +++ b/docs/tui.md @@ -34,7 +34,7 @@ internal/tui/ columns.go AlignColumns() for ANSI-aware two-column alignment banner/ Gradient ASCII art banner with reveal animation + RevealLines spinner/ Animated dot spinner with success glow effect - startup/ Interactive huh-based setup wizard with themed forms + startup/ Manual endpoint setup (huh form for --manual mode) terminal/ Terminal emulator detection and capability bitfield progress/ Determinate progress bar for multi-step operations dashboard/ Live status dashboard with auto-refreshing metrics + stats tab @@ -238,8 +238,8 @@ docker run -d -t -p 9090:9090 crust # Production: plain text logs (no -t) docker run -d -p 9090:9090 crust -# Interactive setup inside container -docker run -it --entrypoint crust crust start --foreground +# Interactive setup inside container (manual mode) +docker run -it --entrypoint crust crust start --foreground --manual # View logs (styled with -t, plain without) docker logs diff --git a/internal/config/config.go b/internal/config/config.go index 8f1bef04..c6f9fb8f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -131,7 +131,7 @@ type SecurityConfig struct { BlockMode types.BlockMode `yaml:"block_mode"` // "remove" (default) or "replace" (substitute with a text warning block) } -// Validate validates the SecurityConfig and sets defaults. +// Validate validates the SecurityConfig and sets defaults for nil fields. func (c *SecurityConfig) Validate() error { // Validate and default BlockMode if c.BlockMode == types.BlockModeUnset { @@ -282,6 +282,93 @@ func (c *Config) Validate() error { return errors.New(sb.String()) } +// rawConfig mirrors Config but uses pointers for fields where zero is +// invalid so YAML can distinguish "absent" (nil → use default) from +// "explicitly set to 0" (non-nil → copy, let Validate() reject). +type rawConfig struct { + Server struct { + Port *int `yaml:"port"` + LogLevel types.LogLevel `yaml:"log_level"` + NoColor bool `yaml:"no_color"` + } `yaml:"server"` + Upstream UpstreamConfig `yaml:"upstream"` + Storage StorageConfig `yaml:"storage"` + API APIConfig `yaml:"api"` + Telemetry struct { + Enabled bool `yaml:"enabled"` + RetentionDays int `yaml:"retention_days"` + ServiceName string `yaml:"service_name"` + SampleRate *float64 `yaml:"sample_rate"` + } `yaml:"telemetry"` + Security struct { + Enabled bool `yaml:"enabled"` + BufferStreaming bool `yaml:"buffer_streaming"` + MaxBufferEvents *int `yaml:"max_buffer_events"` + BufferTimeout *int `yaml:"buffer_timeout"` + BlockMode types.BlockMode `yaml:"block_mode"` + } `yaml:"security"` + Rules RulesConfig `yaml:"rules"` +} + +// applyTo merges parsed YAML onto defaults. +// nil = absent in YAML → keep default. non-nil = explicitly set → copy. +func (r *rawConfig) applyTo(dst *Config) { + if r.Server.Port != nil { + dst.Server.Port = *r.Server.Port + } + if r.Server.LogLevel != "" { + dst.Server.LogLevel = r.Server.LogLevel + } + dst.Server.NoColor = r.Server.NoColor + + if r.Upstream.URL != "" { + dst.Upstream.URL = r.Upstream.URL + } + dst.Upstream.Timeout = r.Upstream.Timeout + if len(r.Upstream.Providers) > 0 { + dst.Upstream.Providers = r.Upstream.Providers + } + + if r.Storage.DBPath != "" { + dst.Storage.DBPath = r.Storage.DBPath + } + if r.Storage.EncryptionKey != "" { + dst.Storage.EncryptionKey = r.Storage.EncryptionKey + } + + if r.API.SocketPath != "" { + dst.API.SocketPath = r.API.SocketPath + } + + dst.Telemetry.Enabled = r.Telemetry.Enabled + dst.Telemetry.RetentionDays = r.Telemetry.RetentionDays + if r.Telemetry.ServiceName != "" { + dst.Telemetry.ServiceName = r.Telemetry.ServiceName + } + if r.Telemetry.SampleRate != nil { + dst.Telemetry.SampleRate = *r.Telemetry.SampleRate + } + + dst.Security.Enabled = r.Security.Enabled + dst.Security.BufferStreaming = r.Security.BufferStreaming + if r.Security.MaxBufferEvents != nil { + dst.Security.MaxBufferEvents = *r.Security.MaxBufferEvents + } + if r.Security.BufferTimeout != nil { + dst.Security.BufferTimeout = *r.Security.BufferTimeout + } + if r.Security.BlockMode != types.BlockModeUnset { + dst.Security.BlockMode = r.Security.BlockMode + } + + dst.Rules.Enabled = r.Rules.Enabled + if r.Rules.UserDir != "" { + dst.Rules.UserDir = r.Rules.UserDir + } + dst.Rules.DisableBuiltin = r.Rules.DisableBuiltin + dst.Rules.Watch = r.Rules.Watch +} + // isUnknownFieldError returns true if the error is from yaml.Decoder.KnownFields(true) // detecting an unrecognized key (e.g. typo like "servr:"). func isUnknownFieldError(err error) bool { @@ -302,15 +389,15 @@ func Load(path string) (*Config, error) { return nil, err } - // Try strict decode to warn about unknown fields (typos like "servr:") + // Parse into rawConfig (pointer fields detect absent vs explicit zero). + var raw rawConfig dec := yaml.NewDecoder(bytes.NewReader(data)) dec.KnownFields(true) - if err := dec.Decode(cfg); err != nil { + if err := dec.Decode(&raw); err != nil { if isUnknownFieldError(err) { cfgLog.Warn("config has unknown fields (ignored): %v", err) - // Re-parse without strict mode for forward compatibility - cfg = DefaultConfig() - if err2 := yaml.Unmarshal(data, cfg); err2 != nil { + raw = rawConfig{} + if err2 := yaml.Unmarshal(data, &raw); err2 != nil { return nil, fmt.Errorf("config parse error: %w", err2) } } else { @@ -318,6 +405,9 @@ func Load(path string) (*Config, error) { } } + // Merge: nil → keep default, non-nil → copy (Validate catches bad values). + raw.applyTo(cfg) + // Expand environment variables in provider API keys. // Collect referenced env var names so the daemon can propagate them. for name, prov := range cfg.Upstream.Providers { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d85cc646..be56bc5a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -92,7 +92,6 @@ func TestValidate_PortRange(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "server.port") { t.Errorf("port 99999 should fail: %v", err) } - } func TestValidate_LogLevel(t *testing.T) { @@ -207,6 +206,63 @@ func TestValidate_RetentionDays(t *testing.T) { } } +func TestLoad_AbsentFieldsUseDefaults(t *testing.T) { + // Fields absent from YAML should keep defaults (rawConfig pointers are nil). + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + data := []byte("security:\n enabled: true\n buffer_streaming: true\n block_mode: replace\n") + if err := os.WriteFile(cfgPath, data, 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(cfgPath) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + // Absent → defaults preserved + if cfg.Server.Port != 9090 { + t.Errorf("Port = %d, want default 9090", cfg.Server.Port) + } + if cfg.Security.MaxBufferEvents != 50000 { + t.Errorf("MaxBufferEvents = %d, want default 50000", cfg.Security.MaxBufferEvents) + } + if cfg.Security.BufferTimeout != 120 { + t.Errorf("BufferTimeout = %d, want default 120", cfg.Security.BufferTimeout) + } + if cfg.Telemetry.SampleRate != 1.0 { + t.Errorf("SampleRate = %g, want default 1.0", cfg.Telemetry.SampleRate) + } +} + +func TestLoad_ExplicitZeroPreserved(t *testing.T) { + // Explicit zero in YAML must be preserved (not replaced with default) + // so that Validate() can reject it. + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + data := []byte("server:\n port: 0\nsecurity:\n buffer_streaming: true\n max_buffer_events: 0\n buffer_timeout: 0\n") + if err := os.WriteFile(cfgPath, data, 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(cfgPath) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if cfg.Server.Port != 0 { + t.Errorf("Port = %d, want explicit 0", cfg.Server.Port) + } + if cfg.Security.MaxBufferEvents != 0 { + t.Errorf("MaxBufferEvents = %d, want explicit 0", cfg.Security.MaxBufferEvents) + } + if cfg.Security.BufferTimeout != 0 { + t.Errorf("BufferTimeout = %d, want explicit 0", cfg.Security.BufferTimeout) + } + // Validate should reject these + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for explicit zeros") + } +} + func TestValidate_ProviderURL(t *testing.T) { cfg := DefaultConfig() cfg.Upstream.Providers = map[string]ProviderConfig{ @@ -244,7 +300,7 @@ func TestValidate_BlockMode(t *testing.T) { func TestValidate_MultipleErrors(t *testing.T) { cfg := DefaultConfig() - cfg.Server.Port = 0 + cfg.Server.Port = -1 // use -1 since mergeConfig would skip 0 cfg.Server.LogLevel = types.LogLevel("invalid") cfg.Upstream.Timeout = -1 cfg.Telemetry.SampleRate = 5.0 diff --git a/internal/tui/startup/common.go b/internal/tui/startup/common.go index 842d0d26..c649295e 100644 --- a/internal/tui/startup/common.go +++ b/internal/tui/startup/common.go @@ -6,16 +6,14 @@ import ( "fmt" "net/url" "os" - "strconv" "strings" "golang.org/x/term" - "github.com/BakeLens/crust/internal/rules" "github.com/BakeLens/crust/internal/tui" ) -// Config holds the configuration collected from the startup prompts +// Config holds the configuration collected from startup. type Config struct { // Mode AutoMode bool // auto mode: resolve provider from model name (per-provider keys or client auth) @@ -34,7 +32,7 @@ type Config struct { Canceled bool } -// Validate validates the startup configuration +// Validate validates the startup configuration. func (c *Config) Validate() error { if !c.AutoMode { if c.EndpointURL == "" { @@ -59,7 +57,7 @@ func (c *Config) Validate() error { return nil } -// ValidationErrors returns human-readable validation errors +// ValidationErrors returns human-readable validation errors. func (c *Config) ValidationErrors() []string { err := c.Validate() if err == nil { @@ -68,14 +66,6 @@ func (c *Config) ValidationErrors() []string { return []string{err.Error()} } -// DefaultProxyPort should match config.DefaultConfig -const DefaultProxyPort = 9090 - -// RunStartup runs the startup prompts and returns the configuration -func RunStartup(defaultEndpoint string) (Config, error) { - return RunStartupWithPort(defaultEndpoint, DefaultProxyPort) -} - // readPassword reads a password from the terminal without echoing. func readPassword() (string, error) { fd := int(os.Stdin.Fd()) //nolint:gosec // Fd() fits in int on all supported platforms @@ -96,104 +86,38 @@ func readPassword() (string, error) { return strings.TrimSpace(password), nil } -// runStartupReader runs plain text prompts using bufio.Reader. +// runManualReader prompts for endpoint URL and API key using plain text. // Used as fallback when plain mode is active (piped, NO_COLOR, etc.) // and as the sole implementation in notui builds. -func runStartupReader(defaultEndpoint string, defaultProxyPort int) (Config, error) { +func runManualReader(defaultEndpoint string) (Config, error) { reader := bufio.NewReader(os.Stdin) - config := Config{ - ProxyPort: defaultProxyPort, - RetentionDays: 7, - } + var config Config - fmt.Println(tui.Separator("Configuration")) + fmt.Println(tui.Separator("Manual Endpoint")) fmt.Println() prompt := ">" - fmt.Printf(" %s Use auto mode? (resolve provider from model name, per-provider keys or client auth) [Y/n]: ", prompt) - modeAnswer, _ := reader.ReadString('\n') - modeAnswer = strings.TrimSpace(strings.ToLower(modeAnswer)) - - if modeAnswer == "" || modeAnswer == "y" || modeAnswer == "yes" { //nolint:goconst - config.AutoMode = true - fmt.Println() - tui.PrintInfo("Auto mode enabled") - tui.PrintInfo("Providers will be resolved from model names") - tui.PrintInfo("Clients must provide their own auth headers") - fmt.Println() - } else { - config.AutoMode = false - - fmt.Printf(" %s Endpoint URL [%s]: ", prompt, defaultEndpoint) - endpoint, err := reader.ReadString('\n') - if err != nil { - return config, fmt.Errorf("failed to read endpoint: %w", err) - } - endpoint = strings.TrimSpace(endpoint) - if endpoint == "" { - endpoint = defaultEndpoint - } - config.EndpointURL = endpoint - - fmt.Printf(" %s API Key: ", prompt) - apiKey, err := readPassword() - if err != nil { - return config, fmt.Errorf("failed to read API key: %w", err) - } - config.APIKey = apiKey - fmt.Println() + fmt.Printf(" %s Endpoint URL [%s]: ", prompt, defaultEndpoint) + endpoint, err := reader.ReadString('\n') + if err != nil { + return config, fmt.Errorf("failed to read endpoint: %w", err) } + endpoint = strings.TrimSpace(endpoint) + if endpoint == "" { + endpoint = defaultEndpoint + } + config.EndpointURL = endpoint - fmt.Println(tui.Separator("Security")) - fmt.Println() - - fmt.Printf(" %s DB Encryption Key (optional, press Enter to skip): ", prompt) - dbKey, err := readPassword() + fmt.Printf(" %s API Key: ", prompt) + apiKey, err := readPassword() if err != nil { - return config, fmt.Errorf("failed to read DB key: %w", err) + return config, fmt.Errorf("failed to read API key: %w", err) } - config.EncryptionKey = dbKey + config.APIKey = apiKey fmt.Println() - fmt.Println(tui.Separator("Advanced")) fmt.Println() - fmt.Printf(" %s Configure advanced options? [y/N]: ", prompt) - advAnswer, _ := reader.ReadString('\n') - advAnswer = strings.TrimSpace(strings.ToLower(advAnswer)) - - if advAnswer == "y" || advAnswer == "yes" { - fmt.Println() - - fmt.Printf(" %s Enable telemetry? [y/N]: ", prompt) - telAnswer, _ := reader.ReadString('\n') - telAnswer = strings.TrimSpace(strings.ToLower(telAnswer)) - config.TelemetryEnabled = telAnswer == "y" || telAnswer == "yes" - - fmt.Printf(" %s Retention days (0=forever) [%d]: ", prompt, config.RetentionDays) - retStr, _ := reader.ReadString('\n') - retStr = strings.TrimSpace(retStr) - if retStr != "" { - if days, err := strconv.Atoi(retStr); err == nil && days >= 0 && days <= 36500 { - config.RetentionDays = days - } - } + tui.PrintInfo("Manual mode — " + config.EndpointURL) - fmt.Printf(" %s Disable builtin rules? (%d locked rules remain active) [y/N]: ", prompt, rules.CountLockedBuiltinRules()) - rulesAnswer, _ := reader.ReadString('\n') - rulesAnswer = strings.TrimSpace(strings.ToLower(rulesAnswer)) - config.DisableBuiltinRules = rulesAnswer == "y" || rulesAnswer == "yes" - - fmt.Printf(" %s Proxy port [%d]: ", prompt, config.ProxyPort) - proxyStr, _ := reader.ReadString('\n') - proxyStr = strings.TrimSpace(proxyStr) - if proxyStr != "" { - if port, err := strconv.Atoi(proxyStr); err == nil && port >= 1 && port <= 65535 { - config.ProxyPort = port - } - } - - } - - fmt.Println() return config, nil } diff --git a/internal/tui/startup/startup.go b/internal/tui/startup/startup.go index b55fa63e..f15949b9 100644 --- a/internal/tui/startup/startup.go +++ b/internal/tui/startup/startup.go @@ -6,28 +6,20 @@ import ( "errors" "fmt" "net/url" - "strconv" "github.com/charmbracelet/huh" "github.com/charmbracelet/lipgloss" - "github.com/BakeLens/crust/internal/rules" "github.com/BakeLens/crust/internal/tui" - "github.com/BakeLens/crust/internal/tui/banner" ) -// RunStartupWithPort runs the startup prompts with a custom default proxy port. -// Uses huh forms for interactive input when a TTY is available. -// Falls back to plain mode for non-interactive contexts. -func RunStartupWithPort(defaultEndpoint string, defaultProxyPort int) (Config, error) { - fmt.Println() - banner.PrintBanner("") - fmt.Println() - +// RunManualSetup prompts for endpoint URL and API key using a huh form. +// Falls back to plain-text prompts for non-interactive contexts. +func RunManualSetup(defaultEndpoint string) (Config, error) { if tui.IsPlainMode() { - return runStartupReader(defaultEndpoint, defaultProxyPort) + return runManualReader(defaultEndpoint) } - return runStartupForm(defaultEndpoint, defaultProxyPort) + return runManualForm(defaultEndpoint) } // crustTheme returns a huh theme using the Crust synthwave color palette. @@ -71,38 +63,13 @@ func crustTheme() *huh.Theme { return t } -// runStartupForm runs the interactive huh form-based wizard. -func runStartupForm(defaultEndpoint string, defaultProxyPort int) (Config, error) { - cfg := Config{ - ProxyPort: defaultProxyPort, - RetentionDays: 7, - } - - // Form field values (huh binds to pointers) - var mode = "auto" +// runManualForm prompts for endpoint URL and API key via huh form. +func runManualForm(defaultEndpoint string) (Config, error) { + var cfg Config var endpointURL = defaultEndpoint var apiKey string - var encryptionKey string - var showAdvanced bool - var telemetryEnabled bool - var retentionStr = "7" - var disableBuiltin bool - var proxyPortStr = strconv.Itoa(defaultProxyPort) form := huh.NewForm( - // Group 1: Mode selection - huh.NewGroup( - huh.NewSelect[string](). - Title("Connection Mode"). - Description("How should Crust connect to LLM providers?"). - Options( - huh.NewOption("Auto — resolve provider from model name, clients bring own auth", "auto"), - huh.NewOption("Manual — specify endpoint URL and API key", "manual"), - ). - Value(&mode), - ).Title("Configuration"), - - // Group 2: Manual mode settings (hidden in auto mode) huh.NewGroup( huh.NewInput(). Title("Endpoint URL"). @@ -129,74 +96,10 @@ func runStartupForm(defaultEndpoint string, defaultProxyPort int) (Config, error } return nil }), - ).Title("Endpoint").WithHideFunc(func() bool { - return mode == "auto" - }), - - // Group 3: Security - huh.NewGroup( - huh.NewInput(). - Title("DB Encryption Key"). - Description("Optional — protects telemetry database (min 16 chars, Enter to skip)"). - EchoMode(huh.EchoModePassword). - Value(&encryptionKey). - Validate(func(s string) error { - if s != "" && len(s) < 16 { - return errors.New("must be at least 16 characters") - } - return nil - }), - ).Title("Security"), - - // Group 4: Advanced options toggle - huh.NewGroup( - huh.NewConfirm(). - Title("Configure advanced options?"). - Description("Telemetry, retention, rules, and port settings"). - Value(&showAdvanced), - ).Title("Advanced"), - - // Group 5: Advanced settings (hidden unless toggled) - huh.NewGroup( - huh.NewConfirm(). - Title("Enable telemetry?"). - Description("Record API traces and tool call logs"). - Value(&telemetryEnabled), - huh.NewInput(). - Title("Retention days"). - Description("How long to keep telemetry data (0 = forever)"). - Placeholder("7"). - Value(&retentionStr). - Validate(func(s string) error { - if s == "" { - return nil - } - days, err := strconv.Atoi(s) - if err != nil { - return errors.New("must be a number") - } - if days < 0 || days > 36500 { - return errors.New("must be 0-36500") - } - return nil - }), - huh.NewConfirm(). - Title("Disable builtin rules?"). - Description(fmt.Sprintf("Only use user-defined rules (%d locked rules remain active)", rules.CountLockedBuiltinRules())). - Value(&disableBuiltin), - huh.NewInput(). - Title("Proxy port"). - Description("Port for the proxy server"). - Placeholder(strconv.Itoa(defaultProxyPort)). - Value(&proxyPortStr). - Validate(validatePort), - ).Title("Advanced Settings").WithHideFunc(func() bool { - return !showAdvanced - }), + ).Title("Manual Endpoint"), ).WithTheme(crustTheme()) - err := form.Run() - if err != nil { + if err := form.Run(); err != nil { if errors.Is(err, huh.ErrUserAborted) { cfg.Canceled = true return cfg, nil @@ -204,47 +107,11 @@ func runStartupForm(defaultEndpoint string, defaultProxyPort int) (Config, error return cfg, fmt.Errorf("startup form error: %w", err) } - // Map form values to config - cfg.AutoMode = mode == "auto" - if !cfg.AutoMode { - cfg.EndpointURL = endpointURL - cfg.APIKey = apiKey - } - cfg.EncryptionKey = encryptionKey + cfg.EndpointURL = endpointURL + cfg.APIKey = apiKey - if showAdvanced { - cfg.TelemetryEnabled = telemetryEnabled - cfg.DisableBuiltinRules = disableBuiltin - if days, err := strconv.Atoi(retentionStr); err == nil { - cfg.RetentionDays = days - } - if port, err := strconv.Atoi(proxyPortStr); err == nil { - cfg.ProxyPort = port - } - } - - // Print summary fmt.Println() - if cfg.AutoMode { - tui.PrintInfo("Auto mode — providers resolved from model names") - } else { - tui.PrintInfo("Manual mode — " + cfg.EndpointURL) - } + tui.PrintInfo("Manual mode — " + cfg.EndpointURL) return cfg, nil } - -// validatePort validates a port number string. -func validatePort(s string) error { - if s == "" { - return nil - } - port, err := strconv.Atoi(s) - if err != nil { - return errors.New("must be a number") - } - if port < 1 || port > 65535 { - return errors.New("must be 1-65535") - } - return nil -} diff --git a/internal/tui/startup/startup_notui.go b/internal/tui/startup/startup_notui.go index e2156faf..4f3a27e1 100644 --- a/internal/tui/startup/startup_notui.go +++ b/internal/tui/startup/startup_notui.go @@ -2,12 +2,7 @@ package startup -import "fmt" - -// RunStartupWithPort runs the startup prompts with a custom default proxy port (plain text, no TUI). -func RunStartupWithPort(defaultEndpoint string, defaultProxyPort int) (Config, error) { - fmt.Println() - fmt.Println("CRUST - Secure Gateway for AI Agents") - fmt.Println() - return runStartupReader(defaultEndpoint, defaultProxyPort) +// RunManualSetup prompts for endpoint URL and API key (plain text, no TUI). +func RunManualSetup(defaultEndpoint string) (Config, error) { + return runManualReader(defaultEndpoint) } diff --git a/main.go b/main.go index 2f1dc8b4..7a2ffe95 100644 --- a/main.go +++ b/main.go @@ -164,6 +164,7 @@ func runStart(args []string) { apiKey := startFlags.String("api-key", "", "API key for the endpoint (saved to OS keyring)") dbKey := startFlags.String("db-key", "", "Database encryption key (auto-generated if not set)") autoMode := startFlags.Bool("auto", false, "Auto mode: resolve providers from model names (per-provider keys or client auth)") + manualMode := startFlags.Bool("manual", false, "Manual mode: prompt for endpoint URL and API key") // Advanced options proxyPort := startFlags.Int("proxy-port", 0, "Proxy server port (default from config)") @@ -220,35 +221,43 @@ func runStart(args []string) { // Interactive mode - collect configuration via TUI var startupCfg startup.Config - if *autoMode || (*endpoint != "" && *apiKey != "") { - // Flags provided — skip interactive prompts, but still show the banner - fmt.Println() - banner.PrintBanner(Version) - fmt.Println() - startupCfg = startup.Config{ - AutoMode: *autoMode, - EndpointURL: *endpoint, - APIKey: *apiKey, - EncryptionKey: *dbKey, - TelemetryEnabled: *telemetryEnabled, - RetentionDays: *retentionDays, - DisableBuiltinRules: *disableBuiltin, - ProxyPort: *proxyPort, - } - } else { - // Run interactive prompts (asks auto vs manual mode first) - startupCfg, err = startup.RunStartupWithPort(cfg.Upstream.URL, cfg.Server.Port) + fmt.Println() + banner.PrintBanner(Version) + fmt.Println() + + switch { + case *manualMode && *endpoint == "": + // --manual without --endpoint: prompt for endpoint + API key + startupCfg, err = startup.RunManualSetup(cfg.Upstream.URL) if err != nil { tui.PrintError(fmt.Sprintf("Startup error: %v", err)) os.Exit(1) } - if startupCfg.Canceled { tui.PrintInfo("Startup canceled") os.Exit(0) } + + case *endpoint != "" && *apiKey != "": + // Explicit endpoint + key via flags + startupCfg = startup.Config{ + EndpointURL: *endpoint, + APIKey: *apiKey, + } + + default: + // Default: auto mode (zero interaction) + startupCfg = startup.Config{AutoMode: true} + tui.PrintInfo("Auto mode — providers resolved from model names") } + // Apply CLI flag overrides + startupCfg.EncryptionKey = *dbKey + startupCfg.TelemetryEnabled = *telemetryEnabled + startupCfg.RetentionDays = *retentionDays + startupCfg.DisableBuiltinRules = *disableBuiltin + startupCfg.ProxyPort = *proxyPort + // Build args for daemon process daemonArgs := daemon.StartArgs{ ConfigPath: *configPath, From 27d3f2a0ccb03123c2911b8cdab7825f6866fe2b Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Mar 2026 14:25:20 +0800 Subject: [PATCH 2/2] fix(ci): increase fuzz timeouts for heavy targets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move FuzzWebSearchURLBypass out of 12s parallel batch into sequential 15s slot — it creates a full engine per iteration and times out under CPU contention. Bump FuzzPipeBypass and FuzzForkBombDetection from 12s to 15s to avoid "context deadline exceeded" shutdown races. --- .github/workflows/ci.yml | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d70d4b6..ba8cde1c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -136,20 +136,20 @@ jobs: /tmp/rules.fuzz -test.fuzz=FuzzMCPToolBypass -test.fuzztime=30s -test.short -test.fuzzcachedir=/tmp/fuzz-cache /tmp/rules.fuzz -test.fuzz=FuzzConfusableBypass -test.fuzztime=30s -test.short -test.fuzzcachedir=/tmp/fuzz-cache - - name: Fuzz tests (full suite, 12s each — parallelized) + - name: Fuzz tests (full suite, 15s each — parallelized) run: | # Run full suite in parallel batches (2 at a time) to cut wall time # while keeping enough CPU per fuzzer for meaningful coverage. # Each fuzz target gets its own process so failures are isolated. - # 12s (not 10s) to avoid spurious "context deadline exceeded" from - # the Go fuzz framework when CPU-contended shutdown races the timer. + # 15s to avoid spurious "context deadline exceeded" from the Go fuzz + # framework when CPU-contended shutdown races the timer. parallel_fuzz() { local bin="$1"; shift local pids=() local targets=() local failed=0 for target in "$@"; do - "$bin" -test.fuzz="$target" -test.fuzztime=12s -test.short -test.fuzzcachedir=/tmp/fuzz-cache & + "$bin" -test.fuzz="$target" -test.fuzztime=15s -test.short -test.fuzzcachedir=/tmp/fuzz-cache & pids+=($!) targets+=("$target") # Cap concurrency at 2 for better per-target CPU utilization. @@ -179,15 +179,17 @@ jobs: FuzzCommandRegexBypass FuzzHostRegexBypass \ FuzzJSONUnicodeEscapeBypass FuzzEvasionDetectionBypass FuzzGlobCommandBypass \ FuzzPipelineExtraction FuzzContentConfusableBypass FuzzVariableExpansionEvasion \ - FuzzShapeDetectionBypass FuzzWebSearchURLBypass \ + FuzzShapeDetectionBypass \ || rc=1 - # FuzzNormalAgentFalsePositive creates a full engine per iteration - # (with DLP/gitleaks) — too slow for the 12s parallel batch. + # FuzzNormalAgentFalsePositive and FuzzWebSearchURLBypass create a + # full engine per iteration (with DLP/gitleaks) — too slow for the + # 12s parallel batch. Run sequentially with more time. /tmp/rules.fuzz -test.fuzz=FuzzNormalAgentFalsePositive -test.fuzztime=15s -test.short -test.fuzzcachedir=/tmp/fuzz-cache || rc=1 + /tmp/rules.fuzz -test.fuzz=FuzzWebSearchURLBypass -test.fuzztime=15s -test.short -test.fuzzcachedir=/tmp/fuzz-cache || rc=1 # FuzzPipeBypass and FuzzForkBombDetection need constrained memory and # single worker to avoid OOM when running alongside other fuzz targets. - GOMEMLIMIT=2GiB /tmp/rules.fuzz -test.fuzz=FuzzPipeBypass -test.fuzztime=12s -test.short -test.fuzzcachedir=/tmp/fuzz-cache -test.parallel=1 || rc=1 - GOMEMLIMIT=2GiB /tmp/rules.fuzz -test.fuzz=FuzzForkBombDetection -test.fuzztime=12s -test.short -test.fuzzcachedir=/tmp/fuzz-cache -test.parallel=1 || rc=1 + GOMEMLIMIT=2GiB /tmp/rules.fuzz -test.fuzz=FuzzPipeBypass -test.fuzztime=15s -test.short -test.fuzzcachedir=/tmp/fuzz-cache -test.parallel=1 || rc=1 + GOMEMLIMIT=2GiB /tmp/rules.fuzz -test.fuzz=FuzzForkBombDetection -test.fuzztime=15s -test.short -test.fuzzcachedir=/tmp/fuzz-cache -test.parallel=1 || rc=1 parallel_fuzz /tmp/httpproxy.fuzz \ FuzzParseSSEEventData FuzzCopyHeaders FuzzStripHopByHopHeaders \ FuzzParseEvent FuzzApplyResultToToolCalls \