diff --git a/lua/opencode/core.lua b/lua/opencode/core.lua index a0457cdf..aea4b478 100644 --- a/lua/opencode/core.lua +++ b/lua/opencode/core.lua @@ -106,8 +106,6 @@ M.open = Promise.async(function(opts) state.last_sent_context = nil context.unload_attachments() - state.current_model = nil - state.current_mode = nil M.ensure_current_mode():await() state.active_session = M.create_new_session():await() @@ -267,6 +265,12 @@ function M.configure_provider() local model_str = string.format('%s/%s', selection.provider, selection.model) state.current_model = model_str + if state.current_mode then + local mode_map = vim.deepcopy(state.user_mode_model_map) + mode_map[state.current_mode] = model_str + state.user_mode_model_map = mode_map + end + if state.windows then ui.focus_input() else @@ -456,8 +460,13 @@ M.switch_to_mode = Promise.async(function(mode) local agent_config = opencode_config and opencode_config.agent or {} local mode_config = agent_config[mode] or {} - if mode_config.model and mode_config.model ~= '' then + + if state.user_mode_model_map[mode] then + state.current_model = state.user_mode_model_map[mode] + elseif mode_config.model and mode_config.model ~= '' then state.current_model = mode_config.model + elseif opencode_config and opencode_config.model and opencode_config.model ~= '' then + state.current_model = opencode_config.model end return true end) diff --git a/lua/opencode/state.lua b/lua/opencode/state.lua index 0aa66c95..0f43666f 100644 --- a/lua/opencode/state.lua +++ b/lua/opencode/state.lua @@ -26,6 +26,7 @@ ---@field active_session Session|nil ---@field restore_points RestorePoint[] ---@field current_model string|nil +---@field user_mode_model_map table ---@field current_model_info table|nil ---@field current_variant string|nil ---@field messages OpencodeMessage[]|nil @@ -76,6 +77,7 @@ local _state = { active_session = nil, restore_points = {}, current_model = nil, + user_mode_model_map = {}, current_model_info = nil, current_variant = nil, -- messages diff --git a/tests/unit/core_spec.lua b/tests/unit/core_spec.lua index 98819463..94e5f843 100644 --- a/tests/unit/core_spec.lua +++ b/tests/unit/core_spec.lua @@ -481,6 +481,7 @@ describe('opencode.core', function() model = 'anthropic/claude-3-opus', }, }, + model = 'gpt-4', }) stub(config_file, 'get_opencode_agents').returns(agents_promise) @@ -488,6 +489,7 @@ describe('opencode.core', function() state.current_mode = nil state.current_model = nil + state.user_mode_model_map = {} local promise = core.switch_to_mode('custom') local success = promise:wait() @@ -500,57 +502,82 @@ describe('opencode.core', function() config_file.get_opencode_config:revert() end) - it('does not change current model when mode has no model configured', function() + it('returns false when mode is invalid', function() + local Promise = require('opencode.promise') + local agents_promise = Promise.new() + agents_promise:resolve({ 'plan', 'build' }) + + stub(config_file, 'get_opencode_agents').returns(agents_promise) + + local promise = core.switch_to_mode('nonexistent') + local success = promise:wait() + + assert.is_false(success) + + config_file.get_opencode_agents:revert() + end) + + it('returns false when mode is empty', function() + local promise = core.switch_to_mode('') + local success = promise:wait() + assert.is_false(success) + + promise = core.switch_to_mode(nil) + success = promise:wait() + assert.is_false(success) + end) + + it('respects user_mode_model_map priority: uses model stored in mode_model_map for mode', function() local Promise = require('opencode.promise') local agents_promise = Promise.new() agents_promise:resolve({ 'plan', 'build' }) local config_promise = Promise.new() config_promise:resolve({ agent = { - plan = {}, + plan = { model = 'gpt-4' }, }, + model = 'gpt-3', }) - stub(config_file, 'get_opencode_agents').returns(agents_promise) stub(config_file, 'get_opencode_config').returns(config_promise) state.current_mode = nil - state.current_model = 'existing/model' + state.current_model = 'should-be-overridden' + state.user_mode_model_map = { plan = 'anthropic/claude-3-haiku' } local promise = core.switch_to_mode('plan') local success = promise:wait() - assert.is_true(success) assert.equal('plan', state.current_mode) - assert.equal('existing/model', state.current_model) + assert.equal('anthropic/claude-3-haiku', state.current_model) config_file.get_opencode_agents:revert() config_file.get_opencode_config:revert() end) - it('returns false when mode is invalid', function() + it('falls back to config model if nothing else matches', function() local Promise = require('opencode.promise') local agents_promise = Promise.new() agents_promise:resolve({ 'plan', 'build' }) - + local config_promise = Promise.new() + config_promise:resolve({ + agent = { + plan = {}, + }, + model = 'default-model', + }) stub(config_file, 'get_opencode_agents').returns(agents_promise) - - local promise = core.switch_to_mode('nonexistent') + stub(config_file, 'get_opencode_config').returns(config_promise) + state.current_mode = nil + state.current_model = 'old-model' + state.user_mode_model_map = {} + local promise = core.switch_to_mode('plan') local success = promise:wait() - - assert.is_false(success) - + assert.is_true(success) + assert.equal('plan', state.current_mode) + assert.equal('default-model', state.current_model) config_file.get_opencode_agents:revert() - end) - - it('returns false when mode is empty', function() - local promise = core.switch_to_mode('') - local success = promise:wait() - assert.is_false(success) - - promise = core.switch_to_mode(nil) - success = promise:wait() - assert.is_false(success) + config_file.get_opencode_config:revert() end) end) end)