diff --git a/lua/opencode/config.lua b/lua/opencode/config.lua index ab0052be..89b6fd9d 100644 --- a/lua/opencode/config.lua +++ b/lua/opencode/config.lua @@ -182,6 +182,10 @@ M.defaults = { show_ids = false, }, prompt_guard = nil, + hooks = { + on_file_edited = nil, + on_session_loaded = nil, + }, } M.values = vim.deepcopy(M.defaults) diff --git a/lua/opencode/config_file.lua b/lua/opencode/config_file.lua index 6cd0ccf7..5b373ab7 100644 --- a/lua/opencode/config_file.lua +++ b/lua/opencode/config_file.lua @@ -58,10 +58,12 @@ end function M.get_model_info(provider, model) local config_file = require('opencode.config_file') - local providers = config_file.get_opencode_providers() or {} + local providers_response = config_file.get_opencode_providers() + local providers = providers_response and providers_response.providers or {} + local filtered_providers = vim.tbl_filter(function(p) return p.id == provider - end, providers.providers) + end, providers) if #filtered_providers == 0 then return nil diff --git a/lua/opencode/types.lua b/lua/opencode/types.lua index 7fa451e9..31ad0055 100644 --- a/lua/opencode/types.lua +++ b/lua/opencode/types.lua @@ -135,8 +135,12 @@ ---@field capture_streamed_events boolean ---@field show_ids boolean ---- @class OpencodeProviders ---- @field [string] string[] +---@class OpencodeHooks +---@field on_file_edited? fun(file: string): nil +---@field on_session_loaded? fun(session: Session): nil + +---@class OpencodeProviders +---@field [string] string[] ---@class OpencodeConfigModule ---@field defaults OpencodeConfig @@ -156,6 +160,7 @@ ---@field context OpencodeContextConfig ---@field debug OpencodeDebugConfig ---@field prompt_guard? fun(mentioned_files: string[]): boolean +---@field hooks OpencodeHooks ---@field legacy_commands boolean ---@class MessagePartState diff --git a/lua/opencode/ui/renderer.lua b/lua/opencode/ui/renderer.lua index 5c077217..917134ac 100644 --- a/lua/opencode/ui/renderer.lua +++ b/lua/opencode/ui/renderer.lua @@ -166,6 +166,10 @@ function M._render_full_session_data(session_data) M._set_model_from_messages() end M.scroll_to_bottom() + + if config.hooks and config.hooks.on_session_loaded then + pcall(config.hooks.on_session_loaded, state.active_session) + end end ---Render lines as the entire output buffer @@ -805,8 +809,11 @@ function M.on_permission_replied(properties) end end -function M.on_file_edited(_) +function M.on_file_edited(properties) vim.cmd('checktime') + if config.hooks and config.hooks.on_file_edited then + pcall(config.hooks.on_file_edited, properties.file) + end end ---@param properties RestorePointCreatedEvent diff --git a/tests/manual/renderer_replay.lua b/tests/manual/renderer_replay.lua index 879465ea..60f3b454 100644 --- a/tests/manual/renderer_replay.lua +++ b/tests/manual/renderer_replay.lua @@ -4,13 +4,14 @@ local helpers = require('tests.helpers') local output_window = require('opencode.ui.output_window') local config = require('opencode.config') -local M = {} - -M.events = {} -M.current_index = 0 -M.stop = false -M.last_loaded_file = nil -M.headless_mode = false +local M = { + events = {}, + event_index = 0, -- which events we have dispatched up to + events_received = 0, -- how many events we have received, just used for logging + stop = false, + last_loaded_file = nil, + headless_mode = false, +} function M.load_events(file_path) file_path = file_path or 'tests/data/simple-session.json' @@ -42,6 +43,7 @@ function M.load_events(file_path) end function M.setup_windows(opts) + require('opencode.ui.highlight').setup() helpers.replay_setup() vim.schedule(function() @@ -64,9 +66,9 @@ function M.replay_next(steps) steps = tonumber(steps) or 1 for _ = 1, steps do - if M.current_index < #M.events then - M.current_index = M.current_index + 1 - helpers.replay_event(M.events[M.current_index]) + if M.event_index < #M.events then + M.event_index = M.event_index + 1 + helpers.replay_event(M.events[M.event_index]) else vim.notify('No more events to replay', vim.log.levels.WARN) return @@ -100,7 +102,7 @@ function M.replay_all(delay_ms) -- will call renderer local function tick() M.replay_next() - if M.current_index >= #M.events or M.stop then + if M.event_index >= #M.events or M.stop then state.job_count = 0 if M.headless_mode then @@ -121,7 +123,8 @@ function M.replay_stop() end function M.reset() - M.current_index = 0 + M.event_index = 0 + M.events_received = 0 M.clear() end @@ -129,7 +132,7 @@ function M.show_status() local status = string.format( 'Replay Status:\n Events loaded: %d\n Current index: %d\n Playing: %s', #M.events, - M.current_index, + M.event_index, not M.stop ) vim.notify(status, vim.log.levels.INFO) @@ -345,7 +348,8 @@ function M.start(opts) -- NOTE: the index numbers will be incorrect when event collapsing happens local log_event = function(type, event) - local index = M.current_index + M.events_received = M.events_received + 1 + local index = M.events_received local count = #M.events local id = event.info and event.info.id or event.part and event.part.id @@ -357,33 +361,26 @@ function M.start(opts) vim.notify('Event ' .. index .. '/' .. count .. ': ' .. type .. ' ' .. id, vim.log.levels.INFO) end - state.event_manager:subscribe('session.updated', function(event) - log_event('session.updated', event) - end) - state.event_manager:subscribe('session.compacted', function(event) - log_event('session.compacted', event) - end) - state.event_manager:subscribe('session.error', function(event) - log_event('session.error', event) - end) - state.event_manager:subscribe('message.updated', function(event) - log_event('message.updated', event) - end) - state.event_manager:subscribe('message.removed', function(event) - log_event('message.removed', event) - end) - state.event_manager:subscribe('message.part.updated', function(event) - log_event('message.part.updated', event) - end) - state.event_manager:subscribe('message.removed', function(event) - log_event('message.removed', event) - end) - state.event_manager:subscribe('permission.updated', function(event) - log_event('permission.updated', event) - end) - state.event_manager:subscribe('permission.replied', function(event) - log_event('permission.replied', event) - end) + local events = { + 'session.updated', + 'session.compacted', + 'session.error', + 'session.idle', + 'message.updated', + 'message.removed', + 'message.part.updated', + 'message.removed', + 'permission.updated', + 'permission.replied', + 'file.edited', + 'server.connected', + } + + for _, event_name in ipairs(events) do + state.event_manager:subscribe(event_name, function(event) + log_event(event_name, event) + end) + end end return M diff --git a/tests/unit/hooks_spec.lua b/tests/unit/hooks_spec.lua new file mode 100644 index 00000000..3b41d7b3 --- /dev/null +++ b/tests/unit/hooks_spec.lua @@ -0,0 +1,110 @@ +local renderer = require('opencode.ui.renderer') +local config = require('opencode.config') +local state = require('opencode.state') +local helpers = require('tests.helpers') +local ui = require('opencode.ui.ui') + +describe('hooks', function() + before_each(function() + helpers.replay_setup() + config.hooks = { + on_file_edited = nil, + on_session_loaded = nil, + } + end) + + after_each(function() + if state.windows then + ui.close_windows(state.windows) + end + config.hooks = { + on_file_edited = nil, + on_session_loaded = nil, + } + end) + + describe('on_file_edited', function() + it('should call hook when file is edited', function() + local called = false + local file_path = nil + + config.hooks.on_file_edited = function(file) + called = true + file_path = file + end + + local test_event = { file = '/test/file.lua' } + renderer.on_file_edited(test_event) + + assert.is_true(called) + assert.are.equal('/test/file.lua', file_path) + end) + + it('should not error when hook is nil', function() + config.hooks.on_file_edited = nil + + local test_event = { file = '/test/file.lua' } + assert.has_no.errors(function() + renderer.on_file_edited(test_event) + end) + end) + + it('should not crash when hook throws error', function() + config.hooks.on_file_edited = function() + error('test error') + end + + local test_event = { file = '/test/file.lua' } + assert.has_no.errors(function() + renderer.on_file_edited(test_event) + end) + end) + end) + + describe('on_session_loaded', function() + it('should call hook when session is loaded', function() + local called = false + local session_data = nil + + config.hooks.on_session_loaded = function(session) + called = true + session_data = session + end + + local events = helpers.load_test_data('tests/data/simple-session.json') + state.active_session = helpers.get_session_from_events(events, true) + local loaded_session = helpers.load_session_from_events(events) + + renderer._render_full_session_data(loaded_session) + + assert.is_true(called) + assert.are.same(state.active_session, session_data) + end) + + it('should not error when hook is nil', function() + config.hooks.on_session_loaded = nil + + local events = helpers.load_test_data('tests/data/simple-session.json') + state.active_session = helpers.get_session_from_events(events, true) + local loaded_session = helpers.load_session_from_events(events) + + assert.has_no.errors(function() + renderer._render_full_session_data(loaded_session) + end) + end) + + it('should not crash when hook throws error', function() + config.hooks.on_session_loaded = function() + error('test error') + end + + local events = helpers.load_test_data('tests/data/simple-session.json') + state.active_session = helpers.get_session_from_events(events, true) + local loaded_session = helpers.load_session_from_events(events) + + assert.has_no.errors(function() + renderer._render_full_session_data(loaded_session) + end) + end) + end) +end)