diff --git a/lua/opencode/state.lua b/lua/opencode/state.lua deleted file mode 100644 index 3a61de36..00000000 --- a/lua/opencode/state.lua +++ /dev/null @@ -1 +0,0 @@ -return require('opencode.state.init') --[[@as OpencodeState]] diff --git a/lua/opencode/ui/debug_helper.lua b/lua/opencode/ui/debug_helper.lua index f1c2b301..eda59565 100644 --- a/lua/opencode/ui/debug_helper.lua +++ b/lua/opencode/ui/debug_helper.lua @@ -29,7 +29,7 @@ function M.debug_output() end function M.debug_message() - local renderer = require('opencode.ui.renderer') + local render_state = require('opencode.ui.renderer.ctx').render_state if not state.windows or not state.windows.output_win then vim.notify('Output window not available', vim.log.levels.WARN) return @@ -38,7 +38,7 @@ function M.debug_message() -- Search backwards from current line to find nearest message for line = current_line, 1, -1 do - local message_data = renderer._render_state:get_message_at_line(line) + local message_data = render_state:get_message_at_line(line) if message_data and message_data.message then M.open_json_file(message_data.message) return diff --git a/lua/opencode/ui/output_window.lua b/lua/opencode/ui/output_window.lua index 8c9271b6..d9f8cc5e 100644 --- a/lua/opencode/ui/output_window.lua +++ b/lua/opencode/ui/output_window.lua @@ -4,6 +4,37 @@ local config = require('opencode.config') local M = {} M.namespace = vim.api.nvim_create_namespace('opencode_output') +local _update_depth = 0 +local _update_buf = nil + +---Begin a batch of buffer writes — toggle modifiable once for the whole batch. +---Returns true if the batch was opened (buffer is valid). Must be paired with end_update(). +---@return boolean +function M.begin_update() + local windows = state.windows + if not windows or not windows.output_buf then + return false + end + if _update_depth == 0 then + _update_buf = windows.output_buf + vim.api.nvim_set_option_value('modifiable', true, { buf = _update_buf }) + end + _update_depth = _update_depth + 1 + return true +end + +---End a batch started by begin_update(). +function M.end_update() + if _update_depth == 0 then + return + end + _update_depth = _update_depth - 1 + if _update_depth == 0 and _update_buf then + vim.api.nvim_set_option_value('modifiable', false, { buf = _update_buf }) + _update_buf = nil + end +end + function M.create_buf() local output_buf = vim.api.nvim_create_buf(false, true) local filetype = config.ui.output.filetype or 'opencode_output' @@ -165,12 +196,11 @@ function M.update_dimensions(windows) end function M.get_buf_line_count() - if not M.buffer_valid() then + local windows = state.windows + if not windows or not windows.output_buf or not vim.api.nvim_buf_is_valid(windows.output_buf) then return 0 end - ---@cast state.windows { output_buf: integer } - - return vim.api.nvim_buf_line_count(state.windows.output_buf) + return vim.api.nvim_buf_line_count(windows.output_buf) end ---Set the output buffer contents @@ -178,17 +208,22 @@ end ---@param start_line? integer The starting line to set, defaults to 0 ---@param end_line? integer The last line to set, defaults to -1 function M.set_lines(lines, start_line, end_line) - if not M.buffer_valid() then + local windows = state.windows + if not windows or not windows.output_buf or not vim.api.nvim_buf_is_valid(windows.output_buf) then return end - ---@cast state.windows { output_buf: integer } + local buf = windows.output_buf start_line = start_line or 0 end_line = end_line or -1 - vim.api.nvim_set_option_value('modifiable', true, { buf = state.windows.output_buf }) - vim.api.nvim_buf_set_lines(state.windows.output_buf, start_line, end_line, false, lines) - vim.api.nvim_set_option_value('modifiable', false, { buf = state.windows.output_buf }) + if _update_depth == 0 then + vim.api.nvim_set_option_value('modifiable', true, { buf = buf }) + vim.api.nvim_buf_set_lines(buf, start_line, end_line, false, lines) + vim.api.nvim_set_option_value('modifiable', false, { buf = buf }) + else + vim.api.nvim_buf_set_lines(buf, start_line, end_line, false, lines) + end end ---Clear output buf extmarks @@ -196,35 +231,32 @@ end ---@param end_line? integer Line to clear until, defaults to -1 ---@param clear_all? boolean If true, clears all extmarks in the buffer function M.clear_extmarks(start_line, end_line, clear_all) - if not M.buffer_valid() then + local windows = state.windows + if not windows or not windows.output_buf or not vim.api.nvim_buf_is_valid(windows.output_buf) then return end - ---@cast state.windows { output_buf: integer } start_line = start_line or 0 end_line = end_line or -1 - pcall( - vim.api.nvim_buf_clear_namespace, - state.windows.output_buf, - clear_all and -1 or M.namespace, - start_line, - end_line - ) + pcall(vim.api.nvim_buf_clear_namespace, windows.output_buf, clear_all and -1 or M.namespace, start_line, end_line) end ---Apply extmarks to the output buffer ---@param extmarks table Extmarks indexed by line ---@param line_offset? integer Line offset to apply to extmarks, defaults to 0 function M.set_extmarks(extmarks, line_offset) - if not M.buffer_valid() or not extmarks or type(extmarks) ~= 'table' then + if not extmarks or type(extmarks) ~= 'table' then + return + end + local windows = state.windows + if not windows or not windows.output_buf or not vim.api.nvim_buf_is_valid(windows.output_buf) then return end - ---@cast state.windows { output_buf: integer } line_offset = line_offset or 0 - local output_buf = state.windows.output_buf + local output_buf = windows.output_buf for line_idx, marks in pairs(extmarks) do for _, mark in ipairs(marks) do diff --git a/lua/opencode/ui/permission_window.lua b/lua/opencode/ui/permission_window.lua index ff421139..af01fab8 100644 --- a/lua/opencode/ui/permission_window.lua +++ b/lua/opencode/ui/permission_window.lua @@ -216,7 +216,7 @@ function M._setup_dialog() end M._processing = true - require('opencode.ui.renderer').render_permissions_display() + require('opencode.ui.renderer.events').render_permissions_display() M._clear_dialog() local api = require('opencode.api') @@ -236,7 +236,7 @@ function M._setup_dialog() end local function on_navigate() - require('opencode.ui.renderer').render_permissions_display() + require('opencode.ui.renderer.events').render_permissions_display() end local function get_option_count() diff --git a/lua/opencode/ui/question_window.lua b/lua/opencode/ui/question_window.lua index 215dde8b..11505e66 100644 --- a/lua/opencode/ui/question_window.lua +++ b/lua/opencode/ui/question_window.lua @@ -13,11 +13,11 @@ M._answering = false M._dialog = nil local function render_question() - require('opencode.ui.renderer').render_question_display() + require('opencode.ui.renderer.events').render_question_display() end local function clear_question() - require('opencode.ui.renderer').clear_question_display() + require('opencode.ui.renderer.events').clear_question_display() end ---@param question_request OpencodeQuestionRequest diff --git a/lua/opencode/ui/render_state.lua b/lua/opencode/ui/render_state.lua index a65382a1..38a61226 100644 --- a/lua/opencode/ui/render_state.lua +++ b/lua/opencode/ui/render_state.lua @@ -1,5 +1,3 @@ -local state = require('opencode.state') - ---@class RenderedMessage ---@field message OpencodeMessage Direct reference to message in state.messages ---@field line_start integer? Line where message header starts @@ -11,16 +9,16 @@ local state = require('opencode.state') ---@field line_start integer? Line where part starts ---@field line_end integer? Line where part ends ---@field actions table[] Actions associated with this part - ----@class LineIndex ----@field line_to_part table Maps line number -> part ID ----@field line_to_message table Maps line number -> message ID +---@field has_extmarks boolean? Whether the part currently has extmarks applied ---@class RenderState ---@field _messages table Message ID -> rendered message ---@field _parts table Part ID -> rendered part ----@field _line_index LineIndex Line number -> ID mappings ----@field _line_index_valid boolean Whether line index is up to date +---@field _part_ranges {[1]: integer, [2]: integer, [3]: string}[] Sorted [line_start, line_end, part_id] for binary search +---@field _message_ranges {[1]: integer, [2]: integer, [3]: string}[] Sorted [line_start, line_end, message_id] for binary search +---@field _ranges_valid boolean Whether range arrays are sorted and up-to-date +---@field _max_line_end integer +---@field _max_line_end_valid boolean local RenderState = {} RenderState.__index = RenderState @@ -34,14 +32,44 @@ end function RenderState:reset() self._messages = {} self._parts = {} - self._line_index = { - line_to_part = {}, - line_to_message = {}, - } + self._part_ranges = {} + self._message_ranges = {} + self._ranges_valid = false + self._max_line_end = 0 + self._max_line_end_valid = true self._child_session_parts = {} + self._child_session_parts_index = {} -- session_id -> part_id -> list_index self._child_session_task_parts = {} self._task_part_child_sessions = {} - self._line_index_valid = false + self._snapshot_id_index = {} -- snapshot_id -> OpencodeMessagePart +end + +function RenderState:_recompute_max_line_end() + local max_line_end = 0 + + for _, msg_data in pairs(self._messages) do + if msg_data.line_end and msg_data.line_end > max_line_end then + max_line_end = msg_data.line_end + end + end + + for _, part_data in pairs(self._parts) do + if part_data.line_end and part_data.line_end > max_line_end then + max_line_end = part_data.line_end + end + end + + self._max_line_end = max_line_end + self._max_line_end_valid = true + return max_line_end +end + +---@return integer +function RenderState:_get_max_line_end() + if not self._max_line_end_valid then + return self:_recompute_max_line_end() + end + return self._max_line_end end ---@param part OpencodeMessagePart? @@ -50,10 +78,8 @@ local function get_child_session_id_for_task_part(part) if not part or part.tool ~= 'task' then return nil end - local part_state = part.state local metadata = part_state and part_state.metadata - return metadata and metadata.sessionId or nil end @@ -63,11 +89,9 @@ function RenderState:_clear_task_part_child_session(part_id) if not child_session_id then return end - if self._child_session_task_parts[child_session_id] == part_id then self._child_session_task_parts[child_session_id] = nil end - self._task_part_child_sessions[part_id] = nil end @@ -75,82 +99,139 @@ end ---@param part OpencodeMessagePart function RenderState:_index_task_part_child_session(part_id, part) self:_clear_task_part_child_session(part_id) - local child_session_id = get_child_session_id_for_task_part(part) if not child_session_id then return end - self._child_session_task_parts[child_session_id] = part_id self._task_part_child_sessions[part_id] = child_session_id end ----Get parts for a child session +---@param ranges {[1]: integer, [2]: integer, [3]: string}[] +---@param line integer +---@return string? +local function range_lookup(ranges, line) + local lo, hi = 1, #ranges + while lo <= hi do + local mid = math.floor((lo + hi) / 2) + local r = ranges[mid] + if line < r[1] then + hi = mid - 1 + elseif line > r[2] then + lo = mid + 1 + else + return r[3] + end + end + return nil +end + +function RenderState:_rebuild_ranges() + local part_ranges = {} + for part_id, part_data in pairs(self._parts) do + if part_data.line_start and part_data.line_end then + part_ranges[#part_ranges + 1] = { part_data.line_start, part_data.line_end, part_id } + end + end + table.sort(part_ranges, function(a, b) + return a[1] < b[1] + end) + self._part_ranges = part_ranges + + local msg_ranges = {} + for msg_id, msg_data in pairs(self._messages) do + if msg_data.line_start and msg_data.line_end then + msg_ranges[#msg_ranges + 1] = { msg_data.line_start, msg_data.line_end, msg_id } + end + end + table.sort(msg_ranges, function(a, b) + return a[1] < b[1] + end) + self._message_ranges = msg_ranges + + self._ranges_valid = true +end + +function RenderState:_ensure_ranges() + if not self._ranges_valid then + self:_rebuild_ranges() + end +end + ---@param session_id string ----@return OpencodeMessagePart[]?|nil +---@return OpencodeMessagePart[]? function RenderState:get_child_session_parts(session_id) if not session_id then return nil end - return self._child_session_parts and self._child_session_parts[session_id] + return self._child_session_parts[session_id] end ----Get the owning task part for a child session ---@param session_id string ---@return string? function RenderState:get_task_part_by_child_session(session_id) if not session_id then return nil end - - return self._child_session_task_parts and self._child_session_task_parts[session_id] + return self._child_session_task_parts[session_id] end ----Upsert a part associated with a child session ---@param session_id string ---@param part OpencodeMessagePart function RenderState:upsert_child_session_part(session_id, part) if not session_id or not part or not part.id then return end - self._child_session_parts = self._child_session_parts or {} - local session_parts = self._child_session_parts[session_id] or {} - local found = false - for i, existing in ipairs(session_parts) do - if existing.id == part.id then - session_parts[i] = part - found = true - break - end + + local session_parts = self._child_session_parts[session_id] + if not session_parts then + session_parts = {} + self._child_session_parts[session_id] = session_parts + self._child_session_parts_index[session_id] = {} end - if not found then - table.insert(session_parts, part) + + local idx = self._child_session_parts_index[session_id][part.id] + if idx then + session_parts[idx] = part + else + session_parts[#session_parts + 1] = part + self._child_session_parts_index[session_id][part.id] = #session_parts end - self._child_session_parts[session_id] = session_parts end ----Get message render data by ID ----@param message_id string Message ID +---@param message_id string ---@return RenderedMessage? function RenderState:get_message(message_id) return self._messages[message_id] end ----Get part render data by ID ----@param part_id string Part ID +---@param line integer 1-indexed +---@return RenderedMessage? +function RenderState:get_message_at_line(line) + self:_ensure_ranges() + local msg_id = range_lookup(self._message_ranges, line) + return msg_id and self._messages[msg_id] or nil +end + +---@param part_id string ---@return RenderedPart? function RenderState:get_part(part_id) return self._parts[part_id] end ----Get part ID by call ID and message ID ----@param call_id string Call ID ----@param message_id string Message ID to check the parts of ----@return string? part_id Part ID if found +---@param line integer 1-indexed +---@return RenderedPart? +function RenderState:get_part_at_line(line) + self:_ensure_ranges() + local part_id = range_lookup(self._part_ranges, line) + return part_id and self._parts[part_id] or nil +end + +---@param call_id string +---@param message_id string +---@return string? function RenderState:get_part_by_call_id(call_id, message_id) local rendered_message = self._messages[message_id] - -- There aren't a lot of parts per message and call_id lookups aren't very common so - -- a little iteration is fine if rendered_message and rendered_message.message and rendered_message.message.parts then for _, part in ipairs(rendered_message.message.parts) do if part.callID == call_id then @@ -161,57 +242,17 @@ function RenderState:get_part_by_call_id(call_id, message_id) return nil end ----Get part ID by snapshot_id and message ID ----@param snapshot_id string Call ID ----@return OpencodeMessagePart? part Part if found +---@param snapshot_id string +---@return OpencodeMessagePart? function RenderState:get_part_by_snapshot_id(snapshot_id) - for _, rendered_message in pairs(self._messages or {}) do - for _, part in ipairs(rendered_message.message.parts or {}) do - if part.type == 'patch' and part.hash == snapshot_id then - return part - end - end - end - return nil -end - ----Ensure line index is up to date -function RenderState:_ensure_line_index() - if not self._line_index_valid then - self:_rebuild_line_index() - end -end - ----Get part at specific line ----@param line integer Line number (1-indexed) ----@return RenderedPart? -function RenderState:get_part_at_line(line) - self:_ensure_line_index() - local part_id = self._line_index.line_to_part[line] - if not part_id then - return nil - end - return self._parts[part_id] -end - ----Get message at specific line ----@param line integer Line number (1-indexed) ----@return RenderedMessage? -function RenderState:get_message_at_line(line) - self:_ensure_line_index() - local message_id = self._line_index.line_to_message[line] - if not message_id then - return nil - end - return self._messages[message_id] + return self._snapshot_id_index[snapshot_id] end ----Get actions at specific line ----@param line integer Line number (0-indexed) ----@return table[] List of actions at that line +---@param line integer +---@return table[] function RenderState:get_actions_at_line(line) - self:_ensure_line_index() - local part_id = self._line_index.line_to_part[line] + self:_ensure_ranges() + local part_id = range_lookup(self._part_ranges, line) if not part_id then return {} end @@ -224,107 +265,167 @@ function RenderState:get_actions_at_line(line) local actions = {} for _, action in ipairs(part_data.actions) do if action.range and action.range.from <= line and action.range.to >= line then - table.insert(actions, action) + actions[#actions + 1] = action end end return actions end ----Set or update message render data ----@param message OpencodeMessage Direct reference to message ----@param line_start integer? Line where message header starts ----@param line_end integer? Line where message header ends +---@param part_id string +---@param actions table[] +---@param offset? integer Line offset to apply to action line numbers +function RenderState:add_actions(part_id, actions, offset) + local part_data = self._parts[part_id] + if not part_data then + return + end + offset = offset or 0 + for _, action in ipairs(actions) do + if offset ~= 0 then + if action.display_line then + action.display_line = action.display_line + offset + end + if action.range then + action.range.from = action.range.from + offset + action.range.to = action.range.to + offset + end + end + part_data.actions[#part_data.actions + 1] = action + end +end + +---@param part_id string +function RenderState:clear_actions(part_id) + local part_data = self._parts[part_id] + if part_data then + part_data.actions = {} + end +end + +---@return table[] +function RenderState:get_all_actions() + local all_actions = {} + for _, part_data in pairs(self._parts) do + if part_data.actions then + for _, action in ipairs(part_data.actions) do + all_actions[#all_actions + 1] = action + end + end + end + return all_actions +end + +---@param message OpencodeMessage +---@param line_start integer? +---@param line_end integer? function RenderState:set_message(message, line_start, line_end) if not message or not message.info or not message.info.id then return end local message_id = message.info.id - if not self._messages[message_id] then + local existing = self._messages[message_id] + if not existing then self._messages[message_id] = { message = message, line_start = line_start, line_end = line_end, } else - local msg_data = self._messages[message_id] - msg_data.message = message + existing.message = message if line_start then - msg_data.line_start = line_start + existing.line_start = line_start end if line_end then - msg_data.line_end = line_end + existing.line_end = line_end end end if line_start and line_end then - self._line_index_valid = false + self._ranges_valid = false + if self._max_line_end_valid and line_end > self._max_line_end then + self._max_line_end = line_end + end end end ----Set or update part render data ----@param part OpencodeMessagePart Direct reference to part (must include id/messageID) ----@param line_start integer? Line where part starts ----@param line_end integer? Line where part ends +---@param part OpencodeMessagePart +---@param line_start integer? +---@param line_end integer? function RenderState:set_part(part, line_start, line_end) if not part or not part.id then return end - - -- Allow special parts (like permissions) without messageID local part_id = part.id local message_id = part.messageID or 'special' - if not self._parts[part_id] then + local existing = self._parts[part_id] + if not existing then self._parts[part_id] = { part = part, message_id = message_id, line_start = line_start, line_end = line_end, actions = {}, + has_extmarks = false, } else - local render_part = self._parts[part_id] - render_part.part = part + existing.part = part if message_id then - render_part.message_id = message_id + existing.message_id = message_id end if line_start then - render_part.line_start = line_start + existing.line_start = line_start end if line_end then - render_part.line_end = line_end + existing.line_end = line_end end end if line_start and line_end then - self._line_index_valid = false + self._ranges_valid = false + if self._max_line_end_valid and line_end > self._max_line_end then + self._max_line_end = line_end + end + end + + if part.type == 'patch' and part.hash then + self._snapshot_id_index[part.hash] = part end self:_index_task_part_child_session(part_id, part) end ----Update part line positions and shift subsequent content ----@param part_id string Part ID ----@param new_line_start integer New start line ----@param new_line_end integer New end line ----@return boolean success +---@param part_id string +---@param new_line_start integer +---@param new_line_end integer +---@return boolean function RenderState:update_part_lines(part_id, new_line_start, new_line_end) local part_data = self._parts[part_id] if not part_data or not part_data.line_start or not part_data.line_end then return false end - local old_line_start = part_data.line_start + if part_data.line_start == new_line_start and part_data.line_end == new_line_end then + return true + end + local old_line_end = part_data.line_end - local old_line_count = old_line_end - old_line_start + 1 + local old_line_count = old_line_end - part_data.line_start + 1 local new_line_count = new_line_end - new_line_start + 1 local delta = new_line_count - old_line_count part_data.line_start = new_line_start part_data.line_end = new_line_end + self._ranges_valid = false - self._line_index_valid = false + if self._max_line_end_valid then + if old_line_end == self._max_line_end and new_line_end < old_line_end then + self._max_line_end_valid = false + elseif new_line_end > self._max_line_end then + self._max_line_end = new_line_end + end + end if delta ~= 0 then self:shift_all(old_line_end + 1, delta) @@ -333,9 +434,8 @@ function RenderState:update_part_lines(part_id, new_line_start, new_line_end) return true end ----Update part data reference ----@param part_ref OpencodeMessagePart New part reference (must include id) ----@return RenderedPart? part The rendered part +---@param part_ref OpencodeMessagePart +---@return RenderedPart? function RenderState:update_part_data(part_ref) if not part_ref or not part_ref.id then return @@ -344,79 +444,28 @@ function RenderState:update_part_data(part_ref) if not rendered_part then return end - rendered_part.part = part_ref - self:_index_task_part_child_session(part_ref.id, part_ref) - return rendered_part -end ----Helper to update action line numbers ----@param action table Action to update ----@param delta integer Line offset to apply -local function shift_action_lines(action, delta) - if action.display_line then - action.display_line = action.display_line + delta - end - if action.range then - action.range.from = action.range.from + delta - action.range.to = action.range.to + delta - end -end - ----Add actions to a part ----@param part_id string Part ID ----@param actions table[] Actions to add ----@param offset? integer Optional line offset to apply to actions -function RenderState:add_actions(part_id, actions, offset) - local part_data = self._parts[part_id] - if not part_data then - return + if part_ref.type == 'patch' and part_ref.hash then + self._snapshot_id_index[part_ref.hash] = part_ref end - offset = offset or 0 - - for _, action in ipairs(actions) do - if offset ~= 0 then - shift_action_lines(action, offset) - end - table.insert(part_data.actions, action) - end -end - ----Clear actions for a part ----@param part_id string Part ID -function RenderState:clear_actions(part_id) - local part_data = self._parts[part_id] - if not part_data then - return - end - - part_data.actions = {} -end - ----Get all actions from all parts ----@return table[] List of all actions -function RenderState:get_all_actions() - local all_actions = {} - for _, part_data in pairs(self._parts) do - if part_data.actions then - for _, action in ipairs(part_data.actions) do - table.insert(all_actions, action) - end - end - end - return all_actions + self:_index_task_part_child_session(part_ref.id, part_ref) + return rendered_part end ----Remove part and shift subsequent content ----@param part_id string Part ID ----@return boolean success +---@param part_id string +---@return boolean function RenderState:remove_part(part_id) local part_data = self._parts[part_id] if not part_data then return false end + if part_data.part and part_data.part.type == 'patch' and part_data.part.hash then + self._snapshot_id_index[part_data.part.hash] = nil + end + self:_clear_task_part_child_session(part_id) if not part_data.line_start or not part_data.line_end then @@ -428,16 +477,17 @@ function RenderState:remove_part(part_id) local shift_from = part_data.line_end + 1 self._parts[part_id] = nil - self._line_index_valid = false + self._ranges_valid = false + if self._max_line_end_valid and part_data.line_end == self._max_line_end then + self._max_line_end_valid = false + end self:shift_all(shift_from, -line_count) - return true end ----Remove message (header only, not parts) ----@param message_id string Message ID ----@return boolean success +---@param message_id string +---@return boolean function RenderState:remove_message(message_id) local msg_data = self._messages[message_id] if not msg_data or not msg_data.line_start or not msg_data.line_end then @@ -448,97 +498,61 @@ function RenderState:remove_message(message_id) local shift_from = msg_data.line_end + 1 self._messages[message_id] = nil - self._line_index_valid = false + self._ranges_valid = false + if self._max_line_end_valid and msg_data.line_end == self._max_line_end then + self._max_line_end_valid = false + end self:shift_all(shift_from, -line_count) - return true end ----Shift all content starting from a line by delta ----Optimized to scan in reverse order and exit early ----@param from_line integer Line number to start shifting from ----@param delta integer Number of lines to shift (positive or negative) +local function shift_action(action, delta) + if action.display_line then + action.display_line = action.display_line + delta + end + if action.range then + action.range.from = action.range.from + delta + action.range.to = action.range.to + delta + end +end + function RenderState:shift_all(from_line, delta) - if delta == 0 or not state.messages then + if delta == 0 then return end - local found_content_before_from_line = false - local anything_shifted = false - - for i = #state.messages, 1, -1 do - local message = state.messages[i] or {} + if from_line > self:_get_max_line_end() then + return + end - local msg_id = message.info and message.info.id - if msg_id then - local rendered_msg = self._messages[msg_id] - if rendered_msg and rendered_msg.line_start and rendered_msg.line_end then - if rendered_msg.line_start >= from_line then - rendered_msg.line_start = rendered_msg.line_start + delta - rendered_msg.line_end = rendered_msg.line_end + delta - anything_shifted = true - elseif rendered_msg.line_end < from_line then - found_content_before_from_line = true - end - end - end + local shifted = false - if message.parts then - for j = #message.parts, 1, -1 do - local part = message.parts[j] - if part.id then - local rendered_part = self._parts[part.id] - if rendered_part and rendered_part.line_start and rendered_part.line_end then - if rendered_part.line_start >= from_line then - rendered_part.line_start = rendered_part.line_start + delta - rendered_part.line_end = rendered_part.line_end + delta - anything_shifted = true - - if rendered_part.actions then - for _, action in ipairs(rendered_part.actions) do - shift_action_lines(action, delta) - end - end - elseif rendered_part.line_end < from_line then - found_content_before_from_line = true - end - end - end - end + for _, msg_data in pairs(self._messages) do + if msg_data.line_start and msg_data.line_start >= from_line then + msg_data.line_start = msg_data.line_start + delta + msg_data.line_end = msg_data.line_end + delta + shifted = true end - - if found_content_before_from_line then - break - end - end - - if anything_shifted then - self._line_index_valid = false end -end - ----Rebuild line index from current state -function RenderState:_rebuild_line_index() - self._line_index.line_to_part = {} - self._line_index.line_to_message = {} - for msg_id, msg_data in pairs(self._messages) do - if msg_data.line_start and msg_data.line_end then - for line = msg_data.line_start, msg_data.line_end do - self._line_index.line_to_message[line] = msg_id + for _, part_data in pairs(self._parts) do + if part_data.line_start and part_data.line_start >= from_line then + part_data.line_start = part_data.line_start + delta + part_data.line_end = part_data.line_end + delta + shifted = true + for _, action in ipairs(part_data.actions) do + shift_action(action, delta) end end end - for part_id, part_data in pairs(self._parts) do - if part_data.line_start and part_data.line_end then - for line = part_data.line_start, part_data.line_end do - self._line_index.line_to_part[line] = part_id - end + if shifted then + self._ranges_valid = false + if self._max_line_end_valid then + self._max_line_end = self._max_line_end + delta end end - self._line_index_valid = true end return RenderState diff --git a/lua/opencode/ui/renderer.lua b/lua/opencode/ui/renderer.lua index f28d3f0f..690768ab 100644 --- a/lua/opencode/ui/renderer.lua +++ b/lua/opencode/ui/renderer.lua @@ -4,28 +4,24 @@ local formatter = require('opencode.ui.formatter') local output_window = require('opencode.ui.output_window') local permission_window = require('opencode.ui.permission_window') local Promise = require('opencode.promise') -local RenderState = require('opencode.ui.render_state') +local ctx = require('opencode.ui.renderer.ctx') +local buf = require('opencode.ui.renderer.buffer') +local events = require('opencode.ui.renderer.events') -local M = { - _prev_line_count = 0, - _render_state = RenderState.new(), - _last_part_formatted = { - part_id = nil, - formatted_data = nil --[[@as Output|nil]], - }, -} +local M = {} + +-- Expose event handlers on M so tests can call them directly and subscriptions +-- can be stubbed cleanly (e.g. stub(renderer, '_render_full_session_data')) +M.on_session_updated = events.on_session_updated local trigger_on_data_rendered = require('opencode.util').debounce(function() local cb_type = type(config.ui.output.rendering.on_data_rendered) - if cb_type == 'boolean' then return end - if not state.windows or not state.windows.output_buf or not state.windows.output_win then return end - if cb_type == 'function' then pcall(config.ui.output.rendering.on_data_rendered, state.windows.output_buf, state.windows.output_win) elseif vim.fn.exists(':RenderMarkdown') > 0 then @@ -35,12 +31,9 @@ local trigger_on_data_rendered = require('opencode.util').debounce(function() end end, config.ui.output.rendering.markdown_debounce_ms or 250) ----Reset renderer state +---Reset all renderer state and clear the output buffer function M.reset() - M._prev_line_count = 0 - M._render_state:reset() - M._last_part_formatted = { part_id = nil, formatted_data = nil } - + ctx:reset() output_window.clear() local permissions = state.pending_permissions or {} @@ -55,8 +48,14 @@ function M.reset() trigger_on_data_rendered() end ----Set up event subscriptions ----@param subscribe? boolean false to unsubscribe +---Unsubscribe from all events and reset +function M.teardown() + M.setup_subscriptions(false) + M.reset() +end + +---Subscribe to (or unsubscribe from) all renderer events +---@param subscribe? boolean false to unsubscribe (default true) function M.setup_subscriptions(subscribe) subscribe = subscribe == nil and true or subscribe @@ -72,26 +71,26 @@ function M.setup_subscriptions(subscribe) return end - local event_subscriptions = { - { 'session.updated', M.on_session_updated }, - { 'session.compacted', M.on_session_compacted }, - { 'session.error', M.on_session_error }, - { 'message.updated', M.on_message_updated }, - { 'message.removed', M.on_message_removed }, - { 'message.part.updated', M.on_part_updated }, - { 'message.part.removed', M.on_part_removed }, - { 'permission.updated', M.on_permission_updated }, - { 'permission.asked', M.on_permission_updated }, - { 'permission.replied', M.on_permission_replied }, - { 'question.asked', M.on_question_asked }, - { 'question.replied', M.clear_question_display }, - { 'question.rejected', M.clear_question_display }, - { 'file.edited', M.on_file_edited }, - { 'custom.restore_point.created', M.on_restore_points }, - { 'custom.emit_events.finished', M.on_emit_events_finished }, + local subs = { + { 'session.updated', events.on_session_updated }, + { 'session.compacted', events.on_session_compacted }, + { 'session.error', events.on_session_error }, + { 'message.updated', events.on_message_updated }, + { 'message.removed', events.on_message_removed }, + { 'message.part.updated', events.on_part_updated }, + { 'message.part.removed', events.on_part_removed }, + { 'permission.updated', events.on_permission_updated }, + { 'permission.asked', events.on_permission_updated }, + { 'permission.replied', events.on_permission_replied }, + { 'question.asked', events.on_question_asked }, + { 'question.replied', events.clear_question_display }, + { 'question.rejected', events.clear_question_display }, + { 'file.edited', events.on_file_edited }, + { 'custom.restore_point.created', events.on_restore_points }, + { 'custom.emit_events.finished', M.on_emit_events_finished }, } - for _, sub in ipairs(event_subscriptions) do + for _, sub in ipairs(subs) do if subscribe then state.event_manager:subscribe(sub[1], sub[2]) else @@ -100,35 +99,39 @@ function M.setup_subscriptions(subscribe) end end ----Clean up and teardown renderer. Unsubscribes from all events -function M.teardown() - M.setup_subscriptions(false) - M.reset() -end - ----Fetch full session messages from server ----@return Promise Promise resolving to list of OpencodeMessage +---Fetch all messages for the active session from the server +---@return Promise local function fetch_session() local session = state.active_session - if not session or not session or session == '' then + if not session or session == '' then return Promise.new():resolve(nil) end - state.renderer.set_last_user_message(nil) return require('opencode.session').get_messages(session) end ----Request all of the session data from the opencode server and render it ----@return Promise -function M.render_full_session() - if not output_window.mounted() or not state.api_client then - return Promise.new():resolve(nil) +---Set the current model/mode from the most recent assistant message +local function set_model_and_mode_from_messages() + if not state.messages then + return end - - return fetch_session():and_then(M._render_full_session_data) + for i = #state.messages, 1, -1 do + local message = state.messages[i] + if message and message.info and message.info.modelID and message.info.providerID then + state.model.set_model(message.info.providerID .. '/' .. message.info.modelID) + if message.info.mode then + state.model.set_mode(message.info.mode) + end + return + end + end + require('opencode.core').initialize_current_model() end -function M._render_full_session_data(session_data, prev_revert, revert) +---Render all messages and parts from session_data into the output buffer +---Called after a full session fetch or when revert state changes +---@param session_data OpencodeMessage[] +function M._render_full_session_data(session_data) M.reset() if not state.active_session or not state.messages then @@ -136,30 +139,26 @@ function M._render_full_session_data(session_data, prev_revert, revert) end local revert_index = nil - - -- if we're loading a session and there's no currently selected model, set it - -- from the messages local set_mode_from_messages = not state.current_model for i, msg in ipairs(session_data) do if state.active_session.revert and state.active_session.revert.messageID == msg.info.id then revert_index = i end - - M.on_message_updated({ info = msg.info }, revert_index) - + events.on_message_updated({ info = msg.info }, revert_index) for _, part in ipairs(msg.parts or {}) do - M.on_part_updated({ part = part }, revert_index) + events.on_part_updated({ part = part }, revert_index) end end if revert_index then - M._write_formatted_data(formatter._format_revert_message(state.messages, revert_index)) + buf.write_formatted_data(formatter._format_revert_message(state.messages, revert_index)) end if set_mode_from_messages then - M._set_model_and_mode_from_messages() + set_model_and_mode_from_messages() end + M.scroll_to_bottom(true) if config.hooks and config.hooks.on_session_loaded then @@ -167,150 +166,38 @@ function M._render_full_session_data(session_data, prev_revert, revert) end end ----Append permissions display as a fake part at the end -function M.render_permissions_display() - local permissions = permission_window.get_all_permissions() - if not permissions or #permissions == 0 then - M._remove_part_from_buffer('permission-display-part') - M._remove_message_from_buffer('permission-display-message') - return - end - local fake_message = { - info = { - id = 'permission-display-message', - sessionID = state.active_session and state.active_session.id or '', - role = 'system', - }, - parts = {}, - } - M.on_message_updated(fake_message --[[@as OpencodeMessage]]) - - local fake_part = { - id = 'permission-display-part', - messageID = 'permission-display-message', - sessionID = state.active_session and state.active_session.id or '', - type = 'permissions-display', - } - - M.on_part_updated({ part = fake_part }) - M.scroll_to_bottom(true) -end - -function M.clear_question_display() - local config_module = require('opencode.config') - local use_vim_ui = config_module.ui.questions and config_module.ui.questions.use_vim_ui_select - - if use_vim_ui then - -- When using vim.ui.select, there's nothing to clear from the buffer - local question_window = require('opencode.ui.question_window') - question_window.clear_question() - return - end - - local question_window = require('opencode.ui.question_window') - question_window.clear_question() - M._remove_part_from_buffer('question-display-part') - M._remove_message_from_buffer('question-display-message') -end - ----Render question display as a fake part -function M.render_question_display() - local use_vim_ui = config.ui.questions and config.ui.questions.use_vim_ui_select - - if use_vim_ui then - -- When using vim.ui.select, we don't render anything in the buffer - return - end - - local question_window = require('opencode.ui.question_window') - - local current_question = question_window._current_question - - if not question_window.has_question() or not current_question or not current_question.id then - M._remove_part_from_buffer('question-display-part') - M._remove_message_from_buffer('question-display-message') - return +---Fetch the active session from the server and render it +---@return Promise +function M.render_full_session() + if not output_window.mounted() or not state.api_client then + return Promise.new():resolve(nil) end - - local message_id = 'question-display-message' - local part_id = 'question-display-part' - - local fake_message = { - info = { - id = message_id, - sessionID = state.active_session and state.active_session.id or '', - role = 'system', - }, - parts = {}, - } - M.on_message_updated(fake_message --[[@as OpencodeMessage]]) - - local fake_part = { - id = part_id, - messageID = message_id, - sessionID = state.active_session and state.active_session.id or '', - type = 'questions-display', - } - - M.on_part_updated({ part = fake_part }) - M.scroll_to_bottom(true) + return fetch_session():and_then(M._render_full_session_data) end ----Render lines as the entire output buffer ----@param lines any +---Replace the entire output buffer with the given lines +---@param lines string[] function M.render_lines(lines) local output = require('opencode.ui.output'):new() output.lines = lines M.render_output(output) end ----Sets the entire output buffer based on output_data ----@param output_data Output Output object from formatter +---Replace the entire output buffer with formatted output data +---@param output_data Output function M.render_output(output_data) if not output_window.buffer_valid() then return end - - local lines = output_data.lines or {} - - output_window.set_lines(lines) + output_window.set_lines(output_data.lines or {}) output_window.clear_extmarks() output_window.set_extmarks(output_data.extmarks) M.scroll_to_bottom() end ----Called when EventManager has finished emitting a batch of events -function M.on_emit_events_finished() - M.scroll_to_bottom() -end - ----Find the most recently used model from the messages -function M._set_model_and_mode_from_messages() - if not state.messages then - return - end - - for i = #state.messages, 1, -1 do - local message = state.messages[i] - - if message and message.info then - if message.info.modelID and message.info.providerID then - state.model.set_model(message.info.providerID .. '/' .. message.info.modelID) - if message.info.mode then - state.model.set_mode(message.info.mode) - end - return - end - end - end - - -- we didn't find a model from any messages, set it to the default - require('opencode.core').initialize_current_model() -end - ----Auto-scroll to bottom if user was already at bottom ----Respects cursor position if user has scrolled up ----@param force? boolean If true, scroll regardless of current position +---Scroll the output window to the bottom. +---Respects the user's scroll position unless force=true or conditions allow it. +---@param force? boolean function M.scroll_to_bottom(force) local windows = state.windows local output_win = windows and windows.output_win @@ -319,7 +206,6 @@ function M.scroll_to_bottom(force) if not output_buf or not output_win then return end - if not vim.api.nvim_win_is_valid(output_win) then return end @@ -329,48 +215,18 @@ function M.scroll_to_bottom(force) return end - local prev_line_count = M._prev_line_count or 0 - - ---@cast line_count integer - M._prev_line_count = line_count + local prev_line_count = ctx.prev_line_count + ctx.prev_line_count = line_count trigger_on_data_rendered() - local scroll_conditions = { - { - name = 'force', - test = function() - return force == true - end, - }, - { - name = 'first_render', - test = function() - return prev_line_count == 0 - end, - }, - { - name = 'always_scroll', - test = function() - return config.ui.output.always_scroll_to_bottom - end, - }, - { - name = 'cursor_at_bottom', - test = function() - local ok_cursor, cursor = pcall(vim.api.nvim_win_get_cursor, output_win) - return ok_cursor and cursor and (cursor[1] >= prev_line_count or cursor[1] >= line_count) - end, - }, - } - - local should_scroll = false - for _, condition in ipairs(scroll_conditions) do - if condition.test() then - should_scroll = true - break - end - end + local should_scroll = force + or prev_line_count == 0 + or config.ui.output.always_scroll_to_bottom + or (function() + local ok_cursor, cursor = pcall(vim.api.nvim_win_get_cursor, output_win) + return ok_cursor and cursor and (cursor[1] >= prev_line_count or cursor[1] >= line_count) + end)() if should_scroll then vim.api.nvim_win_set_cursor(output_win, { line_count, 0 }) @@ -380,837 +236,43 @@ function M.scroll_to_bottom(force) end end ----Write data to output_buf, including normal text and extmarks ----@param formatted_data Output Formatted data as Output object ----@param part_id? string Optional part ID to store actions ----@param start_line? integer Optional line to insert at (shifts content down). If nil, appends to end of buffer. ----@return {line_start: integer, line_end: integer}? Range where data was written -function M._write_formatted_data(formatted_data, part_id, start_line) - if not state.windows or not state.windows.output_buf then - return - end - - local buf = state.windows.output_buf - local is_insertion = start_line ~= nil - local target_line = start_line or output_window.get_buf_line_count() - local new_lines = formatted_data.lines - local extmarks = formatted_data.extmarks - - if #new_lines == 0 or not buf then - return nil - end - - if is_insertion then - output_window.set_lines(new_lines, target_line, target_line) - else - local extra_newline = vim.tbl_extend('keep', {}, new_lines) - table.insert(extra_newline, '') - target_line = target_line - 1 - output_window.set_lines(extra_newline, target_line) - end - - -- update actions and extmarks after the insertion because that may - -- adjust target_line (e.g. when we we're replacing the double newline at - -- the end) - - if part_id and formatted_data.actions then - M._render_state:add_actions(part_id, formatted_data.actions, target_line) - end - - output_window.set_extmarks(extmarks, target_line) - - return { - line_start = target_line, - line_end = target_line + #new_lines - 1, - } -end - ----Insert new part, either at end of buffer or in the middle for out-of-order parts ----@param part_id string Part ID ----@param formatted_data Output Formatted data as Output object ----@return boolean Success status -function M._insert_part_to_buffer(part_id, formatted_data) - local cached = M._render_state:get_part(part_id) - if not cached then - return false - end - - if #formatted_data.lines == 0 then - return true - end - - local is_current_message = state.current_message - and state.current_message.info - and state.current_message.info.id == cached.message_id - - if is_current_message then - -- NOTE: we're inserting a part for the current message, just add it to the end - - local range = M._write_formatted_data(formatted_data, part_id) - if not range then - return false - end - - M._render_state:set_part(cached.part, range.line_start, range.line_end) - - M._last_part_formatted = { part_id = part_id, formatted_data = formatted_data } - - return true - end - - -- NOTE: We're inserting a part for the first time for a previous message. We need to find - -- the insertion line (after the last part of this message or after the message header if - -- no parts). - local insertion_line = M._get_insertion_point_for_part(part_id, cached.message_id) - if not insertion_line then - return false - end - - local range = M._write_formatted_data(formatted_data, part_id, insertion_line) - if not range then - return false - end - - local line_count = #formatted_data.lines - M._render_state:shift_all(insertion_line, line_count) - - M._render_state:set_part(cached.part, range.line_start, range.line_end) - - return true -end - ----Replace existing part in buffer ----Adjusts line positions of subsequent parts if line count changes ----@param part_id string Part ID ----@param formatted_data Output Formatted data as Output object ----@return boolean Success status -function M._replace_part_in_buffer(part_id, formatted_data) - local cached = M._render_state:get_part(part_id) - if not cached or not cached.line_start or not cached.line_end then - return false - end - - local new_lines = formatted_data.lines - local new_line_count = #new_lines - - local old_formatted = M._last_part_formatted - local can_optimize = old_formatted - and old_formatted.part_id == part_id - and old_formatted.formatted_data - and old_formatted.formatted_data.lines - - local lines_to_write = new_lines - local write_start_line = cached.line_start - - if can_optimize then - -- NOTE: This is an optimization to only replace the lines that are different - -- if we're replacing the most recently formatted part. - - ---@cast old_formatted { formatted_data: { lines: string[] } } - local old_lines = old_formatted.formatted_data.lines - local first_diff_line = nil - - -- Find the first line that's different - for i = 1, math.min(#old_lines, new_line_count) do - if old_lines[i] ~= new_lines[i] then - first_diff_line = i - break - end - end - - if not first_diff_line and new_line_count > #old_lines then - -- The old lines all matched but maybe there are more new lines - first_diff_line = #old_lines + 1 - end - - if first_diff_line then - lines_to_write = vim.list_slice(new_lines, first_diff_line, new_line_count) - write_start_line = cached.line_start + first_diff_line - 1 - elseif new_line_count == #old_lines then - -- Nothing was different, so we're done - M._last_part_formatted = { part_id = part_id, formatted_data = formatted_data } - return true - end - end - - M._render_state:clear_actions(part_id) - - output_window.clear_extmarks(cached.line_start - 1, cached.line_end + 1) - output_window.set_lines(lines_to_write, write_start_line, cached.line_end + 1) - - local new_line_end = cached.line_start + new_line_count - 1 - - output_window.set_extmarks(formatted_data.extmarks, cached.line_start) - - if formatted_data.actions then - M._render_state:add_actions(part_id, formatted_data.actions, cached.line_start + 1) - end - - M._render_state:update_part_lines(part_id, cached.line_start, new_line_end) - - M._last_part_formatted = { part_id = part_id, formatted_data = formatted_data } - - return true -end - ----Remove part from buffer and adjust subsequent line positions ----@param part_id string Part ID -function M._remove_part_from_buffer(part_id) - local cached = M._render_state:get_part(part_id) - if not cached or not cached.line_start or not cached.line_end then - return - end - - if not state.windows or not state.windows.output_buf then - return - end - - output_window.clear_extmarks(cached.line_start - 1, cached.line_end) - output_window.set_lines({}, cached.line_start - 1, cached.line_end) - - M._render_state:remove_part(part_id) -end - ----Remove message header from buffer and adjust subsequent line positions ----@param message_id string Message ID -function M._remove_message_from_buffer(message_id) - local cached = M._render_state:get_message(message_id) - if not cached or not cached.line_start or not cached.line_end then - return - end - - if not state.windows or not state.windows.output_buf then - return - end - - if cached.line_start == 0 and cached.line_end == 0 then - return - end - output_window.clear_extmarks(cached.line_start - 1, cached.line_end) - output_window.set_lines({}, cached.line_start - 1, cached.line_end) - - M._render_state:remove_message(message_id) -end - ----Adds a message (most likely just a header) to the buffer ----@param message OpencodeMessage Message to add -function M._add_message_to_buffer(message) - local header_data = formatter.format_message_header(message) - local range = M._write_formatted_data(header_data) - - if range then - M._render_state:set_message(message, range.line_start, range.line_end) - end - - if message.info.role == 'user' then - M.scroll_to_bottom(true) - end -end - ----Replace existing message header in buffer ----@param message_id string Message ID ----@param formatted_data Output Formatted header as Output object ----@return boolean Success status -function M._replace_message_in_buffer(message_id, formatted_data) - local cached = M._render_state:get_message(message_id) - if not cached or not cached.line_start or not cached.line_end then - return false - end - - local new_lines = formatted_data.lines - local new_line_count = #new_lines - - output_window.clear_extmarks(cached.line_start, cached.line_end + 1) - output_window.set_lines(new_lines, cached.line_start, cached.line_end + 1) - output_window.set_extmarks(formatted_data.extmarks, cached.line_start) - - local old_line_end = cached.line_end - local new_line_end = cached.line_start + new_line_count - 1 - - M._render_state:set_message(cached.message, cached.line_start, new_line_end) - - local delta = new_line_end - old_line_end - if delta ~= 0 then - M._render_state:shift_all(old_line_end + 1, delta) - end - - return true -end - ----Event handler for message.updated events ----Creates new message or updates existing message info ----@param message {info: MessageInfo} Event properties ----@param revert_index? integer Revert index in session, if applicable -function M.on_message_updated(message, revert_index) - if not state.active_session or not state.messages then - return - end - - local msg = message --[[@as OpencodeMessage]] - if not msg or not msg.info or not msg.info.id or not msg.info.sessionID then - return - end - - if state.active_session.id ~= msg.info.sessionID then - ---@TODO This is probably a child session message, handle differently? - -- vim.notify('Session id does not match, discarding message: ' .. vim.inspect(message), vim.log.levels.WARN) - return - end - - local rendered_message = M._render_state:get_message(msg.info.id) - local found_msg = rendered_message and rendered_message.message - - if revert_index then - if not found_msg then - table.insert(state.messages, msg) - end - M._render_state:set_message(msg, 0, 0) - return - end - - if found_msg then - local error_changed = not vim.deep_equal(found_msg.info.error, msg.info.error) - - found_msg.info = msg.info - - --- NOTE: error handling is a bit messy because errors come in on messages - --- but we want to display the error at the end. In this case, we an error - --- was added to this message. We find the last part and re-render it to - --- display the message. If there are no parts, we'll re-render the message - - if error_changed and not revert_index then - local last_part_id = M._get_last_part_for_message(found_msg) - if last_part_id then - M._rerender_part(last_part_id) - else - local header_data = formatter.format_message_header(found_msg) - M._replace_message_in_buffer(msg.info.id, header_data) - end - end - else - table.insert(state.messages, msg) - - M._add_message_to_buffer(msg) - - state.renderer.set_current_message(msg) - if message.info.role == 'user' then - state.renderer.set_last_user_message(msg) - end - end - - M._update_stats_from_message(msg) -end - ----Event handler for message.part.updated events ----Inserts new parts or replaces existing parts in buffer ----@param properties {part: OpencodeMessagePart} Event properties ----@param revert_index? integer Revert index in session, if applicable -function M.on_part_updated(properties, revert_index) - if not properties or not properties.part or not state.active_session then - return - end - - local part = properties.part - if not part.id or not part.messageID or not part.sessionID then - return - end - - if state.active_session.id ~= part.sessionID then - if part.tool or part.type == 'tool' then - M._render_state:upsert_child_session_part(part.sessionID, part) - - M._rerender_task_tool_for_child_session(part.sessionID) - end - return - end - - local rendered_message = M._render_state:get_message(part.messageID) - if not rendered_message or not rendered_message.message then - vim.notify('Could not find message for part: ' .. vim.inspect(part), vim.log.levels.WARN) - return - end - - local message = rendered_message.message - - message.parts = message.parts or {} - - local part_data = M._render_state:get_part(part.id) - local is_new_part = not part_data - - local prev_last_part_id = M._get_last_part_for_message(message) - local is_last_part = is_new_part or (prev_last_part_id == part.id) - - if is_new_part then - table.insert(message.parts, part) - else - for i = #message.parts, 1, -1 do - if message.parts[i].id == part.id then - message.parts[i] = part - break - end - end - end - - if part.type == 'step-start' or part.type == 'step-finish' then - return - end - - if is_new_part then - M._render_state:set_part(part) - else - local rendered_part = M._render_state:update_part_data(part) - -- NOTE: This isn't the first time we've seen the part but we haven't rendered it previously - -- so try and render it this time by setting is_new_part = true (otherwise we'd call - -- _replace_message_in_buffer and it wouldn't do anything because the part hasn't been rendered) - if not rendered_part or (not rendered_part.line_start and not rendered_part.line_end) then - is_new_part = true - end - end - - local formatted = formatter.format_part(part, message, is_last_part, function(session_id) - return M._render_state:get_child_session_parts(session_id) - end) - - if part.callID and state.pending_permissions then - for _, permission in ipairs(state.pending_permissions) do - local tool = permission.tool - local perm_callID = tool and tool.callID or permission.callID - local perm_messageID = tool and tool.messageID or permission.messageID - - if perm_callID == part.callID and perm_messageID == part.messageID then - require('opencode.ui.permission_window').update_permission_from_part(permission.id, part) - break - end - end - end - - if revert_index and is_new_part then - return - end - - if is_new_part then - M._insert_part_to_buffer(part.id, formatted) - - if message.info.error then - --- NOTE: More error display code. As mentioned above, errors come in on messages - --- but we want to display them after parts so we tack the error onto the last - --- part. When a part is added and there's an error, we need to rerender - --- previous last part so it doesn't also display the message. If there was no previous - --- part, then we need to rerender the header so it doesn't display the error - - if not prev_last_part_id then - -- no previous part, we're the first part, re-render the message header - -- so it doesn't also display the error - local header_data = formatter.format_message_header(message) - M._replace_message_in_buffer(part.messageID, header_data) - elseif prev_last_part_id ~= part.id then - M._rerender_part(prev_last_part_id) - end - end - else - M._replace_part_in_buffer(part.id, formatted) - end - - if (part.type == 'file' or part.type == 'agent') and part.source then - -- we have a mention, we need to rerender the early part to highlight - -- the mention. - local text_part_id = M._find_text_part_for_message(message) - if text_part_id then - M._rerender_part(text_part_id) - end - end -end - ----Event handler for message.part.removed events ----@param properties {sessionID: string, messageID: string, partID: string} Event properties -function M.on_part_removed(properties) - if not properties then - return - end - - local part_id = properties.partID - if not part_id then - return - end - - local cached = M._render_state:get_part(part_id) - if cached and cached.message_id then - local rendered_message = M._render_state:get_message(cached.message_id) - if rendered_message and rendered_message.message then - local message = rendered_message.message - if message.parts then - for i, part in ipairs(message.parts) do - if part.id == part_id then - table.remove(message.parts, i) - break - end - end - end - end - end - - M._remove_part_from_buffer(part_id) -end - ----Event handler for message.removed events ----Removes message and all its parts from buffer ----@param properties {sessionID: string, messageID: string} Event properties -function M.on_message_removed(properties) - if not properties or not state.messages then - return - end - - local message_id = properties.messageID - if not message_id then - return - end - - local rendered_message = M._render_state:get_message(message_id) - if not rendered_message or not rendered_message.message then - return - end - - local message = rendered_message.message - for _, part in ipairs(message.parts or {}) do - if part.id then - M._remove_part_from_buffer(part.id) - end - end - - M._remove_message_from_buffer(message_id) - - for i, msg in ipairs(state.messages or {}) do - if msg.info.id == message_id then - table.remove(state.messages, i) - break - end - end -end - ----Event handler for session.compacted events ----@param properties {sessionID: string} Event properties -function M.on_session_compacted(properties) - vim.notify('Session has been compacted') -end - ----Event handler for session.updated events ----@param properties {info: Session} -function M.on_session_updated(properties) - if not properties or not properties.info or not state.active_session then - return - end - - local updated_session = properties.info - if not updated_session.id or updated_session.id ~= state.active_session.id then - return - end - - local current_session = state.active_session - local revert_changed = not vim.deep_equal(current_session.revert, updated_session.revert) - local previous_title = current_session.title - - if not vim.deep_equal(current_session, updated_session) then - -- NOTE: we set the session without emitting a change event because we don't want to trigger another rerender. - state.store.set_raw('active_session', updated_session) - end - - if revert_changed then - M._render_full_session_data(state.messages) - end -end - ----Event handler for session.error events ----@param properties {sessionID: string, error: table} Event properties -function M.on_session_error(properties) - if not properties or not properties.error then - return - end - - -- NOTE: we're handling message errors so session errors seem duplicative - if config.debug.enabled then - vim.notify('Session error: ' .. vim.inspect(properties.error)) - end -end - ----Event handler for permission.updated events ----Re-renders part that requires permission and adds to permission window ----@param permission OpencodePermission Event properties -function M.on_permission_updated(permission) - local tool = permission.tool - - ---@TODO this is for backward compatibility, remove later - local callID = tool and tool.callID or permission.callID - local messageID = tool and tool.messageID or permission.messageID - - if not permission or not messageID or not callID then - return - end - - -- Add permission to pending queue - if not state.pending_permissions then - state.renderer.set_pending_permissions({}) - end - - -- Check if permission already exists in queue - local existing_index = nil - for i, existing in ipairs(state.pending_permissions) do - if existing.id == permission.id then - existing_index = i - break - end - end - - state.renderer.update_pending_permissions(function(permissions) - if existing_index then - permissions[existing_index] = permission - else - table.insert(permissions, permission) - end - end) - - permission_window.add_permission(permission) - - M.render_permissions_display() - - M._rerender_part('permission-display-part') - M.scroll_to_bottom(true) -end - ----Event handler for permission.replied events ----Re-renders part after permission is resolved and removes from window ----@param properties {sessionID: string, permissionID?: string,requestID?: string, response: string}|{} Event properties -function M.on_permission_replied(properties) - if not properties then - return - end - - local permission_id = properties.permissionID or properties.requestID - - if permission_id then - permission_window.remove_permission(permission_id) - state.renderer.set_pending_permissions(vim.deepcopy(permission_window.get_all_permissions())) - if #state.pending_permissions == 0 then - M._remove_part_from_buffer('permission-display-part') - M._remove_message_from_buffer('permission-display-message') - end - M._rerender_part('permission-display-part') - end -end - ----Event handler for question.asked events ----Shows the question picker UI for the user to answer ----@param properties OpencodeQuestionRequest Event properties -function M.on_question_asked(properties) - if not properties or not properties.id or not properties.questions then - return - end - - local question_window = require('opencode.ui.question_window') - question_window.show_question(properties) -end - -function M.on_file_edited(properties) - vim.cmd('checktime') - if config.hooks and config.hooks.on_file_edited then - pcall(config.hooks.on_file_edited, properties.file) - end -end - ----@param properties RestorePointCreatedEvent -function M.on_restore_points(properties) - state.store.append('restore_points', properties.restore_point) - if not properties or not properties.restore_point or not properties.restore_point.from_snapshot_id then - return - end - local part = M._render_state:get_part_by_snapshot_id(properties.restore_point.from_snapshot_id) - if part then - M.on_part_updated({ part = part }) - end -end - ----Find part ID by call ID and message ID ----Useful for finding a part for a permission ----@param call_id string Call ID to search for ----@param message_id string Message ID to check the parts of ----@return string? part_id Part ID if found, nil otherwise -function M._find_part_by_call_id(call_id, message_id) - return M._render_state:get_part_by_call_id(call_id, message_id) -end - ----Find the text part in a message ----@param message OpencodeMessage The message containing the parts ----@return string? text_part_id The ID of the text part -function M._find_text_part_for_message(message) - if not message or not message.parts then - return nil - end - - for _, part in ipairs(message.parts) do - if part.type == 'text' and not part.synthetic then - return part.id - end - end - - return nil -end - ----Find the last part in a message ----@param message OpencodeMessage The message containing the parts ----@return string? last_part_id The ID of the last part -function M._get_last_part_for_message(message) - if not message or not message.parts or #message.parts == 0 then - return nil - end - - for i = #message.parts, 1, -1 do - local part = message.parts[i] - if part.type ~= 'step-start' and part.type ~= 'step-finish' and part.id then - return part.id - end - end - - return nil -end - ----Get insertion point for an out-of-order part ----@param part_id string The part ID to insert ----@param message_id string The message ID the part belongs to ----@return integer? insertion_line The line to insert at (1-indexed), or nil on error -function M._get_insertion_point_for_part(part_id, message_id) - local rendered_message = M._render_state:get_message(message_id) - if not rendered_message or not rendered_message.message then - return nil - end - - local message = rendered_message.message - - local insertion_line = rendered_message.line_end and (rendered_message.line_end + 1) - if not insertion_line then - return nil - end - - local current_part_index = nil - if message.parts then - for i, part in ipairs(message.parts) do - if part.id == part_id then - current_part_index = i - break - end - end - end - - if not current_part_index then - return insertion_line - end - - for i = current_part_index - 1, 1, -1 do - local prev_part = message.parts[i] - if prev_part and prev_part.id then - local prev_rendered = M._render_state:get_part(prev_part.id) - - if prev_rendered and prev_rendered.line_end then - return prev_rendered.line_end + 1 - end - end - end - - return insertion_line -end - ----Find and re-render the task tool part in the active session that owns a given child session ----@param child_session_id string The child session ID to look up -function M._rerender_task_tool_for_child_session(child_session_id) - local part_id = M._render_state:get_task_part_by_child_session(child_session_id) - if not part_id then - return - end - - M._rerender_part(part_id) -end - ----Re-render existing part with current state ----Used for permission updates and other dynamic changes ----@param part_id string Part ID to re-render -function M._rerender_part(part_id) - local cached = M._render_state:get_part(part_id) - if not cached or not cached.part then - return - end - - local part = cached.part - local rendered_message = M._render_state:get_message(cached.message_id) - if not rendered_message or not rendered_message.message then - return - end - - local message = rendered_message.message - local last_part_id = M._get_last_part_for_message(message) - local is_last_part = (last_part_id == part_id) - local formatted = formatter.format_part(part, message, is_last_part, function(session_id) - return M._render_state:get_child_session_parts(session_id) - end) - - M._replace_part_in_buffer(part_id, formatted) -end - ----Event handler for focus changes ----Re-renders part associated with current permission for displaying global shortcuts or buffer-local ones +---Re-render the permission display when focus changes (updates shortcut hints) function M.on_focus_changed() - -- Check permission window first, fallback to state - local current_permission = permission_window.get_all_permissions()[1] - - if not current_permission then + if not permission_window.get_all_permissions()[1] then return end - - M._rerender_part('permission-display-part') + buf.rerender_part('permission-display-part') trigger_on_data_rendered() end +---Re-render when the active session changes function M.on_session_changed(_, new, old) if (old and old.id) == (new and new.id) then return end - M.reset() if new then M.render_full_session() end end ----Get all actions available at a specific line ----@param line integer 0-indexed line number ----@return table[] List of actions available at that line -function M.get_actions_for_line(line) - return M._render_state:get_actions_at_line(line) +---Scroll to bottom after all queued events have been processed +function M.on_emit_events_finished() + M.scroll_to_bottom() end ----Update display stats from a single message ----@param message OpencodeMessage -function M._update_stats_from_message(message) - if not state.current_model and message.info.providerID and message.info.providerID ~= '' then - state.model.set_model(message.info.providerID .. '/' .. message.info.modelID) - end - - local tokens = message.info.tokens - if tokens and tokens.input > 0 and message.info.cost and type(message.info.cost) == 'number' then - state.renderer.set_stats(tokens.input + tokens.output + tokens.cache.read + tokens.cache.write, message.info.cost) - elseif tokens and tokens.input > 0 then - state.renderer.set_tokens_count(tokens.input + tokens.output + tokens.cache.read + tokens.cache.write) - elseif message.info.cost and type(message.info.cost) == 'number' then - state.renderer.set_cost(message.info.cost) - end +---Return all actions available at a given (0-indexed) line +---@param line integer +---@return table[] +function M.get_actions_for_line(line) + return ctx.render_state:get_actions_at_line(line) end ----Get rendered message by ID ----@param message_id string Message ID ----@return RenderedMessage|nil Rendered message or nil if not found +---Return the rendered message record for a given message ID +---@param message_id string +---@return RenderedMessage|nil function M.get_rendered_message(message_id) - local rendered_msg = M._render_state:get_message(message_id) - if rendered_msg then - return rendered_msg - end - return nil + return ctx.render_state:get_message(message_id) or nil end return M diff --git a/lua/opencode/ui/renderer/buffer.lua b/lua/opencode/ui/renderer/buffer.lua new file mode 100644 index 00000000..5e59d77e --- /dev/null +++ b/lua/opencode/ui/renderer/buffer.lua @@ -0,0 +1,402 @@ +local ctx = require('opencode.ui.renderer.ctx') +local state = require('opencode.state') +local formatter = require('opencode.ui.formatter') +local output_window = require('opencode.ui.output_window') + +local M = {} + +local function has_extmarks(extmarks) + return type(extmarks) == 'table' and next(extmarks) ~= nil +end + +local function has_actions(actions) + return type(actions) == 'table' and #actions > 0 +end + +---@param old_lines string[] +---@param new_lines string[] +---@return integer, integer +local function get_shared_prefix_suffix(old_lines, new_lines) + local old_count = #old_lines + local new_count = #new_lines + local prefix = 0 + + while prefix < old_count and prefix < new_count do + if old_lines[prefix + 1] ~= new_lines[prefix + 1] then + break + end + prefix = prefix + 1 + end + + local suffix = 0 + while suffix < (old_count - prefix) and suffix < (new_count - prefix) do + if old_lines[old_count - suffix] ~= new_lines[new_count - suffix] then + break + end + suffix = suffix + 1 + end + + return prefix, suffix +end + +---Find the last renderable part ID in a message (skips step-start/finish) +---@param message OpencodeMessage +---@return string? +function M.get_last_part_for_message(message) + if not message or not message.parts or #message.parts == 0 then + return nil + end + for i = #message.parts, 1, -1 do + local part = message.parts[i] + if part.type ~= 'step-start' and part.type ~= 'step-finish' and part.id then + return part.id + end + end + return nil +end + +---Find the first non-synthetic text part ID in a message +---@param message OpencodeMessage +---@return string? +function M.find_text_part_for_message(message) + if not message or not message.parts then + return nil + end + for _, part in ipairs(message.parts) do + if part.type == 'text' and not part.synthetic then + return part.id + end + end + return nil +end + +---Find part ID by call ID and message ID +---@param call_id string +---@param message_id string +---@return string? +function M.find_part_by_call_id(call_id, message_id) + return ctx.render_state:get_part_by_call_id(call_id, message_id) +end + +---Determine where to insert an out-of-order part (after the last rendered +---sibling, or right after the message header if no siblings are rendered yet) +---@param part_id string +---@param message_id string +---@return integer? +local function get_insertion_point_for_part(part_id, message_id) + local rendered_message = ctx.render_state:get_message(message_id) + if not rendered_message or not rendered_message.message then + return nil + end + + local message = rendered_message.message + local insertion_line = rendered_message.line_end and (rendered_message.line_end + 1) + if not insertion_line then + return nil + end + + local current_part_index = nil + if message.parts then + for i, part in ipairs(message.parts) do + if part.id == part_id then + current_part_index = i + break + end + end + end + + if not current_part_index then + return insertion_line + end + + -- Walk backwards through earlier siblings to find the last rendered one + for i = current_part_index - 1, 1, -1 do + local prev_part = message.parts[i] + if prev_part and prev_part.id then + local prev_rendered = ctx.render_state:get_part(prev_part.id) + if prev_rendered and prev_rendered.line_end then + return prev_rendered.line_end + 1 + end + end + end + + return insertion_line +end + +---Append formatted data to the end of the buffer, or insert at start_line. +---Returns the range of lines written, or nil if nothing was written. +---@param formatted_data Output +---@param part_id? string When provided, actions are registered for this part +---@param start_line? integer When provided, content is inserted here (shifts down) +---@return {line_start: integer, line_end: integer}? +function M.write_formatted_data(formatted_data, part_id, start_line) + if not state.windows or not state.windows.output_buf then + return nil + end + + local new_lines = formatted_data.lines + if #new_lines == 0 then + return nil + end + + local is_insertion = start_line ~= nil + local target_line = start_line or output_window.get_buf_line_count() + + if is_insertion then + output_window.set_lines(new_lines, target_line, target_line) + else + -- Append: temporarily add a separating blank line in-place to avoid an + -- O(n) copy, then restore. set_lines is synchronous so this is safe. + target_line = target_line - 1 + new_lines[#new_lines + 1] = '' + output_window.set_lines(new_lines, target_line) + new_lines[#new_lines] = nil + end + + if part_id and formatted_data.actions then + ctx.render_state:add_actions(part_id, formatted_data.actions, target_line) + end + + if has_extmarks(formatted_data.extmarks) then + output_window.set_extmarks(formatted_data.extmarks, target_line) + local part_data = ctx.render_state:get_part(part_id) + if part_data then + part_data.has_extmarks = true + end + end + + return { line_start = target_line, line_end = target_line + #new_lines - 1 } +end + +---Insert a new part into the buffer. +---Appends if the part belongs to the current message; inserts in-order otherwise. +---@param part_id string +---@param formatted_data Output +---@return boolean +function M.insert_part(part_id, formatted_data) + local cached = ctx.render_state:get_part(part_id) + if not cached then + return false + end + + if #formatted_data.lines == 0 then + return true + end + + local is_current_message = state.current_message + and state.current_message.info + and state.current_message.info.id == cached.message_id + + if is_current_message then + local range = M.write_formatted_data(formatted_data, part_id) + if not range then + return false + end + ctx.render_state:set_part(cached.part, range.line_start, range.line_end) + ctx.last_part_formatted = { part_id = part_id, formatted_data = formatted_data } + return true + end + + -- Out-of-order part: find the correct insertion point + local insertion_line = get_insertion_point_for_part(part_id, cached.message_id) + if not insertion_line then + return false + end + + local range = M.write_formatted_data(formatted_data, part_id, insertion_line) + if not range then + return false + end + + ctx.render_state:shift_all(insertion_line, #formatted_data.lines) + ctx.render_state:set_part(cached.part, range.line_start, range.line_end) + return true +end + +---Replace an existing part in the buffer. +---Only writes lines that differ from the previous render (diff optimisation). +---@param part_id string +---@param formatted_data Output +---@return boolean +function M.replace_part(part_id, formatted_data) + local cached = ctx.render_state:get_part(part_id) + if not cached or not cached.line_start or not cached.line_end then + return false + end + + local new_lines = formatted_data.lines + local new_line_count = #new_lines + local next_has_extmarks = has_extmarks(formatted_data.extmarks) + local had_extmarks = cached.has_extmarks == true + local next_has_actions = has_actions(formatted_data.actions) + local had_actions = cached.actions and #cached.actions > 0 + local old_buf_line_count = output_window.get_buf_line_count() + local was_tail_part = cached.line_end == old_buf_line_count - 1 + + -- Diff optimisation: skip lines that haven't changed since the last render + local old = ctx.last_part_formatted + local lines_to_write = new_lines + local write_start = cached.line_start + local write_end = cached.line_end + 1 + local prefix = 0 + local suffix = 0 + + if old and old.part_id == part_id and old.formatted_data and old.formatted_data.lines then + local old_lines = old.formatted_data.lines + prefix, suffix = get_shared_prefix_suffix(old_lines, new_lines) + + if prefix == #old_lines and prefix == new_line_count then + if not had_extmarks and not next_has_extmarks and not had_actions and not next_has_actions then + ctx.last_part_formatted = { part_id = part_id, formatted_data = formatted_data } + return true + end + end + + local replace_from = prefix + 1 + local replace_to = new_line_count - suffix + lines_to_write = replace_from <= replace_to and vim.list_slice(new_lines, replace_from, replace_to) or {} + write_start = cached.line_start + prefix + write_end = cached.line_end + 1 - suffix + end + + if had_actions or next_has_actions then + ctx.render_state:clear_actions(part_id) + end + + output_window.begin_update() + if had_extmarks or next_has_extmarks then + output_window.clear_extmarks(cached.line_start - 1, cached.line_end + 1) + end + output_window.set_lines(lines_to_write, write_start, write_end) + + local new_line_end = cached.line_start + new_line_count - 1 + if next_has_extmarks then + output_window.set_extmarks(formatted_data.extmarks, cached.line_start) + end + output_window.end_update() + cached.has_extmarks = next_has_extmarks + + if next_has_actions then + ctx.render_state:add_actions(part_id, formatted_data.actions, cached.line_start + 1) + end + + if new_line_end ~= cached.line_end then + if was_tail_part then + ctx.render_state:set_part(cached.part, cached.line_start, new_line_end) + else + ctx.render_state:update_part_lines(part_id, cached.line_start, new_line_end) + end + end + + ctx.last_part_formatted = { part_id = part_id, formatted_data = formatted_data } + return true +end + +---Remove a part and its extmarks from the buffer +---@param part_id string +function M.remove_part(part_id) + local cached = ctx.render_state:get_part(part_id) + if not cached or not cached.line_start or not cached.line_end then + return + end + output_window.begin_update() + output_window.clear_extmarks(cached.line_start - 1, cached.line_end) + output_window.set_lines({}, cached.line_start - 1, cached.line_end) + output_window.end_update() + ctx.render_state:remove_part(part_id) +end + +---Write a message header into the buffer +---@param message OpencodeMessage +function M.add_message(message) + local header_data = formatter.format_message_header(message) + local range = M.write_formatted_data(header_data) + if range then + ctx.render_state:set_message(message, range.line_start, range.line_end) + end +end + +---Replace an existing message header in the buffer +---@param message_id string +---@param formatted_data Output +---@return boolean +function M.replace_message(message_id, formatted_data) + local cached = ctx.render_state:get_message(message_id) + if not cached or not cached.line_start or not cached.line_end then + return false + end + + local new_lines = formatted_data.lines + local new_line_count = #new_lines + + output_window.begin_update() + output_window.clear_extmarks(cached.line_start, cached.line_end + 1) + output_window.set_lines(new_lines, cached.line_start, cached.line_end + 1) + output_window.set_extmarks(formatted_data.extmarks, cached.line_start) + output_window.end_update() + + local old_line_end = cached.line_end + local new_line_end = cached.line_start + new_line_count - 1 + + ctx.render_state:set_message(cached.message, cached.line_start, new_line_end) + + local delta = new_line_end - old_line_end + if delta ~= 0 then + ctx.render_state:shift_all(old_line_end + 1, delta) + end + + return true +end + +---Remove a message header and its extmarks from the buffer +---@param message_id string +function M.remove_message(message_id) + local cached = ctx.render_state:get_message(message_id) + if not cached or not cached.line_start or not cached.line_end then + return + end + if not state.windows or not state.windows.output_buf then + return + end + if cached.line_start == 0 and cached.line_end == 0 then + return + end + output_window.begin_update() + output_window.clear_extmarks(cached.line_start - 1, cached.line_end) + output_window.set_lines({}, cached.line_start - 1, cached.line_end) + output_window.end_update() + ctx.render_state:remove_message(message_id) +end + +---Re-render an existing part using its current data from render_state +---@param part_id string +function M.rerender_part(part_id) + local cached = ctx.render_state:get_part(part_id) + if not cached or not cached.part then + return + end + + local rendered_message = ctx.render_state:get_message(cached.message_id) + if not rendered_message or not rendered_message.message then + return + end + + local message = rendered_message.message + local is_last_part = (M.get_last_part_for_message(message) == part_id) + local formatted = formatter.format_part(cached.part, message, is_last_part, function(session_id) + return ctx.render_state:get_child_session_parts(session_id) + end) + + M.replace_part(part_id, formatted) +end + +---Re-render the task-tool part that owns the given child session +---@param child_session_id string +function M.rerender_task_tool_for_child_session(child_session_id) + local part_id = ctx.render_state:get_task_part_by_child_session(child_session_id) + if part_id then + M.rerender_part(part_id) + end +end + +return M diff --git a/lua/opencode/ui/renderer/ctx.lua b/lua/opencode/ui/renderer/ctx.lua new file mode 100644 index 00000000..7cbeddee --- /dev/null +++ b/lua/opencode/ui/renderer/ctx.lua @@ -0,0 +1,21 @@ +local RenderState = require('opencode.ui.render_state') + +---Shared mutable context for the renderer modules. +---Single instance, shared via Lua's require cache. +---@class RendererCtx +local ctx = { + ---@type RenderState + render_state = RenderState.new(), + ---@type integer + prev_line_count = 0, + ---@type { part_id: string|nil, formatted_data: Output|nil } + last_part_formatted = { part_id = nil, formatted_data = nil }, +} + +function ctx:reset() + self.render_state:reset() + self.prev_line_count = 0 + self.last_part_formatted = { part_id = nil, formatted_data = nil } +end + +return ctx diff --git a/lua/opencode/ui/renderer/events.lua b/lua/opencode/ui/renderer/events.lua new file mode 100644 index 00000000..5ed0f143 --- /dev/null +++ b/lua/opencode/ui/renderer/events.lua @@ -0,0 +1,471 @@ +local state = require('opencode.state') +local config = require('opencode.config') +local formatter = require('opencode.ui.formatter') +local ctx = require('opencode.ui.renderer.ctx') +local buf = require('opencode.ui.renderer.buffer') +local permission_window = require('opencode.ui.permission_window') + +-- Lazy require to avoid circular dependency: renderer.lua <-> events.lua +local function scroll(force) + require('opencode.ui.renderer').scroll_to_bottom(force) +end + +local M = {} + +---Update token/cost stats in state from a message +---@param message OpencodeMessage +local function update_stats(message) + if not state.current_model and message.info.providerID and message.info.providerID ~= '' then + state.model.set_model(message.info.providerID .. '/' .. message.info.modelID) + end + + local tokens = message.info.tokens + if tokens and tokens.input > 0 and message.info.cost and type(message.info.cost) == 'number' then + state.renderer.set_stats(tokens.input + tokens.output + tokens.cache.read + tokens.cache.write, message.info.cost) + elseif tokens and tokens.input > 0 then + state.renderer.set_tokens_count(tokens.input + tokens.output + tokens.cache.read + tokens.cache.write) + elseif message.info.cost and type(message.info.cost) == 'number' then + state.renderer.set_cost(message.info.cost) + end +end + +---Render pending permissions as a synthetic part at the end of the buffer +function M.render_permissions_display() + local permissions = permission_window.get_all_permissions() + if not permissions or #permissions == 0 then + buf.remove_part('permission-display-part') + buf.remove_message('permission-display-message') + return + end + + local fake_message = { + info = { + id = 'permission-display-message', + sessionID = state.active_session and state.active_session.id or '', + role = 'system', + }, + parts = {}, + } + M.on_message_updated(fake_message --[[@as OpencodeMessage]]) + + local fake_part = { + id = 'permission-display-part', + messageID = 'permission-display-message', + sessionID = state.active_session and state.active_session.id or '', + type = 'permissions-display', + } + M.on_part_updated({ part = fake_part }) + scroll(true) +end + +---Render the current question as a synthetic part at the end of the buffer +function M.render_question_display() + local use_vim_ui = config.ui.questions and config.ui.questions.use_vim_ui_select + if use_vim_ui then + return + end + + local question_window = require('opencode.ui.question_window') + local current_question = question_window._current_question + + if not question_window.has_question() or not current_question or not current_question.id then + buf.remove_part('question-display-part') + buf.remove_message('question-display-message') + return + end + + local fake_message = { + info = { + id = 'question-display-message', + sessionID = state.active_session and state.active_session.id or '', + role = 'system', + }, + parts = {}, + } + M.on_message_updated(fake_message --[[@as OpencodeMessage]]) + + local fake_part = { + id = 'question-display-part', + messageID = 'question-display-message', + sessionID = state.active_session and state.active_session.id or '', + type = 'questions-display', + } + M.on_part_updated({ part = fake_part }) + scroll(true) +end + +---Remove the question display from the buffer +function M.clear_question_display() + local use_vim_ui = config.ui.questions and config.ui.questions.use_vim_ui_select + local question_window = require('opencode.ui.question_window') + question_window.clear_question() + + if not use_vim_ui then + buf.remove_part('question-display-part') + buf.remove_message('question-display-message') + end +end + +---Handle message.updated — create the message header or update existing info +---@param message {info: MessageInfo} +---@param revert_index? integer +function M.on_message_updated(message, revert_index) + if not state.active_session or not state.messages then + return + end + + local msg = message --[[@as OpencodeMessage]] + if not msg or not msg.info or not msg.info.id or not msg.info.sessionID then + return + end + + if state.active_session.id ~= msg.info.sessionID then + return + end + + local rendered_message = ctx.render_state:get_message(msg.info.id) + local found_msg = rendered_message and rendered_message.message + + if revert_index then + if not found_msg then + table.insert(state.messages, msg) + end + ctx.render_state:set_message(msg, 0, 0) + return + end + + if found_msg then + local error_changed = not vim.deep_equal(found_msg.info.error, msg.info.error) + found_msg.info = msg.info + + -- Errors arrive on the message but we display them after the last part. + -- Re-render the last part (or the header if there are no parts) so the + -- error appears in the right place. + if error_changed then + local last_part_id = buf.get_last_part_for_message(found_msg) + if last_part_id then + buf.rerender_part(last_part_id) + else + local header_data = formatter.format_message_header(found_msg) + buf.replace_message(msg.info.id, header_data) + end + end + else + table.insert(state.messages, msg) + buf.add_message(msg) + state.renderer.set_current_message(msg) + if message.info.role == 'user' then + state.renderer.set_last_user_message(msg) + scroll(true) + end + end + + update_stats(msg) +end + +---Handle message.removed — remove the message and all its parts from the buffer +---@param properties {sessionID: string, messageID: string} +function M.on_message_removed(properties) + if not properties or not state.messages then + return + end + + local message_id = properties.messageID + if not message_id then + return + end + + local rendered_message = ctx.render_state:get_message(message_id) + if not rendered_message or not rendered_message.message then + return + end + + for _, part in ipairs(rendered_message.message.parts or {}) do + if part.id then + buf.remove_part(part.id) + end + end + + buf.remove_message(message_id) + + for i, msg in ipairs(state.messages or {}) do + if msg.info.id == message_id then + table.remove(state.messages, i) + break + end + end +end + +---Handle message.part.updated — insert or replace a part in the buffer +---@param properties {part: OpencodeMessagePart} +---@param revert_index? integer +function M.on_part_updated(properties, revert_index) + if not properties or not properties.part or not state.active_session then + return + end + + local part = properties.part + if not part.id or not part.messageID or not part.sessionID then + return + end + + -- Child-session parts: update the task-tool display instead + if state.active_session.id ~= part.sessionID then + if part.tool or part.type == 'tool' then + ctx.render_state:upsert_child_session_part(part.sessionID, part) + buf.rerender_task_tool_for_child_session(part.sessionID) + end + return + end + + local rendered_message = ctx.render_state:get_message(part.messageID) + if not rendered_message or not rendered_message.message then + vim.notify('Could not find message for part: ' .. vim.inspect(part), vim.log.levels.WARN) + return + end + + local message = rendered_message.message + message.parts = message.parts or {} + + local part_data = ctx.render_state:get_part(part.id) + local is_new_part = not part_data + + local prev_last_part_id = buf.get_last_part_for_message(message) + local is_last_part = is_new_part or (prev_last_part_id == part.id) + + -- Update the part reference in the message + if is_new_part then + table.insert(message.parts, part) + else + for i = #message.parts, 1, -1 do + if message.parts[i].id == part.id then + message.parts[i] = part + break + end + end + end + + -- step-start / step-finish are bookkeeping only — nothing to render + if part.type == 'step-start' or part.type == 'step-finish' then + return + end + + if is_new_part then + ctx.render_state:set_part(part) + else + local rendered_part = ctx.render_state:update_part_data(part) + -- Part known but never rendered yet — treat as new + if not rendered_part or (not rendered_part.line_start and not rendered_part.line_end) then + is_new_part = true + end + end + + -- Update the permission window if this part has a pending permission + if part.callID and state.pending_permissions then + for _, permission in ipairs(state.pending_permissions) do + local tool = permission.tool + local perm_callID = tool and tool.callID or permission.callID + local perm_messageID = tool and tool.messageID or permission.messageID + if perm_callID == part.callID and perm_messageID == part.messageID then + permission_window.update_permission_from_part(permission.id, part) + break + end + end + end + + if revert_index and is_new_part then + return + end + + local formatted = formatter.format_part(part, message, is_last_part, function(session_id) + return ctx.render_state:get_child_session_parts(session_id) + end) + + if is_new_part then + buf.insert_part(part.id, formatted) + + -- If there's already an error on this message, adjust adjacent parts so + -- the error only appears after the last part. + if message.info.error then + if not prev_last_part_id then + local header_data = formatter.format_message_header(message) + buf.replace_message(part.messageID, header_data) + elseif prev_last_part_id ~= part.id then + buf.rerender_part(prev_last_part_id) + end + end + else + buf.replace_part(part.id, formatted) + end + + -- File / agent mentions: re-render the text part to highlight them + if (part.type == 'file' or part.type == 'agent') and part.source then + local text_part_id = buf.find_text_part_for_message(message) + if text_part_id then + buf.rerender_part(text_part_id) + end + end +end + +---Handle message.part.removed +---@param properties {sessionID: string, messageID: string, partID: string} +function M.on_part_removed(properties) + if not properties then + return + end + + local part_id = properties.partID + if not part_id then + return + end + + -- Remove the part from the in-memory message too + local cached = ctx.render_state:get_part(part_id) + if cached and cached.message_id then + local rendered_message = ctx.render_state:get_message(cached.message_id) + if rendered_message and rendered_message.message and rendered_message.message.parts then + for i, part in ipairs(rendered_message.message.parts) do + if part.id == part_id then + table.remove(rendered_message.message.parts, i) + break + end + end + end + end + + buf.remove_part(part_id) +end + +---Handle session.updated — re-render the full session if the revert state changed +---@param properties {info: Session} +function M.on_session_updated(properties) + if not properties or not properties.info or not state.active_session then + return + end + + local updated_session = properties.info + if not updated_session.id or updated_session.id ~= state.active_session.id then + return + end + + local current_session = state.active_session + local revert_changed = not vim.deep_equal(current_session.revert, updated_session.revert) + + if not vim.deep_equal(current_session, updated_session) then + -- Set without emitting a change event to avoid a double re-render + state.store.set_raw('active_session', updated_session) + end + + if revert_changed then + require('opencode.ui.renderer')._render_full_session_data(state.messages) + end +end + +---Handle session.compacted +function M.on_session_compacted() + vim.notify('Session has been compacted') +end + +---Handle session.error +---@param properties {sessionID: string, error: table} +function M.on_session_error(properties) + if not properties or not properties.error then + return + end + if config.debug.enabled then + vim.notify('Session error: ' .. vim.inspect(properties.error)) + end +end + +---Handle permission.updated / permission.asked +---@param permission OpencodePermission +function M.on_permission_updated(permission) + local tool = permission.tool + local callID = tool and tool.callID or permission.callID + local messageID = tool and tool.messageID or permission.messageID + + if not permission or not messageID or not callID then + return + end + + if not state.pending_permissions then + state.renderer.set_pending_permissions({}) + end + + local existing_index = nil + for i, existing in ipairs(state.pending_permissions) do + if existing.id == permission.id then + existing_index = i + break + end + end + + state.renderer.update_pending_permissions(function(permissions) + if existing_index then + permissions[existing_index] = permission + else + table.insert(permissions, permission) + end + end) + + permission_window.add_permission(permission) + M.render_permissions_display() + buf.rerender_part('permission-display-part') + scroll(true) +end + +---Handle permission.replied — remove the resolved permission and update display +---@param properties {sessionID: string, permissionID?: string, requestID?: string, response: string} +function M.on_permission_replied(properties) + if not properties then + return + end + + local permission_id = properties.permissionID or properties.requestID + if not permission_id then + return + end + + permission_window.remove_permission(permission_id) + state.renderer.set_pending_permissions(vim.deepcopy(permission_window.get_all_permissions())) + + if #state.pending_permissions == 0 then + buf.remove_part('permission-display-part') + buf.remove_message('permission-display-message') + end + + buf.rerender_part('permission-display-part') +end + +---Handle question.asked — show the question picker UI +---@param properties OpencodeQuestionRequest +function M.on_question_asked(properties) + if not properties or not properties.id or not properties.questions then + return + end + require('opencode.ui.question_window').show_question(properties) +end + +---Handle file.edited — reload buffers and fire the hook +---@param properties {file: string} +function M.on_file_edited(properties) + vim.cmd('checktime') + if config.hooks and config.hooks.on_file_edited then + pcall(config.hooks.on_file_edited, properties.file) + end +end + +---Handle custom.restore_point.created +---@param properties RestorePointCreatedEvent +function M.on_restore_points(properties) + state.store.append('restore_points', properties.restore_point) + if not properties or not properties.restore_point or not properties.restore_point.from_snapshot_id then + return + end + local part = ctx.render_state:get_part_by_snapshot_id(properties.restore_point.from_snapshot_id) + if part then + M.on_part_updated({ part = part }) + end +end + +return M diff --git a/tests/helpers.lua b/tests/helpers.lua index 836d9480..1ec1a9c6 100644 --- a/tests/helpers.lua +++ b/tests/helpers.lua @@ -323,11 +323,10 @@ function M.normalize_namespace_ids(extmarks) end function M.capture_output(output_buf, namespace) - local renderer = require('opencode.ui.renderer') return { lines = vim.api.nvim_buf_get_lines(output_buf, 0, -1, false) or {}, extmarks = vim.api.nvim_buf_get_extmarks(output_buf, namespace, 0, -1, { details = true }) or {}, - actions = vim.deepcopy(renderer._render_state:get_all_actions()), + actions = vim.deepcopy(require('opencode.ui.renderer.ctx').render_state:get_all_actions()), } end diff --git a/tests/manual/renderer_replay.lua b/tests/manual/renderer_replay.lua index 9f3852bd..1c74cefa 100644 --- a/tests/manual/renderer_replay.lua +++ b/tests/manual/renderer_replay.lua @@ -194,7 +194,7 @@ function M.save_output(filename) local snapshot = { lines = lines, extmarks = M.normalize_namespace_ids(extmarks), - actions = vim.deepcopy(renderer._render_state:get_all_actions()), + actions = vim.deepcopy(require('opencode.ui.renderer.ctx').render_state:get_all_actions()), timestamp = os.time(), } diff --git a/tests/unit/cursor_tracking_spec.lua b/tests/unit/cursor_tracking_spec.lua index 43793dd6..866c754e 100644 --- a/tests/unit/cursor_tracking_spec.lua +++ b/tests/unit/cursor_tracking_spec.lua @@ -288,6 +288,7 @@ end) describe('renderer.scroll_to_bottom', function() local renderer = require('opencode.ui.renderer') + local ctx = require('opencode.ui.renderer.ctx') local output_window = require('opencode.ui.output_window') local buf, win @@ -309,14 +310,14 @@ describe('renderer.scroll_to_bottom', function() }) state.ui.set_windows({ output_win = win, output_buf = buf }) - renderer._prev_line_count = 50 + ctx.prev_line_count = 50 end) after_each(function() pcall(vim.api.nvim_win_close, win, true) pcall(vim.api.nvim_buf_delete, buf, { force = true }) state.ui.set_windows(nil) - renderer._prev_line_count = 0 + ctx.prev_line_count = 0 output_window.viewport_at_bottom = nil end) @@ -398,7 +399,8 @@ end) describe('renderer._add_message_to_buffer scrolling', function() local renderer = require('opencode.ui.renderer') - local formatter = require('opencode.ui.formatter') + local events = require('opencode.ui.renderer.events') + local ctx = require('opencode.ui.renderer.ctx') local stub = require('luassert.stub') local buf, win @@ -418,8 +420,8 @@ describe('renderer._add_message_to_buffer scrolling', function() state.ui.set_windows({ output_win = win, output_buf = buf }) state.session.set_active({ id = 'test-session' }) state.renderer.set_messages({}) - renderer._prev_line_count = 1 - renderer._render_state:reset() + ctx.prev_line_count = 1 + ctx.render_state:reset() end) after_each(function() @@ -428,8 +430,8 @@ describe('renderer._add_message_to_buffer scrolling', function() state.ui.set_windows(nil) state.session.set_active(nil) state.renderer.set_messages(nil) - renderer._prev_line_count = 0 - renderer._render_state:reset() + ctx.prev_line_count = 0 + ctx.render_state:reset() end) it('scrolls to bottom when user message is added', function() @@ -449,7 +451,7 @@ describe('renderer._add_message_to_buffer scrolling', function() scroll_called_with_force = force == true end) - renderer._add_message_to_buffer(user_message) + events.on_message_updated(user_message) assert.is_true(scroll_called_with_force) assert.stub(renderer.scroll_to_bottom).was_called_with(true) @@ -471,7 +473,7 @@ describe('renderer._add_message_to_buffer scrolling', function() stub(renderer, 'scroll_to_bottom') - renderer._add_message_to_buffer(assistant_message) + events.on_message_updated(assistant_message) assert.stub(renderer.scroll_to_bottom).was_not_called() @@ -492,7 +494,7 @@ describe('renderer._add_message_to_buffer scrolling', function() stub(renderer, 'scroll_to_bottom') - renderer._add_message_to_buffer(system_message) + events.on_message_updated(system_message) assert.stub(renderer.scroll_to_bottom).was_not_called() diff --git a/tests/unit/hooks_spec.lua b/tests/unit/hooks_spec.lua index 0368ce24..eeaecffb 100644 --- a/tests/unit/hooks_spec.lua +++ b/tests/unit/hooks_spec.lua @@ -2,6 +2,7 @@ local renderer = require('opencode.ui.renderer') local config = require('opencode.config') local state = require('opencode.state') local core = require('opencode.core') +local events = require('opencode.ui.renderer.events') local helpers = require('tests.helpers') local ui = require('opencode.ui.ui') @@ -39,7 +40,7 @@ describe('hooks', function() end local test_event = { file = '/test/file.lua' } - renderer.on_file_edited(test_event) + events.on_file_edited(test_event) assert.is_true(called) assert.are.equal('/test/file.lua', file_path) @@ -50,7 +51,7 @@ describe('hooks', function() local test_event = { file = '/test/file.lua' } assert.has_no.errors(function() - renderer.on_file_edited(test_event) + events.on_file_edited(test_event) end) end) @@ -61,7 +62,7 @@ describe('hooks', function() local test_event = { file = '/test/file.lua' } assert.has_no.errors(function() - renderer.on_file_edited(test_event) + events.on_file_edited(test_event) end) end) end) diff --git a/tests/unit/permission_integration_spec.lua b/tests/unit/permission_integration_spec.lua index 7e060852..f3745852 100644 --- a/tests/unit/permission_integration_spec.lua +++ b/tests/unit/permission_integration_spec.lua @@ -1,6 +1,7 @@ -local renderer = require('opencode.ui.renderer') local state = require('opencode.state') local permission_window = require('opencode.ui.permission_window') +local events = require('opencode.ui.renderer.events') +local ctx = require('opencode.ui.renderer.ctx') describe('permission_integration', function() local mock_update_permission_from_part @@ -15,8 +16,8 @@ describe('permission_integration', function() permission_window._dialog = nil permission_window._processing = false - renderer._render_state:reset() - renderer._prev_line_count = 0 + ctx.render_state:reset() + ctx.prev_line_count = 0 captured_calls = {} mock_update_permission_from_part = permission_window.update_permission_from_part @@ -47,7 +48,7 @@ describe('permission_integration', function() info = { id = 'msg_abc', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -64,7 +65,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(1, #captured_calls) assert.are.equal('per_test_123', captured_calls[1].permission_id) @@ -85,7 +86,7 @@ describe('permission_integration', function() info = { id = 'msg_legacy', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -101,7 +102,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(1, #captured_calls) assert.are.equal('per_legacy_456', captured_calls[1].permission_id) @@ -123,7 +124,7 @@ describe('permission_integration', function() info = { id = 'msg_abc', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -139,7 +140,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(0, #captured_calls) end) @@ -160,7 +161,7 @@ describe('permission_integration', function() info = { id = 'msg_different', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -176,7 +177,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(0, #captured_calls) end) @@ -197,7 +198,7 @@ describe('permission_integration', function() info = { id = 'msg_abc', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -208,7 +209,7 @@ describe('permission_integration', function() content = 'Some text content', } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(0, #captured_calls) end) @@ -220,7 +221,7 @@ describe('permission_integration', function() info = { id = 'msg_abc', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -236,7 +237,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(0, #captured_calls) end) @@ -273,7 +274,7 @@ describe('permission_integration', function() info = { id = 'msg_second', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -289,7 +290,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(1, #captured_calls) assert.are.equal('per_second', captured_calls[1].permission_id) @@ -319,7 +320,7 @@ describe('permission_integration', function() info = { id = 'msg_abc', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -335,7 +336,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(1, #captured_calls) assert.are.equal('per_first', captured_calls[1].permission_id) @@ -359,7 +360,7 @@ describe('permission_integration', function() info = { id = 'tool_msg_id', sessionID = 'session_123' }, parts = {}, } - renderer._render_state:set_message(message, 1, 1) + ctx.render_state:set_message(message, 1, 1) table.insert(state.messages, message) local part = { @@ -375,7 +376,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part }) + events.on_part_updated({ part = part }) assert.are.equal(1, #captured_calls) assert.are.equal('per_test_123', captured_calls[1].permission_id) @@ -395,7 +396,7 @@ describe('permission_integration', function() }, } - renderer.on_part_updated({ part = part_root }) + events.on_part_updated({ part = part_root }) assert.are.equal(0, #captured_calls) end) diff --git a/tests/unit/render_state_spec.lua b/tests/unit/render_state_spec.lua index c9ec06ac..9530e458 100644 --- a/tests/unit/render_state_spec.lua +++ b/tests/unit/render_state_spec.lua @@ -18,20 +18,20 @@ describe('RenderState', function() assert.is_not_nil(render_state) assert.is_table(render_state._messages) assert.is_table(render_state._parts) - assert.is_table(render_state._line_index) - assert.is_false(render_state._line_index_valid) + assert.is_table(render_state._part_ranges) + assert.is_false(render_state._ranges_valid) end) it('resets to empty state', function() render_state._messages = { test = true } render_state._parts = { test = true } - render_state._line_index_valid = true + render_state._ranges_valid = true render_state:reset() assert.is_true(vim.tbl_isempty(render_state._messages)) assert.is_true(vim.tbl_isempty(render_state._parts)) - assert.is_true(vim.tbl_isempty(render_state._line_index.line_to_part)) - assert.is_true(vim.tbl_isempty(render_state._line_index.line_to_message)) - assert.is_false(render_state._line_index_valid) + assert.is_true(vim.tbl_isempty(render_state._part_ranges)) + assert.is_true(vim.tbl_isempty(render_state._message_ranges)) + assert.is_false(render_state._ranges_valid) end) end) @@ -51,7 +51,7 @@ describe('RenderState', function() local msg = { info = { id = 'msg1' } } render_state:set_message(msg, 5, 7) - assert.is_false(render_state._line_index_valid) + assert.is_false(render_state._ranges_valid) local result = render_state:get_message_at_line(6) assert.is_not_nil(result) @@ -88,7 +88,7 @@ describe('RenderState', function() local part = { id = 'part1', messageID = 'msg1' } render_state:set_part(part, 20, 22) - assert.is_false(render_state._line_index_valid) + assert.is_false(render_state._ranges_valid) local result = render_state:get_part_at_line(21) assert.is_not_nil(result) @@ -304,6 +304,17 @@ describe('RenderState', function() local success = render_state:update_part_lines('nonexistent', 10, 20) assert.is_false(success) end) + + it('returns early when lines are unchanged', function() + local part = { id = 'part1', messageID = 'msg1' } + render_state:set_part(part, 10, 15) + render_state._ranges_valid = true + + local success = render_state:update_part_lines('part1', 10, 15) + + assert.is_true(success) + assert.is_true(render_state._ranges_valid) + end) end) describe('remove_part', function() @@ -474,22 +485,22 @@ describe('RenderState', function() local part = { id = 'part1', messageID = 'msg1' } render_state:set_part(part, 10, 15) - render_state._line_index_valid = true + render_state._ranges_valid = true render_state:shift_all(100, 5) - assert.is_true(render_state._line_index_valid) + assert.is_true(render_state._ranges_valid) end) it('invalidates index when content shifted', function() local part = { id = 'part1', messageID = 'msg1' } render_state:set_part(part, 10, 15) - render_state._line_index_valid = true + render_state._ranges_valid = true render_state:shift_all(10, 5) - assert.is_false(render_state._line_index_valid) + assert.is_false(render_state._ranges_valid) end) it('exits early when content found before from_line', function() @@ -506,6 +517,19 @@ describe('RenderState', function() local result2 = render_state:get_part('part2') assert.equals(60, result2.line_start) end) + + it('exits early when from_line is after max rendered line', function() + local part = { id = 'part1', messageID = 'msg1' } + render_state:set_part(part, 10, 15) + + render_state._ranges_valid = true + render_state:shift_all(100, 5) + + local result = render_state:get_part('part1') + assert.equals(10, result.line_start) + assert.equals(15, result.line_end) + assert.is_true(render_state._ranges_valid) + end) end) describe('update_part_data', function()