diff --git a/Cargo.toml b/Cargo.toml index 0e0dba0..b48c1cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ghcc" -version = "0.1.0" +version = "0.2.0" edition = "2024" rust-version = "1.85" description = "Fast conventional commit messages using GitHub Copilot" diff --git a/README.md b/README.md index 681f9ff..a913326 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,18 @@ # ghcc -ghcc is a CLI tool for generating conventional commit messages with GitHub Copilot. +Fast CLI tool for generating conventional commit messages with GitHub Copilot. ## Features - **Conventional commits** - Generates conventional commit style messages ([spec](https://www.conventionalcommits.org/)) - **Fast** - Single command, no prompts required -- **Commit styles** - Single-line, detailed with bullets, or let the AI decide +- **Commit styles** - Single-line, detailed body to provide context, or let the AI decide - **Git hook integration** - Pre-fills your editor so you can check and edit before committing - **Uses GitHub Copilot** - Just log in with your existing account. No API keys needed. -- **Model selection** - Choose from your enabled Copilot models (default: `gpt-5-mini`) +- **Model selection** - Choose from your enabled Copilot models (default: `gpt-4.1`) + +> [!TIP] +> `gpt-4.1` (default) seems to perform best of the free models. For a bump in quality, any of the non-Codex GPT-5 models (`gpt-5`, `gpt-5.1`, `gpt-5.2`) perform well but will consume premium requests. ## Quickstart @@ -56,10 +59,10 @@ ghcc is a CLI tool for generating conventional commit messages with GitHub Copil ```sh ghcc # Single-line commit message (default) -ghcc -d # Detailed: subject + bullet-point body +ghcc -d # Detailed: subject + paragraph body to provide context ghcc -a # Auto: AI decides format based on complexity ghcc status # Check authentication status -ghcc models # List and select models (default: gpt-5-mini) +ghcc models # List and select models (default: gpt-4.1) ghcc hook install # Install git hook for current repo ghcc hook uninstall ``` diff --git a/src/auth.rs b/src/auth.rs index f23075c..607c861 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -32,6 +32,8 @@ pub struct CopilotAuth { pub machine_id: Option, #[serde(default)] pub max_prompt_tokens: Option, + #[serde(default)] + pub supported_endpoints: Option>, } #[derive(Debug, Deserialize)] @@ -73,7 +75,7 @@ pub enum AuthError { impl std::fmt::Display for AuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - AuthError::NotFound => write!(f, "Credentials not found"), + AuthError::NotFound => write!(f, "Credentials not found. Try running 'ghcc login'"), AuthError::Network(msg) => write!(f, "Network error: {}", msg), AuthError::Parse(msg) => write!(f, "Parse error: {}", msg), AuthError::FileSystem(msg) => write!(f, "File system error: {}", msg), @@ -272,7 +274,8 @@ pub fn login() -> Result { api_endpoint: copilot_resp.endpoints.get("api").cloned(), model: None, machine_id: Some(Uuid::new_v4().to_string()), - max_prompt_tokens: None, // Set when user selects a model + max_prompt_tokens: None, // Set when user selects a model + supported_endpoints: None, // Set when user selects a model }; Ok(auth) @@ -284,6 +287,7 @@ pub fn refresh_token( current_model: Option, current_machine_id: Option, current_max_prompt_tokens: Option, + current_supported_endpoints: Option>, ) -> Result { let agent = create_agent(); @@ -308,6 +312,7 @@ pub fn refresh_token( model: current_model, machine_id: current_machine_id, max_prompt_tokens: current_max_prompt_tokens, + supported_endpoints: current_supported_endpoints, }) } @@ -316,11 +321,18 @@ pub fn get_valid_auth() -> Result { let mut auth = read_auth()?; if is_expired(&auth) || auth.api_endpoint.is_none() { - // preserve current model, machine_id, and max_prompt_tokens during refresh + // preserve current model, machine_id, max_prompt_tokens, and supported_endpoints during refresh let model = auth.model.clone(); let machine_id = auth.machine_id.clone(); let max_prompt_tokens = auth.max_prompt_tokens; - auth = refresh_token(&auth.refresh, model, machine_id, max_prompt_tokens)?; + let supported_endpoints = auth.supported_endpoints.clone(); + auth = refresh_token( + &auth.refresh, + model, + machine_id, + max_prompt_tokens, + supported_endpoints, + )?; save_auth(&auth)?; } @@ -367,6 +379,7 @@ mod tests { model: None, machine_id: None, max_prompt_tokens: None, + supported_endpoints: None, } } diff --git a/src/copilot.rs b/src/copilot.rs index 5a27b2f..e676240 100644 --- a/src/copilot.rs +++ b/src/copilot.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use std::io::{BufRead, BufReader}; use uuid::Uuid; -pub const DEFAULT_MODEL: &str = "gpt-5-mini"; +pub const DEFAULT_MODEL: &str = "gpt-4.1"; pub const DEFAULT_MAX_PROMPT_TOKENS: u64 = 128_000; // Rough estimate: 1 token ≈ 4 bytes for code @@ -35,6 +35,8 @@ struct ChatRequest { n: i32, stream: bool, intent: bool, + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_effort: Option, } #[derive(Debug, Deserialize)] @@ -52,6 +54,16 @@ struct StreamResponse { choices: Vec, } +// Structs for /responses API +#[derive(Debug, Serialize)] +struct ResponsesRequest { + model: String, + input: String, + stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + instructions: Option, +} + #[derive(Debug, Deserialize)] pub struct ModelsResponse { pub data: Vec, @@ -66,6 +78,10 @@ pub struct Model { pub model_picker_enabled: bool, #[serde(default)] pub capabilities: Option, + #[serde(default)] + pub policy: Option, + #[serde(default)] + pub supported_endpoints: Option>, } #[derive(Debug, Deserialize, Clone)] @@ -80,6 +96,12 @@ pub struct ModelLimits { pub max_prompt_tokens: Option, } +#[derive(Debug, Deserialize, Clone)] +pub struct ModelPolicy { + #[serde(default)] + pub state: Option, +} + #[derive(Debug)] pub enum CopilotError { Auth(auth::AuthError), @@ -108,6 +130,9 @@ impl From for CopilotError { impl From for CopilotError { fn from(e: ureq::Error) -> Self { let msg = match &e { + ureq::Error::StatusCode(400) => { + "Request failed (HTTP 400). Check if the model is enabled at https://github.com/settings/copilot/features".to_string() + } ureq::Error::StatusCode(401) => "Authentication failed. Try 'ghcc login'".to_string(), ureq::Error::StatusCode(403) => { "Access denied. Is your Copilot subscription active?".to_string() @@ -148,11 +173,16 @@ pub fn list_models(auth: &CopilotAuth) -> Result, CopilotError> { .body_mut() .read_json()?; - // Filter only enabled models + // Filter only enabled models (must have model_picker_enabled AND policy.state == "enabled") let enabled_models = response .data .into_iter() - .filter(|m| m.model_picker_enabled) + .filter(|m| { + m.model_picker_enabled + && m.policy + .as_ref() + .is_some_and(|p| p.state.as_deref() == Some("enabled")) + }) .collect(); Ok(enabled_models) @@ -214,15 +244,26 @@ fn lowercase_scope(message: &str) -> String { } /// Base prompt template shared across all commit styles -const BASE_PROMPT: &str = "Generate a conventional commit message for this diff. +const BASE_PROMPT: &str = + "You are a senior developer writing a commit message. Be precise and concise. -Rules: -- Format: type(scope): description -- Scope: noun for affected area. Omit if broad. -- Breaking changes: append ! to type. -- Keep subject under 72 characters. -- Describe intent, not implementation. -- Output ONLY the commit message."; +Conventional commit format: type(scope): subject + +Pick the right type: +- feat = new capability for users +- fix = bug was broken, now fixed +- refactor = code change, same behavior +- perf = optimization +- style = formatting +- ci = CI/CD pipelines +- chore = everything else (deps, config) +- docs/test/build = obvious + +Scope = main component affected (omit if unclear) +Subject = what you did, imperative, max 50 chars +Breaking change = add ! before colon + +Just the commit message, nothing else."; /// Build the prompt for generating a commit message pub(crate) fn build_prompt( @@ -240,17 +281,15 @@ pub(crate) fn build_prompt( let style_instruction = match style { CommitStyle::SingleLine => "Output ONLY a single-line subject.", CommitStyle::Detailed => { - "Output format: -- Subject line (under 72 chars) + "Include a short body: +- Subject line (under 72 chars). Don't make it more generic just because a body follows. - Blank line -- Bullet list of key changes, grouped by feature (- prefix)" +- Brief paragraph explaining what changed and why, if not obvious from the subject." } CommitStyle::Auto => { - "Include a body (blank line + bullet list) ONLY if: -- The change is too complex to capture in the subject alone -- Breaking changes need explanation - -Otherwise output ONLY the subject line." + "Include a body (blank line + brief paragraph) ONLY if the change is too complex for the subject alone or needs explanation. +Otherwise output ONLY the subject line. +If a body is included, keep the subject concrete and specific." } }; @@ -272,6 +311,30 @@ pub fn generate_commit_message( diff: &str, diff_stat: &str, style: CommitStyle, +) -> Result { + // Determine which endpoint to use based on model's supported_endpoints + let use_responses = auth + .supported_endpoints + .as_ref() + .map(|endpoints| { + // Use /responses if it's the only supported endpoint (chat/completions not available) + !endpoints.iter().any(|e| e == "/chat/completions") + && endpoints.iter().any(|e| e == "/responses") + }) + .unwrap_or(false); + + if use_responses { + generate_via_responses(auth, diff, diff_stat, style) + } else { + generate_via_chat_completions(auth, diff, diff_stat, style) + } +} + +fn generate_via_chat_completions( + auth: &CopilotAuth, + diff: &str, + diff_stat: &str, + style: CommitStyle, ) -> Result { let endpoint = auth .api_endpoint @@ -311,6 +374,14 @@ pub fn generate_commit_message( let prompt = build_prompt(diff, diff_stat, style, is_truncated); + // Default to "none" for GPT-5.x reasoning models (faster, still good quality) + // Don't send for non-GPT-5 models - they reject the parameter with HTTP 400 + let reasoning_effort = if model.starts_with("gpt-5") && !model.contains("codex") { + Some("none".to_string()) + } else { + None + }; + let request = ChatRequest { messages: vec![ ChatMessage { @@ -329,6 +400,7 @@ pub fn generate_commit_message( n: 1, stream: true, intent: true, + reasoning_effort, }; let session_id = Uuid::new_v4().to_string(); @@ -393,6 +465,117 @@ pub fn generate_commit_message( Ok(message) } +fn generate_via_responses( + auth: &CopilotAuth, + diff: &str, + diff_stat: &str, + style: CommitStyle, +) -> Result { + let endpoint = auth + .api_endpoint + .as_deref() + .ok_or_else(|| CopilotError::Api("No API endpoint found in auth".into()))?; + + let url = format!("{}/responses", endpoint); + let agent = auth::create_agent(); + + let model = auth.model.as_deref().unwrap_or(DEFAULT_MODEL); + + // Get max diff size based on model's context limit + let max_diff_bytes = get_max_diff_bytes(auth.max_prompt_tokens); + if max_diff_bytes == 0 { + return Err(CopilotError::Api( + "Model context limit too small for commit generation".into(), + )); + } + + // Truncate diff if too large (at a valid UTF-8 boundary) + let is_truncated = diff.len() > max_diff_bytes; + let diff = if is_truncated { + eprintln!( + "Warning: Diff is large ({} bytes), truncating to ~{} bytes", + diff.len(), + max_diff_bytes + ); + let mut end = max_diff_bytes; + while end > 0 && !diff.is_char_boundary(end) { + end -= 1; + } + &diff[..end] + } else { + diff + }; + + let prompt = build_prompt(diff, diff_stat, style, is_truncated); + + let request = ResponsesRequest { + model: model.to_string(), + input: prompt, + stream: true, + instructions: Some( + "You are an expert software engineer. Output only the raw commit message. No explanations, markdown, or extra text.".to_string() + ), + }; + + let session_id = Uuid::new_v4().to_string(); + let machine_id = auth + .machine_id + .clone() + .unwrap_or_else(|| Uuid::new_v4().to_string()); + + let response = agent + .post(&url) + .header("Authorization", format!("Bearer {}", auth.access)) + .header("Copilot-Integration-Id", "vscode-chat") + .header("Editor-Version", "vscode/1.95.0") + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .header("x-request-id", &session_id) + .header("vscode-sessionid", &session_id) + .header("vscode-machineid", &machine_id) + .send_json(&request)?; + + let mut reader = BufReader::new(response.into_body().into_reader()); + let mut line = String::new(); + let mut full_message = String::new(); + let mut chunks_received = 0; + + // Parse streaming response from /responses endpoint + // Events look like: "event: response.output_text.delta\ndata: {\"delta\":\"text\",...}" + while reader.read_line(&mut line)? > 0 { + if let Some(data) = line.strip_prefix("data: ") { + let data = data.trim(); + + // Parse JSON and extract delta text from response.output_text.delta events + if let Ok(json) = serde_json::from_str::(data) { + // Check if this is a text delta event + if let Some(delta) = json.get("delta").and_then(|d| d.as_str()) { + chunks_received += 1; + full_message.push_str(delta); + } + } + } + line.clear(); + } + + if chunks_received == 0 { + return Err(CopilotError::Api( + "No valid response chunks from Copilot".into(), + )); + } + + let message = full_message.trim(); + + // Strip markdown code blocks if present (AI sometimes wraps output in ```) + let message = strip_markdown_code_block(message); + let message = lowercase_scope(&message); + + if message.is_empty() { + return Err(CopilotError::Api("Empty response from Copilot".into())); + } + Ok(message) +} + #[cfg(test)] mod tests { use super::*; @@ -426,8 +609,7 @@ mod tests { let prompt = build_prompt(diff, diff_stat, CommitStyle::SingleLine, false); assert!(prompt.contains("single-line")); - assert!(prompt.contains("72 characters")); - assert!(prompt.contains("type(scope): description")); + assert!(prompt.contains("type(scope):")); assert!(prompt.contains(diff)); assert!(prompt.contains(diff_stat)); } @@ -438,9 +620,9 @@ mod tests { let diff_stat = "file.rs | 10 ++++++++++"; let prompt = build_prompt(diff, diff_stat, CommitStyle::Detailed, false); - assert!(prompt.contains("Bullet list")); + assert!(prompt.contains("Brief paragraph")); assert!(prompt.contains("Blank line")); - assert!(prompt.contains("- prefix")); + assert!(prompt.contains("what changed and why")); assert!(prompt.contains(diff)); assert!(prompt.contains(diff_stat)); } @@ -451,8 +633,8 @@ mod tests { let diff_stat = "file.rs | 10 ++++++++++"; let prompt = build_prompt(diff, diff_stat, CommitStyle::Auto, false); - assert!(prompt.contains("too complex to capture")); - assert!(prompt.contains("Breaking changes")); + assert!(prompt.contains("too complex for the subject alone")); + assert!(prompt.contains("Breaking change")); assert!(prompt.contains("ONLY the subject line")); assert!(prompt.contains(diff)); assert!(prompt.contains(diff_stat)); @@ -469,10 +651,9 @@ mod tests { CommitStyle::Auto, ] { let prompt = build_prompt(diff, diff_stat, style, false); - assert!(prompt.contains("type(scope): description")); - assert!(prompt.contains("72 characters")); - assert!(prompt.contains("intent")); - assert!(prompt.contains("Output ONLY")); + assert!(prompt.contains("type(scope):")); + assert!(prompt.contains("Breaking change")); + assert!(prompt.contains("commit message")); } } diff --git a/src/git.rs b/src/git.rs index bab4393..4bfaa40 100644 --- a/src/git.rs +++ b/src/git.rs @@ -84,7 +84,7 @@ fn read_ignore_patterns() -> Vec { pub fn get_diff() -> Result { let ignore_patterns = read_ignore_patterns(); - let mut args = vec!["diff", "--cached", "--"]; + let mut args = vec!["diff", "--cached", "-U12", "--"]; // Add pathspecs: start with "." to include all, then exclude patterns args.push("."); @@ -140,6 +140,101 @@ pub fn has_staged_changes() -> Result { Ok(!diff.is_empty()) } +/// Read diff content from a file or stdin (if path is "-") +/// Only available in debug builds for testing purposes. +#[cfg(debug_assertions)] +pub fn read_diff_from_file(path: &str) -> Result { + use std::io::Read; + + if path == "-" { + let mut buf = String::new(); + std::io::stdin() + .read_to_string(&mut buf) + .map_err(|e| GitError::Git(format!("Failed to read stdin: {}", e)))?; + Ok(buf) + } else { + std::fs::read_to_string(path) + .map_err(|e| GitError::Git(format!("Failed to read {}: {}", path, e))) + } +} + +/// Derive a diff stat summary from diff content (similar to git diff --stat) +/// Only available in debug builds for testing purposes. +#[cfg(debug_assertions)] +pub fn derive_diff_stat(diff: &str) -> String { + let mut files: Vec<(String, usize, usize)> = Vec::new(); + let mut current_file: Option = None; + let mut current_ins = 0usize; + let mut current_del = 0usize; + + for line in diff.lines() { + if line.starts_with("diff --git ") { + // Save previous file stats + if let Some(file) = current_file.take() { + files.push((file, current_ins, current_del)); + } + // Extract filename from "diff --git a/path b/path" + if let Some(b_part) = line.split(" b/").nth(1) { + current_file = Some(b_part.to_string()); + current_ins = 0; + current_del = 0; + } + } else if current_file.is_some() { + if line.starts_with('+') && !line.starts_with("+++") { + current_ins += 1; + } else if line.starts_with('-') && !line.starts_with("---") { + current_del += 1; + } + } + } + + // Don't forget the last file + if let Some(file) = current_file { + files.push((file, current_ins, current_del)); + } + + if files.is_empty() { + return String::new(); + } + + let mut lines: Vec = Vec::new(); + let mut total_ins = 0usize; + let mut total_del = 0usize; + + // Find max filename length for alignment + let max_name_len = files.iter().map(|(f, _, _)| f.len()).max().unwrap_or(0); + + for (file, ins, del) in &files { + total_ins += ins; + total_del += del; + let total = ins + del; + // Truncate the +/- bar to reasonable length + let bar_len = std::cmp::min(total, 50); + let ins_bar = std::cmp::min(*ins, bar_len); + let del_bar = bar_len.saturating_sub(ins_bar).min(*del); + let bar = format!("{}{}", "+".repeat(ins_bar), "-".repeat(del_bar)); + lines.push(format!( + " {:width$} | {:>4} {}", + file, + total, + bar, + width = max_name_len + )); + } + + lines.push(format!( + " {} file{} changed, {} insertion{}(+), {} deletion{}(-)", + files.len(), + if files.len() == 1 { "" } else { "s" }, + total_ins, + if total_ins == 1 { "" } else { "s" }, + total_del, + if total_del == 1 { "" } else { "s" } + )); + + lines.join("\n") +} + /// Check if this is the initial commit (no commits yet; HEAD does not exist) pub fn is_initial_commit() -> Result { let output = Command::new("git") @@ -357,4 +452,63 @@ mod tests { // Two commits - should NOT be treated as initial commit assert!(!is_initial_commit_in_dir(dir.path()).unwrap()); } + + #[cfg(debug_assertions)] + #[test] + fn test_derive_diff_stat_single_file() { + let diff = r#"diff --git a/README.md b/README.md +index 1234567..abcdefg 100644 +--- a/README.md ++++ b/README.md +@@ -10,7 +10,7 @@ A CLI tool for generating commit messages. + ## Installation + +-cargo instal ghcc ++cargo install ghcc + + ## Usage"#; + + let stat = derive_diff_stat(diff); + assert!(stat.contains("README.md")); + assert!(stat.contains("1 file")); + assert!(stat.contains("1 insertion")); + assert!(stat.contains("1 deletion")); + } + + #[cfg(debug_assertions)] + #[test] + fn test_derive_diff_stat_multiple_files() { + let diff = r#"diff --git a/src/main.rs b/src/main.rs +index 1234567..abcdefg 100644 +--- a/src/main.rs ++++ b/src/main.rs +@@ -1,3 +1,5 @@ ++use std::io; ++use std::fs; + fn main() { + println!("Hello"); + } +diff --git a/src/lib.rs b/src/lib.rs +index 1234567..abcdefg 100644 +--- a/src/lib.rs ++++ b/src/lib.rs +@@ -1,5 +1,4 @@ +-fn old_function() { +- // removed +-} ++fn new_function() {} +"#; + + let stat = derive_diff_stat(diff); + assert!(stat.contains("src/main.rs")); + assert!(stat.contains("src/lib.rs")); + assert!(stat.contains("2 files changed")); + } + + #[cfg(debug_assertions)] + #[test] + fn test_derive_diff_stat_empty_diff() { + let stat = derive_diff_stat(""); + assert!(stat.is_empty()); + } } diff --git a/src/main.rs b/src/main.rs index 84d7410..6a25b1e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,8 +22,13 @@ fn print_usage() { eprintln!(" ghcc hook uninstall Remove git prepare-commit-msg hook"); eprintln!(); eprintln!("Options:"); - eprintln!(" -d, --detailed Include bullet-point body (default: single-line)"); + eprintln!(" -d, --detailed Include detailed body (default: single-line)"); eprintln!(" -a, --auto Let AI decide format (experimental)"); + #[cfg(debug_assertions)] + { + eprintln!(" --diff-file Read diff from file instead of git (use - for stdin)"); + eprintln!(" --model Use specific model for this request"); + } eprintln!(" -h, --help Show this help message"); eprintln!(" -V, --version Show version"); } @@ -104,6 +109,8 @@ fn cmd_models() -> Result<(), Box> { ); } eprintln!(); + eprintln!("Missing a model? Enable it at https://github.com/settings/copilot/features"); + eprintln!(); let current = auth.model.as_deref().unwrap_or(copilot::DEFAULT_MODEL); eprintln!("Current: {}", current); @@ -128,6 +135,7 @@ fn cmd_models() -> Result<(), Box> { .as_ref() .and_then(|c| c.limits.as_ref()) .and_then(|l| l.max_prompt_tokens); + auth.supported_endpoints = selected.supported_endpoints.clone(); auth::save_auth(&auth)?; eprintln!("Model set to: {}", selected.id); } else { @@ -278,6 +286,53 @@ fn cmd_hook(msg_file: &str, style: copilot::CommitStyle) -> Result<(), Box, + model_override: Option<&str>, +) -> Result<(), Box> { + let mut auth = auth::get_valid_auth()?; + + // Override model if specified (doesn't persist to disk) + if let Some(model) = model_override { + auth.model = Some(model.to_string()); + + // Infer supported_endpoints from model name + // Codex models only support /responses endpoint (not /chat/completions) + if model.contains("codex") { + auth.supported_endpoints = Some(vec!["/responses".to_string()]); + } + } + + let (diff, diff_stat) = if let Some(path) = diff_file { + let diff = git::read_diff_from_file(path)?; + let diff_stat = git::derive_diff_stat(&diff); + (diff, diff_stat) + } else { + if !git::has_staged_changes()? { + eprintln!("No staged changes found. Please stage your changes using 'git add'."); + return Ok(()); + } + + // Short-circuit for initial commit + if git::is_initial_commit()? { + println!("Initial commit"); + return Ok(()); + } + + (git::get_diff()?, git::get_diff_stat()?) + }; + + let message = copilot::generate_commit_message(&auth, &diff, &diff_stat, style)?; + + // Output to stdout for piping: `git commit -m "$(ghcc)"` + println!("{}", message); + + Ok(()) +} + +#[cfg(not(debug_assertions))] fn cmd_generate(style: copilot::CommitStyle) -> Result<(), Box> { let auth = auth::get_valid_auth()?; @@ -314,8 +369,51 @@ fn main() -> ExitCode { copilot::CommitStyle::SingleLine }; + // Debug-only: Check for --diff-file + #[cfg(debug_assertions)] + let diff_file: Option = if let Some(pos) = args.iter().position(|a| a == "--diff-file") + { + match args.get(pos + 1) { + Some(path) if !path.starts_with('-') || path == "-" => Some(path.clone()), + _ => { + eprintln!("Error: --diff-file requires a path argument (use - for stdin)"); + return ExitCode::from(1); + } + } + } else { + None + }; + + // Debug-only: Check for --model + #[cfg(debug_assertions)] + let model_override: Option = if let Some(pos) = args.iter().position(|a| a == "--model") + { + match args.get(pos + 1) { + Some(name) if !name.starts_with('-') => Some(name.clone()), + _ => { + eprintln!("Error: --model requires a model name"); + return ExitCode::from(1); + } + } + } else { + None + }; + // Check for --hook (internal, called by git hook) if let Some(pos) = args.iter().position(|a| a == "--hook") { + // Debug-only: --diff-file and --model are incompatible with --hook + #[cfg(debug_assertions)] + { + if diff_file.is_some() { + eprintln!("Error: --diff-file cannot be used with --hook"); + return ExitCode::from(1); + } + if model_override.is_some() { + eprintln!("Error: --model cannot be used with --hook"); + return ExitCode::from(1); + } + } + if let Some(msg_file) = args.get(pos + 1) { return match cmd_hook(msg_file, style) { Ok(()) => ExitCode::SUCCESS, @@ -331,6 +429,27 @@ fn main() -> ExitCode { } // Get command (first non-flag argument after program name) + #[cfg(debug_assertions)] + let cmd = args.iter().skip(1).find(|a| { + if a.starts_with('-') { + return false; + } + // Check if this arg is the value for --diff-file + if let Some(pos) = args.iter().position(|x| x == "--diff-file") { + if args.get(pos + 1) == Some(*a) { + return false; + } + } + // Check if this arg is the value for --model + if let Some(pos) = args.iter().position(|x| x == "--model") { + if args.get(pos + 1) == Some(*a) { + return false; + } + } + true + }); + + #[cfg(not(debug_assertions))] let cmd = args.iter().skip(1).find(|a| !a.starts_with('-')); let result = match cmd.map(|s| s.as_str()) { @@ -366,6 +485,9 @@ fn main() -> ExitCode { print_usage(); return ExitCode::from(1); } + #[cfg(debug_assertions)] + None => cmd_generate(style, diff_file.as_deref(), model_override.as_deref()), + #[cfg(not(debug_assertions))] None => cmd_generate(style), };