Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lua/opencode/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ M.defaults = {
show_ids = false,
},
prompt_guard = nil,
hooks = {
on_file_edited = nil,
on_session_loaded = nil,
},
}

M.values = vim.deepcopy(M.defaults)
Expand Down
6 changes: 4 additions & 2 deletions lua/opencode/config_file.lua
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ end

function M.get_model_info(provider, model)
local config_file = require('opencode.config_file')
local providers = config_file.get_opencode_providers() or {}
local providers_response = config_file.get_opencode_providers()
local providers = providers_response and providers_response.providers or {}

local filtered_providers = vim.tbl_filter(function(p)
return p.id == provider
end, providers.providers)
end, providers)

if #filtered_providers == 0 then
return nil
Expand Down
9 changes: 7 additions & 2 deletions lua/opencode/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,12 @@
---@field capture_streamed_events boolean
---@field show_ids boolean

--- @class OpencodeProviders
--- @field [string] string[]
---@class OpencodeHooks
---@field on_file_edited? fun(file: string): nil
---@field on_session_loaded? fun(session: Session): nil

---@class OpencodeProviders
---@field [string] string[]

---@class OpencodeConfigModule
---@field defaults OpencodeConfig
Expand All @@ -156,6 +160,7 @@
---@field context OpencodeContextConfig
---@field debug OpencodeDebugConfig
---@field prompt_guard? fun(mentioned_files: string[]): boolean
---@field hooks OpencodeHooks
---@field legacy_commands boolean

---@class MessagePartState
Expand Down
9 changes: 8 additions & 1 deletion lua/opencode/ui/renderer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ function M._render_full_session_data(session_data)
M._set_model_from_messages()
end
M.scroll_to_bottom()

if config.hooks and config.hooks.on_session_loaded then
pcall(config.hooks.on_session_loaded, state.active_session)
end
end

---Render lines as the entire output buffer
Expand Down Expand Up @@ -805,8 +809,11 @@ function M.on_permission_replied(properties)
end
end

function M.on_file_edited(_)
function M.on_file_edited(properties)
vim.cmd('checktime')
if config.hooks and config.hooks.on_file_edited then
pcall(config.hooks.on_file_edited, properties.file)
end
end

---@param properties RestorePointCreatedEvent
Expand Down
79 changes: 38 additions & 41 deletions tests/manual/renderer_replay.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ local helpers = require('tests.helpers')
local output_window = require('opencode.ui.output_window')
local config = require('opencode.config')

local M = {}

M.events = {}
M.current_index = 0
M.stop = false
M.last_loaded_file = nil
M.headless_mode = false
local M = {
events = {},
event_index = 0, -- which events we have dispatched up to
events_received = 0, -- how many events we have received, just used for logging
stop = false,
last_loaded_file = nil,
headless_mode = false,
}

function M.load_events(file_path)
file_path = file_path or 'tests/data/simple-session.json'
Expand Down Expand Up @@ -42,6 +43,7 @@ function M.load_events(file_path)
end

function M.setup_windows(opts)
require('opencode.ui.highlight').setup()
helpers.replay_setup()

vim.schedule(function()
Expand All @@ -64,9 +66,9 @@ function M.replay_next(steps)
steps = tonumber(steps) or 1

for _ = 1, steps do
if M.current_index < #M.events then
M.current_index = M.current_index + 1
helpers.replay_event(M.events[M.current_index])
if M.event_index < #M.events then
M.event_index = M.event_index + 1
helpers.replay_event(M.events[M.event_index])
else
vim.notify('No more events to replay', vim.log.levels.WARN)
return
Expand Down Expand Up @@ -100,7 +102,7 @@ function M.replay_all(delay_ms)
-- will call renderer
local function tick()
M.replay_next()
if M.current_index >= #M.events or M.stop then
if M.event_index >= #M.events or M.stop then
state.job_count = 0

if M.headless_mode then
Expand All @@ -121,15 +123,16 @@ function M.replay_stop()
end

function M.reset()
M.current_index = 0
M.event_index = 0
M.events_received = 0
M.clear()
end

function M.show_status()
local status = string.format(
'Replay Status:\n Events loaded: %d\n Current index: %d\n Playing: %s',
#M.events,
M.current_index,
M.event_index,
not M.stop
)
vim.notify(status, vim.log.levels.INFO)
Expand Down Expand Up @@ -345,7 +348,8 @@ function M.start(opts)

-- NOTE: the index numbers will be incorrect when event collapsing happens
local log_event = function(type, event)
local index = M.current_index
M.events_received = M.events_received + 1
local index = M.events_received
local count = #M.events
local id = event.info and event.info.id
or event.part and event.part.id
Expand All @@ -357,33 +361,26 @@ function M.start(opts)
vim.notify('Event ' .. index .. '/' .. count .. ': ' .. type .. ' ' .. id, vim.log.levels.INFO)
end

state.event_manager:subscribe('session.updated', function(event)
log_event('session.updated', event)
end)
state.event_manager:subscribe('session.compacted', function(event)
log_event('session.compacted', event)
end)
state.event_manager:subscribe('session.error', function(event)
log_event('session.error', event)
end)
state.event_manager:subscribe('message.updated', function(event)
log_event('message.updated', event)
end)
state.event_manager:subscribe('message.removed', function(event)
log_event('message.removed', event)
end)
state.event_manager:subscribe('message.part.updated', function(event)
log_event('message.part.updated', event)
end)
state.event_manager:subscribe('message.removed', function(event)
log_event('message.removed', event)
end)
state.event_manager:subscribe('permission.updated', function(event)
log_event('permission.updated', event)
end)
state.event_manager:subscribe('permission.replied', function(event)
log_event('permission.replied', event)
end)
local events = {
'session.updated',
'session.compacted',
'session.error',
'session.idle',
'message.updated',
'message.removed',
'message.part.updated',
'message.removed',
'permission.updated',
'permission.replied',
'file.edited',
'server.connected',
}

for _, event_name in ipairs(events) do
state.event_manager:subscribe(event_name, function(event)
log_event(event_name, event)
end)
end
end

return M
110 changes: 110 additions & 0 deletions tests/unit/hooks_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
local renderer = require('opencode.ui.renderer')
local config = require('opencode.config')
local state = require('opencode.state')
local helpers = require('tests.helpers')
local ui = require('opencode.ui.ui')

describe('hooks', function()
before_each(function()
helpers.replay_setup()
config.hooks = {
on_file_edited = nil,
on_session_loaded = nil,
}
end)

after_each(function()
if state.windows then
ui.close_windows(state.windows)
end
config.hooks = {
on_file_edited = nil,
on_session_loaded = nil,
}
end)

describe('on_file_edited', function()
it('should call hook when file is edited', function()
local called = false
local file_path = nil

config.hooks.on_file_edited = function(file)
called = true
file_path = file
end

local test_event = { file = '/test/file.lua' }
renderer.on_file_edited(test_event)

assert.is_true(called)
assert.are.equal('/test/file.lua', file_path)
end)

it('should not error when hook is nil', function()
config.hooks.on_file_edited = nil

local test_event = { file = '/test/file.lua' }
assert.has_no.errors(function()
renderer.on_file_edited(test_event)
end)
end)

it('should not crash when hook throws error', function()
config.hooks.on_file_edited = function()
error('test error')
end

local test_event = { file = '/test/file.lua' }
assert.has_no.errors(function()
renderer.on_file_edited(test_event)
end)
end)
end)

describe('on_session_loaded', function()
it('should call hook when session is loaded', function()
local called = false
local session_data = nil

config.hooks.on_session_loaded = function(session)
called = true
session_data = session
end

local events = helpers.load_test_data('tests/data/simple-session.json')
state.active_session = helpers.get_session_from_events(events, true)
local loaded_session = helpers.load_session_from_events(events)

renderer._render_full_session_data(loaded_session)

assert.is_true(called)
assert.are.same(state.active_session, session_data)
end)

it('should not error when hook is nil', function()
config.hooks.on_session_loaded = nil

local events = helpers.load_test_data('tests/data/simple-session.json')
state.active_session = helpers.get_session_from_events(events, true)
local loaded_session = helpers.load_session_from_events(events)

assert.has_no.errors(function()
renderer._render_full_session_data(loaded_session)
end)
end)

it('should not crash when hook throws error', function()
config.hooks.on_session_loaded = function()
error('test error')
end

local events = helpers.load_test_data('tests/data/simple-session.json')
state.active_session = helpers.get_session_from_events(events, true)
local loaded_session = helpers.load_session_from_events(events)

assert.has_no.errors(function()
renderer._render_full_session_data(loaded_session)
end)
end)
end)
end)