diff --git a/lua/ai_commit_msg.lua b/lua/ai_commit_msg.lua index 893623f..8921817 100644 --- a/lua/ai_commit_msg.lua +++ b/lua/ai_commit_msg.lua @@ -64,6 +64,20 @@ function M.format_cost(cost_info, format) end end +-- Small model patterns: auto-select SMALL_MODEL_SYSTEM_PROMPT when the user +-- hasn't explicitly overridden system_prompt and the model name matches. +local SMALL_MODEL_PATTERNS = { "haiku", "nano", "lite" } + +local function is_small_model(model) + local model_lower = model:lower() + for _, pattern in ipairs(SMALL_MODEL_PATTERNS) do + if model_lower:find(pattern, 1, true) then + return true + end + end + return false +end + -- Get the active provider configuration function M.get_active_provider_config() local provider_name = M.config.provider @@ -73,10 +87,17 @@ function M.get_active_provider_config() error("No configuration found for provider: " .. tostring(provider_name)) end - -- Return a merged config with provider-specific settings local active_config = vim.tbl_deep_extend("force", {}, provider_config) active_config.provider = provider_name + -- Auto-select small model prompt if the user hasn't overridden system_prompt + local default_provider = config_mod.default.providers[provider_name] + local uses_default_prompt = default_provider + and active_config.system_prompt == default_provider.system_prompt + if uses_default_prompt and active_config.model and is_small_model(active_config.model) then + active_config.system_prompt = require("ai_commit_msg.prompts").SMALL_MODEL_SYSTEM_PROMPT + end + return active_config end diff --git a/lua/ai_commit_msg/config.lua b/lua/ai_commit_msg/config.lua index 7cb59f7..dadfca9 100644 --- a/lua/ai_commit_msg/config.lua +++ b/lua/ai_commit_msg/config.lua @@ -21,7 +21,7 @@ local DEFAULT_SYSTEM_PROMPT = require("ai_commit_msg.prompts").DEFAULT_SYSTEM_PR ---@class AiCommitMsgConfig ---@field enabled boolean Whether to enable the plugin ----@field provider string AI provider to use ("openai", "anthropic", or "gemini") +---@field provider string AI provider to use ("openai", "anthropic", "gemini", "copilot", or "claude_code") ---@field providers table Provider-specific configurations ---@field auto_push_prompt boolean Whether to prompt for push after commit ---@field pull_before_push { enabled: boolean, args: string[] } Whether and how to run `git pull` before pushing @@ -113,6 +113,12 @@ M.default = { system_prompt = DEFAULT_SYSTEM_PROMPT, pricing = {}, }, + claude_code = { + model = "sonnet", + prompt = DEFAULT_PROMPT, + system_prompt = DEFAULT_SYSTEM_PROMPT, + pricing = {}, + }, }, } diff --git a/lua/ai_commit_msg/prompts.lua b/lua/ai_commit_msg/prompts.lua index 93db4d7..34a33c3 100644 --- a/lua/ai_commit_msg/prompts.lua +++ b/lua/ai_commit_msg/prompts.lua @@ -211,4 +211,43 @@ Examples: - refactor(router): inline trivial helper ]] +-- Optimized for small/fast models (Haiku, GPT-4.1-nano, Flash-Lite). +-- Shorter, flatter structure. Explicit body rules to avoid terse one-liners. +M.SMALL_MODEL_SYSTEM_PROMPT = [[ +Generate ONE Conventional Commit message from the git diff below. +Output plain text only. No markdown, no quotes, no commentary. + +FORMAT: +(scope): description + +- bullet explaining what and why +- bullet for behavior change or secondary impact + +TYPES: feat fix docs style refactor perf test build ci chore revert + +RULES: +1. Header <= 72 chars. Imperative mood ("add" not "added"). No period. +2. Scope: lowercase, parenthesized, pick the primary area changed. +3. Include 2-4 body bullets for any non-trivial change. Each bullet starts with "- ". +4. Bullets explain WHY, not just WHAT. Group related changes. +5. Skip body ONLY for single-line trivial changes (typo, formatting). +6. Only describe lines that actually changed (+/- in the diff). +7. Do NOT list files. Do NOT repeat the header in the body. +8. For breaking changes: add ! after scope and a BREAKING CHANGE: footer. +9. One commit only. No extras. + +EXAMPLES: + +feat(auth): add JWT refresh token rotation +- prevent token reuse after refresh to mitigate replay attacks +- store token family ID to detect stolen refresh tokens +- update middleware to validate token lineage + +fix(ui): correct dropdown z-index in modal overlay +- modal content was rendered behind the navigation bar +- set explicit stacking context on the modal container + +chore(deps): bump eslint from 8.56 to 9.0 +]] + return M diff --git a/lua/ai_commit_msg/providers/claude_code.lua b/lua/ai_commit_msg/providers/claude_code.lua new file mode 100644 index 0000000..7d94089 --- /dev/null +++ b/lua/ai_commit_msg/providers/claude_code.lua @@ -0,0 +1,80 @@ +local M = {} + +function M.call_api(config, diff, callback) + if vim.fn.executable("claude") ~= 1 then + callback(false, "claude CLI not found. Install Claude Code: https://docs.anthropic.com/en/docs/claude-code") + return + end + + local prompt + if config.prompt:find("{diff}", 1, true) then + local before, after = config.prompt:match("^(.*)%{diff%}(.*)$") + if before and after then + prompt = before .. diff .. after + else + prompt = config.prompt .. "\n\n" .. diff + end + else + prompt = config.prompt .. "\n\n" .. diff + end + + vim.schedule(function() + require("ai_commit_msg.config").notify( + "ai-commit-msg.nvim: Prompt length: " .. #prompt .. " chars", + vim.log.levels.DEBUG + ) + end) + + local cmd = { + "claude", + "-p", + "--output-format", + "json", + "--max-turns", + "1", + "--system-prompt", + config.system_prompt, + } + + if config.model then + table.insert(cmd, "--model") + table.insert(cmd, config.model) + end + + vim.system(cmd, { stdin = prompt }, function(res) + if res.code ~= 0 then + callback(false, "claude CLI failed: " .. (res.stderr or "Unknown error")) + return + end + + local ok, response = pcall(vim.json.decode, res.stdout) + if not ok then + callback(false, "Failed to parse claude CLI response: " .. tostring(response)) + return + end + + if response.is_error then + callback(false, "Claude Code error: " .. (response.result or "Unknown error")) + return + end + + vim.schedule(function() + require("ai_commit_msg.config").notify( + "ai-commit-msg.nvim: Full CLI response: " .. vim.inspect(response), + vim.log.levels.DEBUG + ) + end) + + if response.result then + local commit_msg = response.result + commit_msg = commit_msg:gsub("^```%w*\n", ""):gsub("\n```$", ""):gsub("^`", ""):gsub("`$", "") + commit_msg = vim.trim(commit_msg) + + callback(true, commit_msg, nil) + else + callback(false, "Unexpected claude CLI response format") + end + end) +end + +return M diff --git a/lua/ai_commit_msg/providers/init.lua b/lua/ai_commit_msg/providers/init.lua index 25611f1..24c3f5e 100644 --- a/lua/ai_commit_msg/providers/init.lua +++ b/lua/ai_commit_msg/providers/init.lua @@ -1,18 +1,19 @@ local M = {} +local PROVIDERS = { + openai = "ai_commit_msg.providers.openai", + anthropic = "ai_commit_msg.providers.anthropic", + gemini = "ai_commit_msg.providers.gemini", + copilot = "ai_commit_msg.providers.copilot", + claude_code = "ai_commit_msg.providers.claude_code", +} + function M.get_provider(config) - local provider_name = config.provider - if provider_name == "openai" then - return require("ai_commit_msg.providers.openai") - elseif provider_name == "anthropic" then - return require("ai_commit_msg.providers.anthropic") - elseif provider_name == "gemini" then - return require("ai_commit_msg.providers.gemini") - elseif provider_name == "copilot" then - return require("ai_commit_msg.providers.copilot") - else - error("Unsupported provider: " .. tostring(provider_name)) + local module = PROVIDERS[config.provider] + if not module then + error("Unsupported provider: " .. tostring(config.provider)) end + return require(module) end function M.call_api(config, diff, callback) diff --git a/spec/claude_code_provider_spec.lua b/spec/claude_code_provider_spec.lua new file mode 100644 index 0000000..90ac6eb --- /dev/null +++ b/spec/claude_code_provider_spec.lua @@ -0,0 +1,157 @@ +-- Minimal vim mock for running outside nvim +if not rawget(_G, "vim") then + _G.vim = { + fn = { executable = function() return 1 end }, + system = function() end, + schedule = function(fn) fn() end, + json = { + encode = function(t) return require("dkjson").encode(t) end, + decode = function(s) return require("dkjson").decode(s) end, + }, + trim = function(s) return s:match("^%s*(.-)%s*$") end, + inspect = function(t) return tostring(t) end, + log = { levels = { DEBUG = 1, INFO = 2, WARN = 3, ERROR = 4 } }, + } +end + +describe("claude_code provider", function() + local claude_code_provider + local original_executable = vim.fn.executable + local original_vim_system = vim.system + + local default_config = { + model = "sonnet", + prompt = "{diff}", + system_prompt = "Generate a commit message", + } + + local function mock_cli_response(response) + vim.system = function(cmd, opts, callback) + callback({ + code = 0, + stdout = type(response) == "string" and response or vim.json.encode(response), + }) + return cmd + end + end + + before_each(function() + package.loaded["ai_commit_msg.providers.claude_code"] = nil + package.loaded["ai_commit_msg.config"] = nil + package.preload["ai_commit_msg.config"] = function() + return { notify = function() end } + end + claude_code_provider = require("ai_commit_msg.providers.claude_code") + vim.fn.executable = function() return 1 end + end) + + after_each(function() + vim.fn.executable = original_executable + vim.system = original_vim_system + end) + + describe("call_api", function() + it("calls callback with error when claude CLI not found", function() + vim.fn.executable = function() return 0 end + + local result_success, result_message + claude_code_provider.call_api({}, "test diff", function(success, message) + result_success = success + result_message = message + end) + + assert.is_false(result_success) + assert.truthy(result_message:find("claude CLI not found")) + end) + + it("passes model flag from config", function() + local captured_cmd + vim.system = function(cmd, opts, callback) + captured_cmd = cmd + callback({ + code = 0, + stdout = '{"type":"result","is_error":false,"result":"feat: test commit","usage":{"input_tokens":10,"output_tokens":5}}', + }) + end + + claude_code_provider.call_api(default_config, "test diff", function() end) + + local found_model = false + for i, arg in ipairs(captured_cmd) do + if arg == "--model" and captured_cmd[i + 1] == "sonnet" then + found_model = true + break + end + end + assert.is_true(found_model) + end) + + it("pipes prompt via stdin and system prompt via flag", function() + local captured_cmd, captured_opts + vim.system = function(cmd, opts, callback) + captured_cmd = cmd + captured_opts = opts + callback({ + code = 0, + stdout = '{"type":"result","is_error":false,"result":"feat: test commit","usage":{"input_tokens":10,"output_tokens":5}}', + }) + end + + local config = { model = "haiku", prompt = "{diff}", system_prompt = "Generate a commit message" } + claude_code_provider.call_api(config, "my test diff", function() end) + + assert.truthy(captured_opts.stdin) + assert.truthy(captured_opts.stdin:find("my test diff")) + + local found_system_prompt = false + for i, arg in ipairs(captured_cmd) do + if arg == "--system-prompt" and captured_cmd[i + 1] == "Generate a commit message" then + found_system_prompt = true + break + end + end + assert.is_true(found_system_prompt) + end) + + it("parses successful response and extracts usage", function() + mock_cli_response({ + type = "result", + is_error = false, + result = "feat: add new feature", + usage = { + input_tokens = 100, + cache_creation_input_tokens = 50, + cache_read_input_tokens = 200, + output_tokens = 20, + }, + }) + + local result_success, result_message, result_usage + local config = { model = "haiku", prompt = "{diff}", system_prompt = "test" } + claude_code_provider.call_api(config, "test diff", function(success, message, usage) + result_success = success + result_message = message + result_usage = usage + end) + + assert.is_true(result_success) + assert.equals("feat: add new feature", result_message) + assert.equals(350, result_usage.input_tokens) -- 100 + 50 + 200 + assert.equals(20, result_usage.output_tokens) + end) + + it("handles error response from claude CLI", function() + mock_cli_response('{"type":"result","is_error":true,"result":"Something went wrong"}') + + local result_success, result_message + local config = { model = "haiku", prompt = "{diff}", system_prompt = "test" } + claude_code_provider.call_api(config, "test diff", function(success, message) + result_success = success + result_message = message + end) + + assert.is_false(result_success) + assert.truthy(result_message:find("Something went wrong")) + end) + end) +end)