diff --git a/README.md b/README.md index c272b3de..9b79039e 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,8 @@ require('opencode').setup({ hooks = { on_file_edited = nil, -- Called after a file is edited by opencode. on_session_loaded = nil, -- Called after a session is loaded. + on_done_thinking = nil, -- Called when opencode finishes thinking (all jobs complete). + on_permission_requested = nil, -- Called when a permission request is issued. }, }) ``` @@ -594,12 +596,14 @@ The plugin defines several highlight groups that can be customized to match your The `prompt_guard` configuration option allows you to control when prompts can be sent to Opencode. This is useful for preventing accidental or unauthorized AI interactions in certain contexts. -## 🪝Custom user hooks +## 🪝 Custom user hooks You can define custom functions to be called at specific events in Opencode: - `on_file_edited`: Called after a file is edited by Opencode. - `on_session_loaded`: Called after a session is loaded. +- `on_done_thinking`: Called when Opencode finishes thinking (all user jobs complete). +- `on_permission_requested`: Called when a permission request is issued. ```lua require('opencode').setup({ @@ -612,6 +616,14 @@ require('opencode').setup({ -- Custom logic after a session is loaded print("Session loaded: " .. session_name) end, + on_done_thinking = function() + -- Custom logic when thinking is done + print("Done thinking!") + end, + on_permission_requested = function() + -- Custom logic when a permission is requested + print("Permission requested!") + end, }, }) ``` diff --git a/lua/opencode/config.lua b/lua/opencode/config.lua index 9d633a75..e1724f9d 100644 --- a/lua/opencode/config.lua +++ b/lua/opencode/config.lua @@ -191,6 +191,8 @@ M.defaults = { hooks = { on_file_edited = nil, on_session_loaded = nil, + on_done_thinking = nil, + on_permission_requested = nil, }, } diff --git a/lua/opencode/core.lua b/lua/opencode/core.lua index b9100867..34ea8c25 100644 --- a/lua/opencode/core.lua +++ b/lua/opencode/core.lua @@ -147,8 +147,14 @@ function M.send_message(prompt, opts) params.parts = context.format_message(prompt, opts.context) M.before_run(opts) + -- Capture the session ID to ensure we track the message count for the correct session + local session_id = state.active_session.id + local sent_message_count = vim.deepcopy(state.user_message_count) + sent_message_count[session_id] = (sent_message_count[session_id] or 0) + 1 + state.user_message_count = sent_message_count + state.api_client - :create_message(state.active_session.id, params) + :create_message(session_id, params) :and_then(function(response) if not response or not response.info or not response.parts then -- fall back to full render. incremental render is handled @@ -156,6 +162,10 @@ function M.send_message(prompt, opts) ui.render_output() end + local received_message_count = vim.deepcopy(state.user_message_count) + received_message_count[response.info.sessionID] = (received_message_count[response.info.sessionID] ~= nil) and (received_message_count[response.info.sessionID] - 1) or 0 + state.user_message_count = received_message_count + M.after_run(prompt) end) :catch(function(err) @@ -367,6 +377,29 @@ function M.initialize_current_model() return state.current_model end +function M._on_user_message_count_change(_, new, old) + if config.hooks and config.hooks.on_done_thinking then + local all_sessions = session.get_all_workspace_sessions() or {} + local done_sessions = vim.tbl_filter(function(s) + local msg_count = new[s.id] or 0 + local old_msg_count = (old and old[s.id]) or 0 + return msg_count == 0 and old_msg_count > 0 + end, all_sessions) + + for _, done_session in ipairs(done_sessions) do + pcall(config.hooks.on_done_thinking, done_session) + end + end +end + +function M._on_current_permission_change(_, new, old) + local permission_requested = old == nil and new ~= nil + if config.hooks and config.hooks.on_permission_requested and permission_requested then + local local_session = session.get_by_id(state.active_session.id) or {} + pcall(config.hooks.on_permission_requested, local_session) + end +end + --- Handle clipboard image data by saving it to a file and adding it to context --- @return boolean success True if image was successfully handled function M.paste_image_from_clipboard() @@ -375,6 +408,8 @@ end function M.setup() state.subscribe('opencode_server', on_opencode_server) + state.subscribe('user_message_count', M._on_user_message_count_change) + state.subscribe('current_permission', M._on_current_permission_change) vim.schedule(function() M.opencode_ok() diff --git a/lua/opencode/server_job.lua b/lua/opencode/server_job.lua index 35a0a086..fa922c22 100644 --- a/lua/opencode/server_job.lua +++ b/lua/opencode/server_job.lua @@ -26,6 +26,7 @@ end --- @return Promise promise A promise that resolves with the result or rejects with an error function M.call_api(url, method, body) local call_promise = Promise.new() + state.job_count = state.job_count + 1 local request_entry = { nil, call_promise } diff --git a/lua/opencode/state.lua b/lua/opencode/state.lua index 2263f5cc..8d571cb8 100644 --- a/lua/opencode/state.lua +++ b/lua/opencode/state.lua @@ -32,6 +32,7 @@ ---@field cost number ---@field tokens_count number ---@field job_count number +---@field user_message_count table ---@field opencode_server OpencodeServer|nil ---@field api_client OpencodeApiClient ---@field event_manager EventManager|nil @@ -80,6 +81,7 @@ local _state = { tokens_count = 0, -- job job_count = 0, + user_message_count = {}, opencode_server = nil, api_client = nil, event_manager = nil, diff --git a/lua/opencode/types.lua b/lua/opencode/types.lua index 7b76e9df..1c254ab7 100644 --- a/lua/opencode/types.lua +++ b/lua/opencode/types.lua @@ -150,6 +150,8 @@ ---@class OpencodeHooks ---@field on_file_edited? fun(file: string): nil ---@field on_session_loaded? fun(session: Session): nil +---@field on_done_thinking? fun(session: Session): nil +---@field on_permission_requested? fun(session: Session): nil ---@class OpencodeProviders ---@field [string] string[] diff --git a/tests/unit/hooks_spec.lua b/tests/unit/hooks_spec.lua index 3b41d7b3..4c8b1eb6 100644 --- a/tests/unit/hooks_spec.lua +++ b/tests/unit/hooks_spec.lua @@ -1,6 +1,7 @@ local renderer = require('opencode.ui.renderer') local config = require('opencode.config') local state = require('opencode.state') +local core = require('opencode.core') local helpers = require('tests.helpers') local ui = require('opencode.ui.ui') @@ -10,6 +11,8 @@ describe('hooks', function() config.hooks = { on_file_edited = nil, on_session_loaded = nil, + on_done_thinking = nil, + on_permission_requested = nil, } end) @@ -20,6 +23,8 @@ describe('hooks', function() config.hooks = { on_file_edited = nil, on_session_loaded = nil, + on_done_thinking = nil, + on_permission_requested = nil, } end) @@ -107,4 +112,117 @@ describe('hooks', function() end) end) end) + + describe('on_done_thinking', function() + it('should call hook when thinking is done', function() + local called = false + local called_session = nil + + config.hooks.on_done_thinking = function(session) + called = true + called_session = session + end + + -- Mock session.get_all_workspace_sessions to return our test session + local session_module = require('opencode.session') + local original_get_all = session_module.get_all_workspace_sessions + session_module.get_all_workspace_sessions = function() + return { { id = 'test-session', title = 'Test' } } + end + + state.subscribe('user_message_count', core._on_user_message_count_change) + + -- Simulate job count change from 1 to 0 (done thinking) for a specific session + state.active_session = { id = 'test-session', title = 'Test' } + state.user_message_count = { ['test-session'] = 1 } + state.user_message_count = { ['test-session'] = 0 } + + -- Wait for async notification + vim.wait(100, function() return called end) + + -- Restore original function + session_module.get_all_workspace_sessions = original_get_all + state.unsubscribe('user_message_count', core._on_user_message_count_change) + + assert.is_true(called) + assert.are.equal(called_session.id, 'test-session') + end) + + it('should not error when hook is nil', function() + config.hooks.on_done_thinking = nil + state.active_session = { id = 'test-session', title = 'Test' } + state.user_message_count = { ['test-session'] = 1 } + assert.has_no.errors(function() + state.user_message_count = { ['test-session'] = 0 } + end) + end) + + it('should not crash when hook throws error', function() + config.hooks.on_done_thinking = function() + error('test error') + end + + state.active_session = { id = 'test-session', title = 'Test' } + state.user_message_count = { ['test-session'] = 1 } + assert.has_no.errors(function() + state.user_message_count = { ['test-session'] = 0 } + end) + end) + end) + + describe('on_permission_requested', function() + it('should call hook when permission is requested', function() + local called = false + local called_session = nil + + config.hooks.on_permission_requested = function(session) + called = true + called_session = session + end + + -- Mock session.get_by_id to return our test session + local session_module = require('opencode.session') + local original_get_by_id = session_module.get_by_id + session_module.get_by_id = function(id) + return { id = id, title = 'Test' } + end + + -- Set up the subscription manually + state.subscribe('current_permission', core._on_current_permission_change) + + -- Simulate permission change from nil to a value + state.active_session = { id = 'test-session', title = 'Test' } + state.current_permission = nil + state.current_permission = { tool = 'test_tool', action = 'read' } + + -- Wait for async notification + vim.wait(100, function() return called end) + + -- Restore original function + session_module.get_by_id = original_get_by_id + state.unsubscribe('current_permission', core._on_current_permission_change) + + assert.is_true(called) + assert.are.equal(called_session.id, 'test-session') + end) + + it('should not error when hook is nil', function() + config.hooks.on_permission_requested = nil + state.current_permission = nil + assert.has_no.errors(function() + state.current_permission = { tool = 'test_tool', action = 'read' } + end) + end) + + it('should not crash when hook throws error', function() + config.hooks.on_permission_requested = function() + error('test error') + end + + state.current_permission = nil + assert.has_no.errors(function() + state.current_permission = { tool = 'test_tool', action = 'read' } + end) + end) + end) end)