Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion lua/ai_commit_msg.lua
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ function M.format_cost(cost_info, format)
end
end

-- Small model patterns: auto-select SMALL_MODEL_SYSTEM_PROMPT when the user
-- hasn't explicitly overridden system_prompt and the model name matches.
local SMALL_MODEL_PATTERNS = { "haiku", "nano", "lite" }

local function is_small_model(model)
local model_lower = model:lower()
for _, pattern in ipairs(SMALL_MODEL_PATTERNS) do
if model_lower:find(pattern, 1, true) then
return true
end
end
return false
end

-- Get the active provider configuration
function M.get_active_provider_config()
local provider_name = M.config.provider
Expand All @@ -73,10 +87,17 @@ function M.get_active_provider_config()
error("No configuration found for provider: " .. tostring(provider_name))
end

-- Return a merged config with provider-specific settings
local active_config = vim.tbl_deep_extend("force", {}, provider_config)
active_config.provider = provider_name

-- Auto-select small model prompt if the user hasn't overridden system_prompt
local default_provider = config_mod.default.providers[provider_name]
local uses_default_prompt = default_provider
and active_config.system_prompt == default_provider.system_prompt
if uses_default_prompt and active_config.model and is_small_model(active_config.model) then
active_config.system_prompt = require("ai_commit_msg.prompts").SMALL_MODEL_SYSTEM_PROMPT
end

return active_config
end

Expand Down
8 changes: 7 additions & 1 deletion lua/ai_commit_msg/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local DEFAULT_SYSTEM_PROMPT = require("ai_commit_msg.prompts").DEFAULT_SYSTEM_PR

---@class AiCommitMsgConfig
---@field enabled boolean Whether to enable the plugin
---@field provider string AI provider to use ("openai", "anthropic", or "gemini")
---@field provider string AI provider to use ("openai", "anthropic", "gemini", "copilot", or "claude_code")
---@field providers table<string, ProviderConfig> Provider-specific configurations
---@field auto_push_prompt boolean Whether to prompt for push after commit
---@field pull_before_push { enabled: boolean, args: string[] } Whether and how to run `git pull` before pushing
Expand Down Expand Up @@ -113,6 +113,12 @@ M.default = {
system_prompt = DEFAULT_SYSTEM_PROMPT,
pricing = {},
},
claude_code = {
model = "sonnet",
prompt = DEFAULT_PROMPT,
system_prompt = DEFAULT_SYSTEM_PROMPT,
pricing = {},
},
},
}

Expand Down
39 changes: 39 additions & 0 deletions lua/ai_commit_msg/prompts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,43 @@ Examples:
- refactor(router): inline trivial helper
]]

-- Optimized for small/fast models (Haiku, GPT-4.1-nano, Flash-Lite).
-- Shorter, flatter structure. Explicit body rules to avoid terse one-liners.
M.SMALL_MODEL_SYSTEM_PROMPT = [[
Generate ONE Conventional Commit message from the git diff below.
Output plain text only. No markdown, no quotes, no commentary.

FORMAT:
<type>(scope): description

- bullet explaining what and why
- bullet for behavior change or secondary impact

TYPES: feat fix docs style refactor perf test build ci chore revert

RULES:
1. Header <= 72 chars. Imperative mood ("add" not "added"). No period.
2. Scope: lowercase, parenthesized, pick the primary area changed.
3. Include 2-4 body bullets for any non-trivial change. Each bullet starts with "- ".
4. Bullets explain WHY, not just WHAT. Group related changes.
5. Skip body ONLY for single-line trivial changes (typo, formatting).
6. Only describe lines that actually changed (+/- in the diff).
7. Do NOT list files. Do NOT repeat the header in the body.
8. For breaking changes: add ! after scope and a BREAKING CHANGE: footer.
9. One commit only. No extras.

EXAMPLES:

feat(auth): add JWT refresh token rotation
- prevent token reuse after refresh to mitigate replay attacks
- store token family ID to detect stolen refresh tokens
- update middleware to validate token lineage

fix(ui): correct dropdown z-index in modal overlay
- modal content was rendered behind the navigation bar
- set explicit stacking context on the modal container

chore(deps): bump eslint from 8.56 to 9.0
]]

return M
80 changes: 80 additions & 0 deletions lua/ai_commit_msg/providers/claude_code.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
local M = {}

function M.call_api(config, diff, callback)
if vim.fn.executable("claude") ~= 1 then
callback(false, "claude CLI not found. Install Claude Code: https://docs.anthropic.com/en/docs/claude-code")
return
end

local prompt
if config.prompt:find("{diff}", 1, true) then
local before, after = config.prompt:match("^(.*)%{diff%}(.*)$")
if before and after then
prompt = before .. diff .. after
else
prompt = config.prompt .. "\n\n" .. diff
end
else
prompt = config.prompt .. "\n\n" .. diff
end

vim.schedule(function()
require("ai_commit_msg.config").notify(
"ai-commit-msg.nvim: Prompt length: " .. #prompt .. " chars",
vim.log.levels.DEBUG
)
end)

local cmd = {
"claude",
"-p",
"--output-format",
"json",
"--max-turns",
"1",
"--system-prompt",
config.system_prompt,
}

if config.model then
table.insert(cmd, "--model")
table.insert(cmd, config.model)
end

vim.system(cmd, { stdin = prompt }, function(res)
if res.code ~= 0 then
callback(false, "claude CLI failed: " .. (res.stderr or "Unknown error"))
return
end

local ok, response = pcall(vim.json.decode, res.stdout)
if not ok then
callback(false, "Failed to parse claude CLI response: " .. tostring(response))
return
end

if response.is_error then
callback(false, "Claude Code error: " .. (response.result or "Unknown error"))
return
end

vim.schedule(function()
require("ai_commit_msg.config").notify(
"ai-commit-msg.nvim: Full CLI response: " .. vim.inspect(response),
vim.log.levels.DEBUG
)
end)

if response.result then
local commit_msg = response.result
commit_msg = commit_msg:gsub("^```%w*\n", ""):gsub("\n```$", ""):gsub("^`", ""):gsub("`$", "")
commit_msg = vim.trim(commit_msg)

callback(true, commit_msg, nil)
else
callback(false, "Unexpected claude CLI response format")
end
end)
end

return M
23 changes: 12 additions & 11 deletions lua/ai_commit_msg/providers/init.lua
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
local M = {}

local PROVIDERS = {
openai = "ai_commit_msg.providers.openai",
anthropic = "ai_commit_msg.providers.anthropic",
gemini = "ai_commit_msg.providers.gemini",
copilot = "ai_commit_msg.providers.copilot",
claude_code = "ai_commit_msg.providers.claude_code",
}

function M.get_provider(config)
local provider_name = config.provider
if provider_name == "openai" then
return require("ai_commit_msg.providers.openai")
elseif provider_name == "anthropic" then
return require("ai_commit_msg.providers.anthropic")
elseif provider_name == "gemini" then
return require("ai_commit_msg.providers.gemini")
elseif provider_name == "copilot" then
return require("ai_commit_msg.providers.copilot")
else
error("Unsupported provider: " .. tostring(provider_name))
local module = PROVIDERS[config.provider]
if not module then
error("Unsupported provider: " .. tostring(config.provider))
end
return require(module)
end

function M.call_api(config, diff, callback)
Expand Down
157 changes: 157 additions & 0 deletions spec/claude_code_provider_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
-- Minimal vim mock for running outside nvim
if not rawget(_G, "vim") then
_G.vim = {
fn = { executable = function() return 1 end },
system = function() end,
schedule = function(fn) fn() end,
json = {
encode = function(t) return require("dkjson").encode(t) end,
decode = function(s) return require("dkjson").decode(s) end,
},
trim = function(s) return s:match("^%s*(.-)%s*$") end,
inspect = function(t) return tostring(t) end,
log = { levels = { DEBUG = 1, INFO = 2, WARN = 3, ERROR = 4 } },
}
end

describe("claude_code provider", function()
local claude_code_provider
local original_executable = vim.fn.executable
local original_vim_system = vim.system

local default_config = {
model = "sonnet",
prompt = "{diff}",
system_prompt = "Generate a commit message",
}

local function mock_cli_response(response)
vim.system = function(cmd, opts, callback)
callback({
code = 0,
stdout = type(response) == "string" and response or vim.json.encode(response),
})
return cmd
end
end

before_each(function()
package.loaded["ai_commit_msg.providers.claude_code"] = nil
package.loaded["ai_commit_msg.config"] = nil
package.preload["ai_commit_msg.config"] = function()
return { notify = function() end }
end
claude_code_provider = require("ai_commit_msg.providers.claude_code")
vim.fn.executable = function() return 1 end
end)

after_each(function()
vim.fn.executable = original_executable
vim.system = original_vim_system
end)

describe("call_api", function()
it("calls callback with error when claude CLI not found", function()
vim.fn.executable = function() return 0 end

local result_success, result_message
claude_code_provider.call_api({}, "test diff", function(success, message)
result_success = success
result_message = message
end)

assert.is_false(result_success)
assert.truthy(result_message:find("claude CLI not found"))
end)

it("passes model flag from config", function()
local captured_cmd
vim.system = function(cmd, opts, callback)
captured_cmd = cmd
callback({
code = 0,
stdout = '{"type":"result","is_error":false,"result":"feat: test commit","usage":{"input_tokens":10,"output_tokens":5}}',
})
end

claude_code_provider.call_api(default_config, "test diff", function() end)

local found_model = false
for i, arg in ipairs(captured_cmd) do
if arg == "--model" and captured_cmd[i + 1] == "sonnet" then
found_model = true
break
end
end
assert.is_true(found_model)
end)

it("pipes prompt via stdin and system prompt via flag", function()
local captured_cmd, captured_opts
vim.system = function(cmd, opts, callback)
captured_cmd = cmd
captured_opts = opts
callback({
code = 0,
stdout = '{"type":"result","is_error":false,"result":"feat: test commit","usage":{"input_tokens":10,"output_tokens":5}}',
})
end

local config = { model = "haiku", prompt = "{diff}", system_prompt = "Generate a commit message" }
claude_code_provider.call_api(config, "my test diff", function() end)

assert.truthy(captured_opts.stdin)
assert.truthy(captured_opts.stdin:find("my test diff"))

local found_system_prompt = false
for i, arg in ipairs(captured_cmd) do
if arg == "--system-prompt" and captured_cmd[i + 1] == "Generate a commit message" then
found_system_prompt = true
break
end
end
assert.is_true(found_system_prompt)
end)

it("parses successful response and extracts usage", function()
mock_cli_response({
type = "result",
is_error = false,
result = "feat: add new feature",
usage = {
input_tokens = 100,
cache_creation_input_tokens = 50,
cache_read_input_tokens = 200,
output_tokens = 20,
},
})

local result_success, result_message, result_usage
local config = { model = "haiku", prompt = "{diff}", system_prompt = "test" }
claude_code_provider.call_api(config, "test diff", function(success, message, usage)
result_success = success
result_message = message
result_usage = usage
end)

assert.is_true(result_success)
assert.equals("feat: add new feature", result_message)
assert.equals(350, result_usage.input_tokens) -- 100 + 50 + 200
assert.equals(20, result_usage.output_tokens)
end)

it("handles error response from claude CLI", function()
mock_cli_response('{"type":"result","is_error":true,"result":"Something went wrong"}')

local result_success, result_message
local config = { model = "haiku", prompt = "{diff}", system_prompt = "test" }
claude_code_provider.call_api(config, "test diff", function(success, message)
result_success = success
result_message = message
end)

assert.is_false(result_success)
assert.truthy(result_message:find("Something went wrong"))
end)
end)
end)