diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index 7c9f02945e..acc779356d 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -738,11 +738,7 @@ fn now_unix_timestamp() -> u64 { } fn read_env_non_empty(key: &str) -> Result, ApiError> { - match std::env::var(key) { - Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(super::dotenv_value(key)), - Err(error) => Err(ApiError::from(error)), - } + super::read_env_or_config(key) } #[cfg(test)] @@ -763,7 +759,10 @@ fn read_auth_token() -> Option { #[must_use] pub fn read_base_url() -> String { - std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) + super::read_env_or_config("ANTHROPIC_BASE_URL") + .ok() + .flatten() + .unwrap_or_else(|| DEFAULT_BASE_URL.to_string()) } fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index fb97900359..b768f1d5e1 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -247,9 +247,65 @@ pub fn detect_provider_kind(model: &str) -> ProviderKind { if std::env::var_os("OPENAI_BASE_URL").is_some() { return ProviderKind::OpenAi; } + // Fallback: check stored provider config from setup wizard. + if let Some(kind) = stored_provider_kind() { + return kind; + } ProviderKind::Anthropic } +/// Look up a stored provider config value by env var name. +/// Returns the stored API key or base URL when the env var matches the +/// configured provider kind, enabling the setup wizard to persist credentials +/// that work without shell env vars. +pub fn provider_config_value(key: &str) -> Option { + let cwd = std::env::current_dir().ok()?; + let config = runtime::ConfigLoader::default_for(&cwd).load().ok()?; + let provider = config.provider(); + let kind = provider.kind()?; + match (key, kind) { + ("ANTHROPIC_API_KEY" | "ANTHROPIC_AUTH_TOKEN", "anthropic") + | ("XAI_API_KEY", "xai") + | ("OPENAI_API_KEY", "openai") + | ("DASHSCOPE_API_KEY", "dashscope") => provider.api_key().map(ToOwned::to_owned), + ("ANTHROPIC_BASE_URL", "anthropic") + | ("XAI_BASE_URL", "xai") + | ("OPENAI_BASE_URL", "openai") + | ("DASHSCOPE_BASE_URL", "dashscope") => provider.base_url().map(ToOwned::to_owned), + _ => None, + } +} + +/// Read an env var with a 3-tier fallback: process env -> .env file -> stored config. +/// Environment variables always take priority over stored settings. +pub fn read_env_or_config(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => return Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => {} + Err(error) => return Err(ApiError::from(error)), + } + if let Some(value) = dotenv_value(key) { + return Ok(Some(value)); + } + if let Some(value) = provider_config_value(key) { + return Ok(Some(value)); + } + Ok(None) +} + +/// Return the stored `ProviderKind` from config, if set. +fn stored_provider_kind() -> Option { + let cwd = std::env::current_dir().ok()?; + let config = runtime::ConfigLoader::default_for(&cwd).load().ok()?; + let kind = config.provider().kind()?; + match kind { + "anthropic" => Some(ProviderKind::Anthropic), + "xai" => Some(ProviderKind::Xai), + "openai" => Some(ProviderKind::OpenAi), + _ => None, + } +} + #[must_use] pub fn max_tokens_for_model(model: &str) -> u32 { model_token_limit(model).map_or_else( diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 4e4183bd96..3794008722 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -1303,11 +1303,7 @@ fn parse_sse_frame( } fn read_env_non_empty(key: &str) -> Result, ApiError> { - match std::env::var(key) { - Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(super::dotenv_value(key)), - Err(error) => Err(ApiError::from(error)), - } + super::read_env_or_config(key) } #[must_use] @@ -1320,7 +1316,10 @@ pub fn has_api_key(key: &str) -> bool { #[must_use] pub fn read_base_url(config: OpenAiCompatConfig) -> String { - std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) + super::read_env_or_config(config.base_url_env) + .ok() + .flatten() + .unwrap_or_else(|| config.default_base_url.to_string()) } fn chat_completions_endpoint(base_url: &str) -> String { diff --git a/rust/crates/commands/src/lib.rs b/rust/crates/commands/src/lib.rs index 9d693935a5..6ef4409a54 100644 --- a/rust/crates/commands/src/lib.rs +++ b/rust/crates/commands/src/lib.rs @@ -313,6 +313,13 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ argument_hint: None, resume_supported: true, }, + SlashCommandSpec { + name: "setup", + aliases: &[], + summary: "Configure provider, API key, and model interactively", + argument_hint: None, + resume_supported: true, + }, SlashCommandSpec { name: "stats", aliases: &[], @@ -1140,6 +1147,7 @@ pub enum SlashCommand { Usage { scope: Option, }, + Setup, Rename { name: Option, }, @@ -1265,6 +1273,7 @@ impl SlashCommand { Self::Theme { .. } => "/theme", Self::Voice { .. } => "/voice", Self::Usage { .. } => "/usage", + Self::Setup => "/setup", Self::Rename { .. } => "/rename", Self::Copy { .. } => "/copy", Self::Hooks { .. } => "/hooks", @@ -1476,6 +1485,7 @@ pub fn validate_slash_command_input( "theme" => SlashCommand::Theme { name: remainder }, "voice" => SlashCommand::Voice { mode: remainder }, "usage" => SlashCommand::Usage { scope: remainder }, + "setup" => SlashCommand::Setup, "rename" => SlashCommand::Rename { name: remainder }, "copy" => SlashCommand::Copy { target: remainder }, "hooks" => SlashCommand::Hooks { args: remainder }, @@ -2537,6 +2547,7 @@ pub fn resolve_skill_path(cwd: &Path, skill: &str) -> std::io::Result { )) } +#[allow(clippy::unnecessary_wraps)] fn render_mcp_report_for( loader: &ConfigLoader, cwd: &Path, @@ -2600,6 +2611,7 @@ fn render_mcp_report_for( } } +#[allow(clippy::unnecessary_wraps)] fn render_mcp_report_json_for( loader: &ConfigLoader, cwd: &Path, @@ -4169,6 +4181,7 @@ pub fn handle_slash_command( | SlashCommand::OutputStyle { .. } | SlashCommand::AddDir { .. } | SlashCommand::History { .. } + | SlashCommand::Setup | SlashCommand::Unknown(_) => None, } } @@ -4706,7 +4719,7 @@ mod tests { assert!(help.contains("aliases: /skill")); assert!(!help.contains("/login")); assert!(!help.contains("/logout")); - assert_eq!(slash_command_specs().len(), 139); + assert_eq!(slash_command_specs().len(), 140); assert!(resume_supported_slash_commands().len() >= 39); } diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index 1566189282..868a15a60d 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -65,6 +65,38 @@ pub struct RuntimeFeatureConfig { sandbox: SandboxConfig, provider_fallbacks: ProviderFallbackConfig, trusted_roots: Vec, + provider: RuntimeProviderConfig, +} + +/// Stored provider configuration from the setup wizard. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimeProviderConfig { + kind: Option, + api_key: Option, + base_url: Option, + model: Option, +} + +impl RuntimeProviderConfig { + #[must_use] + pub fn kind(&self) -> Option<&str> { + self.kind.as_deref() + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + self.api_key.as_deref() + } + + #[must_use] + pub fn base_url(&self) -> Option<&str> { + self.base_url.as_deref() + } + + #[must_use] + pub fn model(&self) -> Option<&str> { + self.model.as_deref() + } } /// Ordered chain of fallback model identifiers used when the primary @@ -315,6 +347,7 @@ impl ConfigLoader { sandbox: parse_optional_sandbox_config(&merged_value)?, provider_fallbacks: parse_optional_provider_fallbacks(&merged_value)?, trusted_roots: parse_optional_trusted_roots(&merged_value)?, + provider: parse_optional_provider_config(&merged_value)?, }; Ok(RuntimeConfig { @@ -414,6 +447,11 @@ impl RuntimeConfig { pub fn trusted_roots(&self) -> &[String] { &self.feature_config.trusted_roots } + + #[must_use] + pub fn provider(&self) -> &RuntimeProviderConfig { + &self.feature_config.provider + } } impl RuntimeFeatureConfig { @@ -483,6 +521,11 @@ impl RuntimeFeatureConfig { pub fn trusted_roots(&self) -> &[String] { &self.trusted_roots } + + #[must_use] + pub fn provider(&self) -> &RuntimeProviderConfig { + &self.provider + } } impl ProviderFallbackConfig { @@ -564,6 +607,92 @@ pub fn default_config_home() -> PathBuf { .unwrap_or_else(|| PathBuf::from(".claw")) } +/// Save provider settings to the user-level `~/.claw/settings.json`. +/// Creates the file and directory if they don't exist. Sets file permissions +/// to `0o600` (owner read/write only) to protect stored API keys. +pub fn save_user_provider_settings( + kind: &str, + api_key: &str, + base_url: Option<&str>, + model: Option<&str>, +) -> Result<(), ConfigError> { + let config_home = default_config_home(); + fs::create_dir_all(&config_home).map_err(ConfigError::Io)?; + let settings_path = config_home.join("settings.json"); + + let mut root = read_settings_root(&settings_path); + + let mut provider = serde_json::Map::new(); + provider.insert("kind".to_string(), serde_json::Value::String(kind.to_string())); + provider.insert("apiKey".to_string(), serde_json::Value::String(api_key.to_string())); + if let Some(base_url) = base_url { + provider.insert("baseUrl".to_string(), serde_json::Value::String(base_url.to_string())); + } else { + provider.remove("baseUrl"); + } + root.insert("provider".to_string(), serde_json::Value::Object(provider)); + if let Some(model) = model { + root.insert("model".to_string(), serde_json::Value::String(model.to_string())); + } else { + root.remove("model"); + } + + write_settings_root(&settings_path, &root)?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + fs::set_permissions(&settings_path, perms).map_err(ConfigError::Io)?; + } + + Ok(()) +} + +/// Remove the `provider` section from the user-level `~/.claw/settings.json`. +pub fn clear_user_provider_settings() -> Result<(), ConfigError> { + let config_home = default_config_home(); + let settings_path = config_home.join("settings.json"); + + if !settings_path.exists() { + return Ok(()); + } + + let mut root = read_settings_root(&settings_path); + if root.remove("provider").is_none() { + return Ok(()); + } + root.remove("model"); + + write_settings_root(&settings_path, &root)?; + + Ok(()) +} + +fn read_settings_root(path: &Path) -> serde_json::Map { + match fs::read_to_string(path) { + Ok(contents) if !contents.trim().is_empty() => { + serde_json::from_str::(&contents) + .ok() + .and_then(|v| v.as_object().cloned()) + .unwrap_or_default() + } + _ => serde_json::Map::new(), + } +} + +fn write_settings_root( + path: &Path, + root: &serde_json::Map, +) -> Result<(), ConfigError> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).map_err(ConfigError::Io)?; + } + let rendered = serde_json::to_string_pretty(&serde_json::Value::Object(root.clone())) + .map_err(|e| ConfigError::Parse(e.to_string()))?; + fs::write(path, format!("{rendered}\n")).map_err(ConfigError::Io) +} + impl RuntimeHookConfig { #[must_use] pub fn new( @@ -950,6 +1079,25 @@ fn parse_optional_oauth_config( })) } +fn parse_optional_provider_config(root: &JsonValue) -> Result { + let Some(provider_value) = root.as_object().and_then(|object| object.get("provider")) else { + return Ok(RuntimeProviderConfig::default()); + }; + let Some(object) = provider_value.as_object() else { + return Ok(RuntimeProviderConfig::default()); + }; + let kind = optional_string(object, "kind", "provider")?.map(str::to_string); + let api_key = optional_string(object, "apiKey", "provider")?.map(str::to_string); + let base_url = optional_string(object, "baseUrl", "provider")?.map(str::to_string); + let model = optional_string(object, "model", "provider")?.map(str::to_string); + Ok(RuntimeProviderConfig { + kind, + api_key, + base_url, + model, + }) +} + fn parse_mcp_server_config( server_name: &str, value: &JsonValue, diff --git a/rust/crates/runtime/src/config_validate.rs b/rust/crates/runtime/src/config_validate.rs index 7a9c1c4adc..ad6cc44c18 100644 --- a/rust/crates/runtime/src/config_validate.rs +++ b/rust/crates/runtime/src/config_validate.rs @@ -197,6 +197,10 @@ const TOP_LEVEL_FIELDS: &[FieldSpec] = &[ name: "trustedRoots", expected: FieldType::StringArray, }, + FieldSpec { + name: "provider", + expected: FieldType::Object, + }, ]; const HOOKS_FIELDS: &[FieldSpec] = &[ @@ -310,6 +314,25 @@ const OAUTH_FIELDS: &[FieldSpec] = &[ }, ]; +const PROVIDER_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "kind", + expected: FieldType::String, + }, + FieldSpec { + name: "apiKey", + expected: FieldType::String, + }, + FieldSpec { + name: "baseUrl", + expected: FieldType::String, + }, + FieldSpec { + name: "model", + expected: FieldType::String, + }, +]; + const DEPRECATED_FIELDS: &[DeprecatedField] = &[ DeprecatedField { name: "permissionMode", @@ -501,6 +524,15 @@ pub fn validate_config_file( &path_display, )); } + if let Some(provider) = object.get("provider").and_then(JsonValue::as_object) { + result.merge(validate_object_keys( + provider, + PROVIDER_FIELDS, + "provider", + source, + &path_display, + )); + } result } diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index c7d87091fa..85ef40e68f 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -57,12 +57,13 @@ pub use compact::{ get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult, }; pub use config::{ - ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpConfigCollection, - McpManagedProxyServerConfig, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, - McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, + clear_user_provider_settings, save_user_provider_settings, ConfigEntry, ConfigError, + ConfigLoader, ConfigSource, McpConfigCollection, McpManagedProxyServerConfig, + McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig, + McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, ProviderFallbackConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, - RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig, - CLAW_SETTINGS_SCHEMA_NAME, + RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, RuntimeProviderConfig, + ScopedMcpServerConfig, CLAW_SETTINGS_SCHEMA_NAME, }; pub use config_validate::{ check_unsupported_format, format_diagnostics, validate_config_file, ConfigDiagnostic, diff --git a/rust/crates/rusty-claude-cli/src/input.rs b/rust/crates/rusty-claude-cli/src/input.rs index b0664dac44..011065a728 100644 --- a/rust/crates/rusty-claude-cli/src/input.rs +++ b/rust/crates/rusty-claude-cli/src/input.rs @@ -18,6 +18,7 @@ pub enum ReadOutcome { Submit(String), Cancel, Exit, + ProviderSwap, } struct SlashCommandHelper { @@ -86,12 +87,19 @@ impl Hinter for SlashCommandHelper { impl Highlighter for SlashCommandHelper { fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { self.set_current_line(line); - Cow::Borrowed(line) + // When sentinel is present, show visible prompt instead of invisible char + if line.contains('\x01') { + let display = line.replace('\x01', "\x1b[36m[Provider Swap]\x1b[0m "); + Cow::Owned(display) + } else { + Cow::Borrowed(line) + } } fn highlight_char(&self, line: &str, _pos: usize, _kind: CmdKind) -> bool { self.set_current_line(line); - false + // Re-highlight when sentinel is present to show the prompt + line.contains('\x01') } } @@ -115,6 +123,13 @@ impl LineEditor { editor.set_helper(Some(SlashCommandHelper::new(completions))); editor.bind_sequence(KeyEvent(KeyCode::Char('J'), Modifiers::CTRL), Cmd::Newline); editor.bind_sequence(KeyEvent(KeyCode::Enter, Modifiers::SHIFT), Cmd::Newline); + // Ctrl+P inserts a sentinel character that triggers provider swap. + // The sentinel is invisible but the highlighter shows "[Provider Swap]" prompt. + // User must press Enter to confirm (rustyline cannot chain commands). + editor.bind_sequence( + KeyEvent(KeyCode::Char('P'), Modifiers::CTRL), + Cmd::SelfInsert(1, '\x01'), + ); Self { prompt: prompt.into(), @@ -147,7 +162,14 @@ impl LineEditor { } match self.editor.readline(&self.prompt) { - Ok(line) => Ok(ReadOutcome::Submit(line)), + Ok(line) => { + // Ctrl+P inserts \x01 sentinel — triggers provider swap wizard. + // The sentinel is stripped and we return ProviderSwap to the REPL loop. + if line.contains('\x01') { + return Ok(ReadOutcome::ProviderSwap); + } + Ok(ReadOutcome::Submit(line)) + } Err(ReadlineError::Interrupted) => { let has_input = !self.current_line().is_empty(); self.finish_interrupted_read()?; diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index c4ba812a64..7923fa0d46 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -2,6 +2,13 @@ dead_code, unused_imports, unused_variables, + clippy::doc_markdown, + clippy::len_zero, + clippy::manual_string_new, + clippy::match_same_arms, + clippy::result_large_err, + clippy::too_many_lines, + clippy::uninlined_format_args, clippy::unneeded_struct_pattern, clippy::unnecessary_wraps, clippy::unused_self @@ -9,6 +16,7 @@ mod init; mod input; mod render; +mod setup_wizard; use std::collections::BTreeSet; use std::env; @@ -406,6 +414,7 @@ fn run() -> Result<(), Box> { CliAction::Acp { output_format } => print_acp_status(output_format)?, CliAction::State { output_format } => run_worker_state(output_format)?, CliAction::Init { output_format } => run_init(output_format)?, + CliAction::Setup { .. } => setup_wizard::run_setup_wizard()?, // #146: dispatch pure-local introspection. Text mode uses existing // render_config_report/render_diff_report; JSON mode uses the // corresponding _json helpers already exposed for resume sessions. @@ -564,6 +573,9 @@ enum CliAction { Help { output_format: CliOutputFormat, }, + Setup { + output_format: CliOutputFormat, + }, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -1099,6 +1111,7 @@ fn parse_single_word_command_alias( "sandbox" => Some(Ok(CliAction::Sandbox { output_format })), "doctor" => Some(Ok(CliAction::Doctor { output_format })), "state" => Some(Ok(CliAction::State { output_format })), + "setup" => Some(Ok(CliAction::Setup { output_format })), // #146: let `config` and `diff` fall through to parse_subcommand // where they are wired as pure-local introspection, instead of // producing the "is a slash command" guidance. Zero-arg cases @@ -1120,6 +1133,7 @@ fn bare_slash_command_guidance(command_name: &str) -> Option { | "init" | "prompt" | "export" + | "setup" ) { return None; } @@ -1572,7 +1586,8 @@ fn config_permission_mode_for_current_dir() -> Option { fn config_model_for_current_dir() -> Option { let cwd = env::current_dir().ok()?; let loader = ConfigLoader::default_for(&cwd); - loader.load().ok()?.model().map(ToOwned::to_owned) + let config = loader.load().ok()?; + config.model().map(ToOwned::to_owned).or_else(|| config.provider().model().map(ToOwned::to_owned)) } fn resolve_repl_model(cli_model: String) -> String { @@ -3463,7 +3478,8 @@ fn run_resume_command( | SlashCommand::Ide { .. } | SlashCommand::Tag { .. } | SlashCommand::OutputStyle { .. } - | SlashCommand::AddDir { .. } => Err("unsupported resumed slash command".into()), + | SlashCommand::AddDir { .. } + | SlashCommand::Setup => Err("unsupported resumed slash command".into()), } } @@ -3615,6 +3631,16 @@ fn run_repl( cli.record_prompt_history(&trimmed); cli.run_turn(&trimmed)?; } + input::ReadOutcome::ProviderSwap => { + // Ctrl+P triggered — launch setup wizard and hot-swap model + setup_wizard::run_setup_wizard()?; + let cwd = std::env::current_dir().unwrap_or_default(); + let config = runtime::ConfigLoader::default_for(&cwd).load().ok(); + if let Some(new_model) = config.as_ref().and_then(|c| c.provider().model().map(str::to_string)) { + cli.set_model(Some(new_model))?; + } + println!("{}", format_connected_line(&cli.model)); + } input::ReadOutcome::Cancel => {} input::ReadOutcome::Exit => { cli.persist_session()?; @@ -4466,6 +4492,16 @@ impl LiveCli { run_init(CliOutputFormat::Text)?; false } + SlashCommand::Setup => { + setup_wizard::run_setup_wizard()?; + // Reload the model from config after wizard saves + let cwd = std::env::current_dir().unwrap_or_default(); + let config = runtime::ConfigLoader::default_for(&cwd).load().ok(); + if let Some(new_model) = config.as_ref().and_then(|c| c.provider().model().map(str::to_string)) { + self.set_model(Some(new_model))?; + } + false + } SlashCommand::Diff => { Self::print_diff()?; false diff --git a/rust/crates/rusty-claude-cli/src/setup_wizard.rs b/rust/crates/rusty-claude-cli/src/setup_wizard.rs new file mode 100644 index 0000000000..e1513f712b --- /dev/null +++ b/rust/crates/rusty-claude-cli/src/setup_wizard.rs @@ -0,0 +1,226 @@ +use std::io::{self, IsTerminal, Write}; + +use runtime::{save_user_provider_settings, ConfigLoader, RuntimeProviderConfig}; + +const PROVIDERS: &[(&str, &str, &str)] = &[ + ("1", "Anthropic", "anthropic"), + ("2", "xAI / Grok", "xai"), + ("3", "OpenAI", "openai"), + ("4", "DashScope (Qwen/Kimi)", "dashscope"), + ("5", "Custom (OpenAI-compat)", "openai"), +]; + +const PROVIDER_MODELS: &[(&str, &[&str])] = &[ + ("anthropic", &["opus", "sonnet", "haiku"]), + ("xai", &["grok", "grok-mini", "grok-2"]), + ("openai", &["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"]), + ("dashscope", &["qwen-plus", "qwen-max", "kimi"]), +]; + +const DEFAULT_BASE_URLS: &[(&str, &str)] = &[ + ("anthropic", "https://api.anthropic.com"), + ("xai", "https://api.x.ai/v1"), + ("openai", "https://api.openai.com/v1"), + ("dashscope", "https://dashscope.aliyuncs.com/compatible-mode/v1"), +]; + +const API_KEY_ENV_VARS: &[(&str, &str)] = &[ + ("anthropic", "ANTHROPIC_API_KEY"), + ("xai", "XAI_API_KEY"), + ("openai", "OPENAI_API_KEY"), + ("dashscope", "DASHSCOPE_API_KEY"), +]; + +pub fn run_setup_wizard() -> Result<(), Box> { + if !io::stdin().is_terminal() { + return Err("setup wizard requires an interactive terminal".into()); + } + + let current = load_current_provider_config(); + + println!(); + println!(" \x1b[1mClaw Code Setup Wizard\x1b[0m"); + println!(" Configure your provider, API key, and model."); + println!(" Press Enter to keep current value.\n"); + + let kind = prompt_provider(¤t)?; + let api_key = prompt_api_key(&kind, ¤t)?; + let base_url = prompt_base_url(&kind, ¤t)?; + let model = prompt_model(&kind, ¤t)?; + + save_user_provider_settings( + &kind, + &api_key, + base_url.as_deref(), + model.as_deref(), + )?; + + println!(); + println!(" \x1b[32mProvider saved to ~/.claw/settings.json\x1b[0m"); + println!(" Run \x1b[1m/model {}\x1b[0m or restart claw to activate.", model.as_deref().unwrap_or(&kind)); + println!(); + + Ok(()) +} + +fn load_current_provider_config() -> RuntimeProviderConfig { + let cwd = std::env::current_dir().unwrap_or_default(); + ConfigLoader::default_for(&cwd) + .load() + .map(|c| c.provider().clone()) + .unwrap_or_default() +} + +fn prompt_provider(current: &RuntimeProviderConfig) -> Result> { + let current_kind = current.kind().unwrap_or("anthropic"); + println!(" \x1b[1mProvider\x1b[0m"); + for (num, label, kind) in PROVIDERS { + let marker = if *kind == current_kind { " (current)" } else { "" }; + println!(" [{num}] {label}{marker}"); + } + let default = PROVIDERS + .iter() + .position(|(_, _, k)| *k == current_kind) + .map_or_else(|| "1".to_string(), |i| (i + 1).to_string()); + + let input = read_line(&format!(" Select provider [{default}]: "))?; + let choice = if input.trim().is_empty() { + default + } else { + input.trim().to_string() + }; + + let kind = PROVIDERS + .iter() + .find(|(num, _, _)| *num == choice) + .map(|(_, _, kind)| *kind) + .ok_or_else(|| format!("invalid provider choice: {choice}"))?; + + Ok(kind.to_string()) +} + +fn prompt_api_key( + kind: &str, + current: &RuntimeProviderConfig, +) -> Result> { + let env_var = API_KEY_ENV_VARS + .iter() + .find(|(k, _)| *k == kind) + .map_or("API_KEY", |(_, v)| *v); + + let current_key = current.api_key(); + let hint = match current_key { + Some(key) if !key.is_empty() => { + let masked = if key.len() > 4 { + format!("****{}", &key[key.len() - 4..]) + } else { + "****".to_string() + }; + format!("[{masked}]") + } + _ => "(none)".to_string(), + }; + + // Check if env var is already set + let env_set = std::env::var(env_var) + .ok() + .is_some_and(|v| !v.is_empty()); + if env_set { + println!(" {env_var} is set in environment (will take priority over stored key)"); + } + + let input = read_line(&format!(" API key ({env_var}) {hint}: "))?; + let key = if input.trim().is_empty() { + current_key.unwrap_or("").to_string() + } else { + input.trim().to_string() + }; + + if key.is_empty() && !env_set { + eprintln!(" \x1b[33mWarning: no API key configured. Set {env_var} or re-run setup.\x1b[0m"); + } + + Ok(key) +} + +fn prompt_base_url( + kind: &str, + current: &RuntimeProviderConfig, +) -> Result, Box> { + let default_url = DEFAULT_BASE_URLS + .iter() + .find(|(k, _)| *k == kind) + .map_or("", |(_, v)| *v); + + let current_url = current.base_url().unwrap_or(default_url); + let display = if current_url.is_empty() { + default_url.to_string() + } else { + current_url.to_string() + }; + + // Check if the relevant env var is already set + let env_var = match kind { + "anthropic" => "ANTHROPIC_BASE_URL", + "xai" => "XAI_BASE_URL", + "openai" => "OPENAI_BASE_URL", + "dashscope" => "DASHSCOPE_BASE_URL", + _ => "BASE_URL", + }; + let env_set = std::env::var(env_var) + .ok() + .is_some_and(|v| !v.is_empty()); + if env_set { + println!(" {env_var} is set in environment (will take priority over stored URL)"); + } + + let input = read_line(&format!(" Base URL [{display}]: "))?; + if input.trim().is_empty() { + if current_url == default_url || current_url.is_empty() { + Ok(None) + } else { + Ok(Some(current_url.to_string())) + } + } else { + Ok(Some(input.trim().to_string())) + } +} + +fn prompt_model( + kind: &str, + current: &RuntimeProviderConfig, +) -> Result, Box> { + let empty: &[&str] = &[]; + let aliases = PROVIDER_MODELS + .iter() + .find(|(k, _)| *k == kind) + .map_or(empty, |(_, models)| *models); + + let current_model = current.model().unwrap_or(aliases.first().copied().unwrap_or("")); + + println!(" \x1b[1mModel\x1b[0m"); + if !aliases.is_empty() { + println!(" Common: {}", aliases.join(", ")); + } + println!(" Or enter any model name (e.g. openai/gpt-4.1-mini for custom routing)"); + + let input = read_line(&format!(" Model [{current_model}]: "))?; + if input.trim().is_empty() { + if current_model.is_empty() { + Ok(None) + } else { + Ok(Some(current_model.to_string())) + } + } else { + Ok(Some(input.trim().to_string())) + } +} + +fn read_line(prompt: &str) -> Result> { + let mut stdout = io::stdout(); + write!(stdout, "{prompt}")?; + stdout.flush()?; + let mut buffer = String::new(); + io::stdin().read_line(&mut buffer)?; + Ok(buffer) +}