From 7f4b11476c1b17347cd6f8ddc43edf94b0dbd503 Mon Sep 17 00:00:00 2001 From: Cameron Ring Date: Thu, 4 Dec 2025 18:02:36 -0800 Subject: [PATCH] fix(core): invalid model error handling It looks like there was a change in opencode 1.0 that changed how invalid model errors are surfaced so this change fixes that. Also add unit test coverage for send_message changing state.user_message_count (and fix a bug with it not being updated in the catch) Hopefully fixes #131 --- lua/opencode/core.lua | 30 +++++++++------- tests/unit/core_spec.lua | 75 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 13 deletions(-) diff --git a/lua/opencode/core.lua b/lua/opencode/core.lua index 604fc056..a38905b0 100644 --- a/lua/opencode/core.lua +++ b/lua/opencode/core.lua @@ -147,31 +147,35 @@ function M.send_message(prompt, opts) params.parts = context.format_message(prompt, opts.context) M.before_run(opts) - -- Capture the session ID to ensure we track the message count for the correct session local session_id = state.active_session.id - local sent_message_count = vim.deepcopy(state.user_message_count) - sent_message_count[session_id] = (sent_message_count[session_id] or 0) + 1 - state.user_message_count = sent_message_count + + ---Helper to update state.user_message_count. Have to deepcopy since it's a table to make + ---sure notification events fire. Prevents negative values (in case of an untracked code path) + local function update_sent_message_count(num) + local sent_message_count = vim.deepcopy(state.user_message_count) + local new_value = (sent_message_count[session_id] or 0) + num + sent_message_count[session_id] = new_value >= 0 and new_value or 0 + state.user_message_count = sent_message_count + end + + update_sent_message_count(1) state.api_client :create_message(session_id, params) :and_then(function(response) + update_sent_message_count(-1) + if not response or not response.info or not response.parts then - -- fall back to full render. incremental render is handled - -- event manager - ui.render_output() + vim.notify('Invalid response from opencode: ' .. vim.inspect(response), vim.log.levels.ERROR) + M.cancel() + return end - local received_message_count = vim.deepcopy(state.user_message_count) - received_message_count[response.info.sessionID] = (received_message_count[response.info.sessionID] ~= nil) - and (received_message_count[response.info.sessionID] - 1) - or 0 - state.user_message_count = received_message_count - M.after_run(prompt) end) :catch(function(err) vim.notify('Error sending message to session: ' .. vim.inspect(err), vim.log.levels.ERROR) + update_sent_message_count(-1) M.cancel() end) end diff --git a/tests/unit/core_spec.lua b/tests/unit/core_spec.lua index f325cc5f..3267f6a5 100644 --- a/tests/unit/core_spec.lua +++ b/tests/unit/core_spec.lua @@ -271,6 +271,81 @@ describe('opencode.core', function() assert.equal(state.current_model, 'test/model') state.api_client.create_message = orig end) + + it('increments and decrements user_message_count correctly', function() + state.windows = { mock = 'windows' } + state.active_session = { id = 'sess1' } + state.user_message_count = {} + + -- Capture the count at different stages + local count_before = state.user_message_count['sess1'] or 0 + local count_during = nil + local count_after = nil + + local orig = state.api_client.create_message + state.api_client.create_message = function(_, sid, params) + -- Capture count while message is in flight + count_during = state.user_message_count['sess1'] + return Promise.new():resolve({ + id = 'm1', + info = { id = 'm1' }, + parts = {}, + }) + end + + core.send_message('hello world') + + -- Wait for promise to resolve + vim.wait(50, function() + count_after = state.user_message_count['sess1'] or 0 + return count_after == 0 + end) + + -- Verify: starts at 0, increments to 1, then back to 0 + assert.equal(0, count_before) + assert.equal(1, count_during) + assert.equal(0, count_after) + + state.api_client.create_message = orig + end) + + it('decrements user_message_count on error', function() + state.windows = { mock = 'windows' } + state.active_session = { id = 'sess1' } + state.user_message_count = {} + + -- Capture the count at different stages + local count_before = state.user_message_count['sess1'] or 0 + local count_during = nil + local count_after = nil + + local orig = state.api_client.create_message + state.api_client.create_message = function(_, sid, params) + -- Capture count while message is in flight + count_during = state.user_message_count['sess1'] + return Promise.new():reject('Test error') + end + + -- Stub cancel to prevent it from trying to abort the session + local orig_cancel = core.cancel + stub(core, 'cancel') + + core.send_message('hello world') + + -- Wait for promise to reject + vim.wait(50, function() + count_after = state.user_message_count['sess1'] or 0 + return count_after == 0 + end) + + -- Verify: starts at 0, increments to 1, then back to 0 even on error + assert.equal(0, count_before) + assert.equal(1, count_during) + assert.equal(0, count_after) + + state.api_client.create_message = orig + core.cancel = orig_cancel + end) end) describe('opencode_ok (version checks)', function()