From fb737409ac42c1dfe9f614ba45865d605faba551 Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Wed, 24 Jun 2026 16:22:30 +0300 Subject: [PATCH] feat: AI inline autocompletion with Codestral --- cmd/micro/micro.go | 2 + internal/action/actions.go | 18 ++ internal/action/bufpane.go | 139 +++++++++++ internal/action/defaults_darwin.go | 1 + internal/action/defaults_other.go | 1 + internal/ai/codestral.go | 116 ++++++++++ internal/ai/codestral_test.go | 148 ++++++++++++ internal/ai/manager.go | 202 ++++++++++++++++ internal/ai/manager_test.go | 358 +++++++++++++++++++++++++++++ internal/ai/providers.go | 19 ++ internal/buffer/buffer.go | 43 ++++ internal/config/settings.go | 7 +- internal/display/bufwindow.go | 46 ++++ 13 files changed, 1099 insertions(+), 1 deletion(-) create mode 100644 internal/ai/codestral.go create mode 100644 internal/ai/codestral_test.go create mode 100644 internal/ai/manager.go create mode 100644 internal/ai/manager_test.go create mode 100644 internal/ai/providers.go diff --git a/cmd/micro/micro.go b/cmd/micro/micro.go index c8a99b2517..3f7b4d051b 100644 --- a/cmd/micro/micro.go +++ b/cmd/micro/micro.go @@ -516,6 +516,8 @@ func DoEvent() { } case f := <-timerChan: f() + case f := <-action.AIResultChan: + f() case <-sighup: exit(0) case <-util.Sigterm: diff --git a/internal/action/actions.go b/internal/action/actions.go index 491199f76f..5d2db5b70c 100644 --- a/internal/action/actions.go +++ b/internal/action/actions.go @@ -907,6 +907,10 @@ func (h *BufPane) OutdentSelection() bool { func (h *BufPane) Autocomplete() bool { b := h.Buf + if b.InlineCompletion != "" { + return false + } + if h.Cursor.HasSelection() { return false } @@ -944,6 +948,9 @@ func (h *BufPane) CycleAutocompleteBack() bool { // InsertTab inserts a tab or spaces func (h *BufPane) InsertTab() bool { + if h.acceptAICompletion() { + return true + } b := h.Buf indent := b.IndentString(util.IntOpt(b.Settings["tabsize"])) tabBytes := len(indent) @@ -1874,6 +1881,10 @@ func (h *BufPane) ToggleOverwriteMode() bool { // Escape leaves current mode func (h *BufPane) Escape() bool { + if h.Buf.InlineCompletion != "" { + h.dismissAICompletion() + return true + } return true } @@ -2346,3 +2357,10 @@ func (h *BufPane) RemoveAllMultiCursors() bool { func (h *BufPane) None() bool { return true } + +// ManualTrigger immediately triggers AI completion (no debounce). +// Bound to Ctrl-Space by default. +func (h *BufPane) ManualTrigger() bool { + h.triggerAICompletionNow() + return true +} diff --git a/internal/action/bufpane.go b/internal/action/bufpane.go index df6a20f8d8..4febb2726d 100644 --- a/internal/action/bufpane.go +++ b/internal/action/bufpane.go @@ -1,11 +1,15 @@ package action import ( + "context" + "log" "strings" + "sync" "time" luar "layeh.com/gopher-luar" + "github.com/micro-editor/micro/v2/internal/ai" "github.com/micro-editor/micro/v2/internal/buffer" "github.com/micro-editor/micro/v2/internal/config" "github.com/micro-editor/micro/v2/internal/display" @@ -16,6 +20,35 @@ import ( lua "github.com/yuin/gopher-lua" ) +// AIResultChan delivers AI completion results to the main event loop. +// The main loop reads from this channel and applies the result on the main thread. +var AIResultChan = make(chan func(), 16) + +var aiManager *ai.Manager +var aiManagerErr error +var aiManagerOnce sync.Once + +func getAIManager() *ai.Manager { + aiManagerOnce.Do(func() { + if !config.GetGlobalOption("aicomplete").(bool) { + return + } + provider := config.GetGlobalOption("aicompleteprovider").(string) + model := config.GetGlobalOption("aicompletemodel").(string) + baseURL := config.GetGlobalOption("aicompleteurl").(string) + debounce := config.GetGlobalOption("aicompletedebounce").(float64) + aiManager, aiManagerErr = ai.NewManager(provider, model, baseURL, debounce) + if aiManagerErr != nil { + log.Println("AI completion:", aiManagerErr) + InfoBar.Error("AI completion: ", aiManagerErr.Error()) + } + }) + if aiManagerErr != nil { + return nil + } + return aiManager +} + type BufAction any // BufKeyAction represents an action bound to a key. @@ -254,6 +287,9 @@ type BufPane struct { // since we may not know the window geometry yet. In such case we finish // its initialization a bit later, after the initial resize. initialized bool + + // AI inline completion state + aiCancel context.CancelFunc } func newBufPane(buf *buffer.Buffer, win display.BWindow, tab *Tab) *BufPane { @@ -345,6 +381,7 @@ func (h *BufPane) resetMouse() { // OpenBuffer opens the given buffer in this pane. func (h *BufPane) OpenBuffer(b *buffer.Buffer) { + h.dismissAICompletion() h.Buf.Close() h.Buf = b h.BWindow.SetBuffer(b) @@ -357,6 +394,100 @@ func (h *BufPane) OpenBuffer(b *buffer.Buffer) { h.lastClickTime = time.Time{} } +func (h *BufPane) dismissAICompletion() { + h.Buf.InlineCompletion = "" + if h.aiCancel != nil { + h.aiCancel() + h.aiCancel = nil + } +} + +func (h *BufPane) triggerAICompletion() { + mgr := getAIManager() + if mgr == nil { + if aiManagerErr != nil { + InfoBar.Error("AI completion: ", aiManagerErr.Error()) + } + return + } + + if !config.GetGlobalOption("aicomplete").(bool) { + return + } + + h.dismissAICompletion() + InfoBar.Message("") + + before := h.Buf.TextBeforeCursor() + after := h.Buf.TextAfterCursor() + fileType := h.Buf.Settings["filetype"].(string) + + ctx, cancel := context.WithCancel(context.Background()) + h.aiCancel = cancel + + mgr.RequestDelayed(ai.Request{ + BeforeCursor: before, + AfterCursor: after, + FileType: fileType, + FileName: h.Buf.AbsPath, + }, func(resp *ai.Response, err error) { + if ctx.Err() != nil { + return + } + if err != nil { + AIResultChan <- func() { + InfoBar.Message("AI: ", err.Error()) + } + return + } + AIResultChan <- func() { + h.Buf.InlineCompletion = resp.Text + screen.Redraw() + } + }) +} + +func (h *BufPane) triggerAICompletionNow() { + mgr := getAIManager() + if mgr == nil { + if aiManagerErr != nil { + InfoBar.Error("AI completion: ", aiManagerErr.Error()) + } + return + } + if !config.GetGlobalOption("aicomplete").(bool) { + return + } + h.dismissAICompletion() + InfoBar.Message("") + + resp, err := mgr.RequestNow(ai.Request{ + BeforeCursor: h.Buf.TextBeforeCursor(), + AfterCursor: h.Buf.TextAfterCursor(), + FileType: h.Buf.Settings["filetype"].(string), + FileName: h.Buf.AbsPath, + }) + if err != nil { + InfoBar.Message("AI: ", err.Error()) + return + } + if resp != nil { + h.Buf.InlineCompletion = resp.Text + screen.Redraw() + } +} + +func (h *BufPane) acceptAICompletion() bool { + text := h.Buf.InlineCompletion + if text == "" { + return false + } + h.dismissAICompletion() + h.Buf.Insert(h.Cursor.Loc, text) + h.Relocate() + return true +} + // GotoLoc moves the cursor to a new location and adjusts the view accordingly. // Use GotoLoc when the new location may be far away from the current location. func (h *BufPane) GotoLoc(loc buffer.Loc) { @@ -556,6 +687,12 @@ func (h *BufPane) execAction(action BufAction, name string, te *tcell.EventMouse if name != "Autocomplete" && name != "CycleAutocompleteBack" { h.Buf.HasSuggestions = false } + if name != "Autocomplete" && name != "CycleAutocompleteBack" && + name != "Escape" && name != "InsertTab" && name != "ManualTrigger" && + name != "IndentSelection" && name != "OutdentSelection" && + name != "OutdentLine" { + h.dismissAICompletion() + } if !h.PluginCB("pre"+name, te) { return false @@ -645,6 +782,7 @@ func (h *BufPane) DoRuneInsert(r rune) { h.Relocate() h.PluginCB("onRune", string(r)) } + h.triggerAICompletion() } // VSplitIndex opens the given buffer in a vertical split on the given side. @@ -851,6 +989,7 @@ var BufKeyActions = map[string]BufKeyAction{ "Deselect": (*BufPane).Deselect, "ClearInfo": (*BufPane).ClearInfo, "None": (*BufPane).None, + "ManualTrigger": (*BufPane).ManualTrigger, // This was changed to InsertNewline but I don't want to break backwards compatibility "InsertEnter": (*BufPane).InsertNewline, diff --git a/internal/action/defaults_darwin.go b/internal/action/defaults_darwin.go index 44b4fcc1ae..db2f645fb5 100644 --- a/internal/action/defaults_darwin.go +++ b/internal/action/defaults_darwin.go @@ -76,6 +76,7 @@ var bufdefaults = map[string]string{ "Ctrl-u": "ToggleMacro", "Ctrl-j": "PlayMacro", "Insert": "ToggleOverwriteMode", + "Ctrl-Space": "ManualTrigger", // Emacs-style keybindings "Alt-f": "WordRight", diff --git a/internal/action/defaults_other.go b/internal/action/defaults_other.go index 10e5b08358..fe44cf29ef 100644 --- a/internal/action/defaults_other.go +++ b/internal/action/defaults_other.go @@ -79,6 +79,7 @@ var bufdefaults = map[string]string{ "Ctrl-u": "ToggleMacro", "Ctrl-j": "PlayMacro", "Insert": "ToggleOverwriteMode", + "Ctrl-Space": "ManualTrigger", // Emacs-style keybindings "Alt-f": "WordRight", diff --git a/internal/ai/codestral.go b/internal/ai/codestral.go new file mode 100644 index 0000000000..d2afe193a2 --- /dev/null +++ b/internal/ai/codestral.go @@ -0,0 +1,116 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" +) + +const ( + codestralDefaultBaseURL = "https://codestral.mistral.ai/v1/fim/completions" + codestralDefaultModel = "codestral-latest" + codestralEnvKey = "MISTRAL_API_KEY" +) + +type CodestralProvider struct { + apiKey string + model string + baseURL string + client *http.Client +} + +func NewCodestralProvider(model, baseURL string) (*CodestralProvider, error) { + apiKey := os.Getenv(codestralEnvKey) + if apiKey == "" { + return nil, fmt.Errorf("codestral: %s environment variable not set", codestralEnvKey) + } + if model == "" { + model = codestralDefaultModel + } + if baseURL == "" { + baseURL = codestralDefaultBaseURL + } + return &CodestralProvider{ + apiKey: apiKey, + model: model, + baseURL: baseURL, + client: &http.Client{ + Timeout: 10 * time.Second, + }, + }, nil +} + +func (p *CodestralProvider) Name() string { + return "codestral" +} + +type codestralRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Suffix string `json:"suffix"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` + Stop []string `json:"stop,omitempty"` +} + +type codestralResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` +} + +func (p *CodestralProvider) Complete(ctx context.Context, req Request) (*Response, error) { + body := codestralRequest{ + Model: p.model, + Prompt: req.BeforeCursor, + Suffix: req.AfterCursor, + MaxTokens: 256, + Temperature: 0, + Stop: []string{"\n\n"}, + } + + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("codestral: marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL, bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("codestral: create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("codestral: http request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("codestral: read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("codestral: status %d: %s", resp.StatusCode, string(respBody)) + } + + var cresp codestralResponse + if err := json.Unmarshal(respBody, &cresp); err != nil { + return nil, fmt.Errorf("codestral: parse response: %w", err) + } + + if len(cresp.Choices) == 0 { + return nil, fmt.Errorf("codestral: no choices in response") + } + + return &Response{Text: cresp.Choices[0].Message.Content}, nil +} diff --git a/internal/ai/codestral_test.go b/internal/ai/codestral_test.go new file mode 100644 index 0000000000..53ae5978c8 --- /dev/null +++ b/internal/ai/codestral_test.go @@ -0,0 +1,148 @@ +package ai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCodestralComplete(t *testing.T) { + expected := "return x + 1" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + var req codestralRequest + assert.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + assert.Equal(t, "func foo()", req.Prompt) + assert.Equal(t, "}", req.Suffix) + assert.Equal(t, "codestral-latest", req.Model) + + resp := codestralResponse{ + Choices: []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + }{ + {Message: struct { + Content string `json:"content"` + }{Content: expected}}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + t.Setenv("MISTRAL_API_KEY", "test-key") + + p, err := NewCodestralProvider("", server.URL) + assert.NoError(t, err) + + resp, err := p.Complete(context.Background(), Request{ + BeforeCursor: "func foo()", + AfterCursor: "}", + }) + assert.NoError(t, err) + assert.Equal(t, expected, resp.Text) +} + +func TestCodestralEmptyChoices(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := codestralResponse{Choices: []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + }{}} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + t.Setenv("MISTRAL_API_KEY", "test-key") + + p, err := NewCodestralProvider("", server.URL) + assert.NoError(t, err) + + _, err = p.Complete(context.Background(), Request{ + BeforeCursor: "func foo()", + AfterCursor: "}", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no choices") +} + +func TestCodestralNon200(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"invalid token"}`)) + })) + defer server.Close() + + t.Setenv("MISTRAL_API_KEY", "test-key") + + p, err := NewCodestralProvider("", server.URL) + assert.NoError(t, err) + + _, err = p.Complete(context.Background(), Request{ + BeforeCursor: "func foo()", + AfterCursor: "}", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "status 401") +} + +func TestCodestralMissingKey(t *testing.T) { + t.Setenv("MISTRAL_API_KEY", "") + + _, err := NewCodestralProvider("", "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "MISTRAL_API_KEY") +} + +func TestCodestralCustomModel(t *testing.T) { + t.Setenv("MISTRAL_API_KEY", "test-key") + + p, err := NewCodestralProvider("my-custom-model", "") + assert.NoError(t, err) + assert.Equal(t, "my-custom-model", p.model) +} + +func TestCodestralCustomBaseURL(t *testing.T) { + t.Setenv("MISTRAL_API_KEY", "test-key") + + p, err := NewCodestralProvider("", "https://example.com/v1/fim") + assert.NoError(t, err) + assert.Equal(t, "https://example.com/v1/fim", p.baseURL) +} + +func TestCodestralContextCancel(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select {} + })) + defer server.Close() + + t.Setenv("MISTRAL_API_KEY", "test-key") + + p, err := NewCodestralProvider("", server.URL) + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = p.Complete(ctx, Request{BeforeCursor: "func"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") +} + +func TestCodestralName(t *testing.T) { + t.Setenv("MISTRAL_API_KEY", "test-key") + + p, err := NewCodestralProvider("", "") + assert.NoError(t, err) + assert.Equal(t, "codestral", p.Name()) +} diff --git a/internal/ai/manager.go b/internal/ai/manager.go new file mode 100644 index 0000000000..3a506aba17 --- /dev/null +++ b/internal/ai/manager.go @@ -0,0 +1,202 @@ +package ai + +import ( + "context" + "fmt" + "hash/fnv" + "sync" + "time" +) + +const ( + maxCacheEntries = 16 + cacheTTL = 5 * time.Second +) + +type cacheEntry struct { + resp *Response + err error + expiresAt time.Time +} + +type Manager struct { + mu sync.Mutex + provider Provider + + cancelPrev context.CancelFunc + timer *time.Timer + debounce time.Duration + + cache map[uint64]*cacheEntry + cacheOrder []uint64 + cacheMu sync.Mutex +} + +func NewManager(providerName, model, baseURL string, debounceMs float64) (*Manager, error) { + provider, err := newProvider(providerName, model, baseURL) + if err != nil { + return nil, err + } + if provider == nil { + return nil, fmt.Errorf("unknown AI provider: %q", providerName) + } + + return &Manager{ + provider: provider, + debounce: time.Duration(debounceMs) * time.Millisecond, + cache: make(map[uint64]*cacheEntry), + cacheOrder: make([]uint64, 0, maxCacheEntries), + }, nil +} + +func newProvider(name, model, baseURL string) (Provider, error) { + switch name { + case "codestral": + return NewCodestralProvider(model, baseURL) + default: + return nil, nil + } +} + +func (m *Manager) Provider() Provider { + return m.provider +} + +func cacheKey(req Request) uint64 { + h := fnv.New64a() + h.Write([]byte(req.BeforeCursor)) + h.Write([]byte{0}) + h.Write([]byte(req.AfterCursor)) + h.Write([]byte{0}) + h.Write([]byte(req.FileType)) + return h.Sum64() +} + +func (m *Manager) getCache(key uint64) (*cacheEntry, bool) { + m.cacheMu.Lock() + defer m.cacheMu.Unlock() + if m.cache == nil { + return nil, false + } + entry, ok := m.cache[key] + if !ok { + return nil, false + } + if time.Now().After(entry.expiresAt) { + delete(m.cache, key) + m.removeCacheOrder(key) + return nil, false + } + m.touchCache(key) + return entry, true +} + +func (m *Manager) touchCache(key uint64) { + for i, k := range m.cacheOrder { + if k == key { + m.cacheOrder = append(m.cacheOrder[:i], m.cacheOrder[i+1:]...) + m.cacheOrder = append(m.cacheOrder, key) + return + } + } +} + +func (m *Manager) removeCacheOrder(key uint64) { + for i, k := range m.cacheOrder { + if k == key { + m.cacheOrder = append(m.cacheOrder[:i], m.cacheOrder[i+1:]...) + return + } + } +} + +func (m *Manager) setCache(key uint64, resp *Response, err error) { + // Don't cache errors (transient network failures) + if err != nil { + return + } + if m.cache == nil { + m.cache = make(map[uint64]*cacheEntry) + } + if len(m.cache) >= maxCacheEntries { + delete(m.cache, m.cacheOrder[0]) + m.cacheOrder = m.cacheOrder[1:] + } + m.cache[key] = &cacheEntry{ + resp: resp, + err: err, + expiresAt: time.Now().Add(cacheTTL), + } + m.cacheOrder = append(m.cacheOrder, key) +} + +func (m *Manager) Cancel() { + m.mu.Lock() + defer m.mu.Unlock() + if m.cancelPrev != nil { + m.cancelPrev() + m.cancelPrev = nil + } + if m.timer != nil { + m.timer.Stop() + m.timer = nil + } +} + +func (m *Manager) RequestDelayed(req Request, onResult func(*Response, error)) { + key := cacheKey(req) + + if entry, ok := m.getCache(key); ok { + onResult(entry.resp, entry.err) + return + } + + m.mu.Lock() + + if m.cancelPrev != nil { + m.cancelPrev() + } + if m.timer != nil { + m.timer.Stop() + } + + ctx, cancel := context.WithCancel(context.Background()) + m.cancelPrev = cancel + + m.timer = time.AfterFunc(m.debounce, func() { + if ctx.Err() != nil { + return + } + + if entry, ok := m.getCache(key); ok { + onResult(entry.resp, entry.err) + return + } + + resp, err := m.provider.Complete(ctx, req) + if ctx.Err() != nil { + return + } + + m.cacheMu.Lock() + m.setCache(key, resp, err) + m.cacheMu.Unlock() + + onResult(resp, err) + }) + + m.mu.Unlock() +} + +func (m *Manager) RequestNow(req Request) (*Response, error) { + m.Cancel() + + resp, err := m.provider.Complete(context.Background(), req) + if err == nil && resp != nil { + key := cacheKey(req) + m.cacheMu.Lock() + m.setCache(key, resp, nil) + m.cacheMu.Unlock() + } + return resp, err +} diff --git a/internal/ai/manager_test.go b/internal/ai/manager_test.go new file mode 100644 index 0000000000..2646fbf5f9 --- /dev/null +++ b/internal/ai/manager_test.go @@ -0,0 +1,358 @@ +package ai + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type mockProvider struct { + name string + complete func(ctx context.Context, req Request) (*Response, error) +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) Complete(ctx context.Context, req Request) (*Response, error) { + if m.complete != nil { + return m.complete(ctx, req) + } + return &Response{Text: "mock"}, nil +} + +func TestNewManager(t *testing.T) { + t.Run("unknown provider", func(t *testing.T) { + m, err := NewManager("unknown", "", "", 100) + assert.Nil(t, m) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown AI provider") + }) + + t.Run("codestral missing key", func(t *testing.T) { + t.Setenv("MISTRAL_API_KEY", "") + m, err := NewManager("codestral", "", "", 100) + assert.Nil(t, m) + assert.Error(t, err) + assert.Contains(t, err.Error(), "MISTRAL_API_KEY") + }) +} + +func TestManagerRequestNow(t *testing.T) { + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + return &Response{Text: "hello"}, nil + }, + }, + debounce: 10 * time.Millisecond, + } + + resp, err := m.RequestNow(Request{BeforeCursor: "foo"}) + assert.NoError(t, err) + assert.Equal(t, "hello", resp.Text) +} + +func TestManagerRequestNowError(t *testing.T) { + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + return nil, errors.New("api error") + }, + }, + } + + resp, err := m.RequestNow(Request{BeforeCursor: "foo"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "api error") + assert.Nil(t, resp) +} + +func TestManagerRequestDelayedFires(t *testing.T) { + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + return &Response{Text: "delayed"}, nil + }, + }, + debounce: 20 * time.Millisecond, + } + + var result *Response + var resultErr error + done := make(chan struct{}) + + m.RequestDelayed(Request{BeforeCursor: "foo"}, func(resp *Response, err error) { + result = resp + resultErr = err + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("RequestDelayed did not fire") + } + + assert.NoError(t, resultErr) + assert.Equal(t, "delayed", result.Text) +} + +func TestManagerRequestDelayedDebounce(t *testing.T) { + var callCount int32 + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + atomic.AddInt32(&callCount, 1) + return &Response{Text: "delayed"}, nil + }, + }, + debounce: 50 * time.Millisecond, + } + + // Fire multiple requests quickly — only the last one should execute + m.RequestDelayed(Request{BeforeCursor: "a"}, func(*Response, error) {}) + m.RequestDelayed(Request{BeforeCursor: "b"}, func(*Response, error) {}) + m.RequestDelayed(Request{BeforeCursor: "c"}, func(*Response, error) {}) + + time.Sleep(200 * time.Millisecond) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) +} + +func TestManagerCancelStopsPending(t *testing.T) { + var called bool + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, + debounce: 10 * time.Millisecond, + } + + var resultErr error + m.RequestDelayed(Request{BeforeCursor: "foo"}, func(resp *Response, err error) { + called = true + resultErr = err + }) + + time.Sleep(20 * time.Millisecond) // wait for timer to fire + m.Cancel() + time.Sleep(20 * time.Millisecond) + + // Completion should not have been called (context was cancelled) + assert.False(t, called) + assert.Nil(t, resultErr) +} + +func TestManagerCancelBetweenRequests(t *testing.T) { + var callCount int32 + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + atomic.AddInt32(&callCount, 1) + return &Response{Text: req.BeforeCursor}, nil + }, + }, + debounce: 30 * time.Millisecond, + } + + // Send request A, cancel immediately via RequestNow + m.RequestDelayed(Request{BeforeCursor: "a"}, func(*Response, error) {}) + m.RequestNow(Request{BeforeCursor: "now"}) + + time.Sleep(100 * time.Millisecond) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) +} + +func TestManagerCacheHit(t *testing.T) { + var callCount int32 + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + atomic.AddInt32(&callCount, 1) + return &Response{Text: "cached"}, nil + }, + }, + debounce: 10 * time.Millisecond, + cache: make(map[uint64]*cacheEntry), + cacheOrder: make([]uint64, 0, maxCacheEntries), + } + + // First request populates cache + done := make(chan struct{}) + m.RequestDelayed(Request{BeforeCursor: "foo", FileType: "go"}, func(resp *Response, err error) { + close(done) + }) + <-done + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) + + // Second request with same context should hit cache + m.RequestDelayed(Request{BeforeCursor: "foo", FileType: "go"}, func(resp *Response, err error) { + assert.Equal(t, "cached", resp.Text) + }) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "should not call provider again") +} + +func TestManagerCacheMissDifferentContext(t *testing.T) { + var callCount int32 + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + atomic.AddInt32(&callCount, 1) + return &Response{Text: req.BeforeCursor}, nil + }, + }, + debounce: 10 * time.Millisecond, + cache: make(map[uint64]*cacheEntry), + cacheOrder: make([]uint64, 0, maxCacheEntries), + } + + done := make(chan struct{}, 2) + m.RequestDelayed(Request{BeforeCursor: "foo", FileType: "go"}, func(resp *Response, err error) { + done <- struct{}{} + }) + <-done + + m.RequestDelayed(Request{BeforeCursor: "bar", FileType: "go"}, func(resp *Response, err error) { + done <- struct{}{} + }) + <-done + + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount)) +} + +func TestManagerCacheExpiry(t *testing.T) { + var callCount int32 + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + atomic.AddInt32(&callCount, 1) + return &Response{Text: "fresh"}, nil + }, + }, + debounce: 10 * time.Millisecond, + cache: make(map[uint64]*cacheEntry), + cacheOrder: make([]uint64, 0, maxCacheEntries), + } + + // Force very short TTL via direct cache insertion + key := cacheKey(Request{BeforeCursor: "foo", FileType: "go"}) + m.cacheMu.Lock() + m.cache[key] = &cacheEntry{ + resp: &Response{Text: "stale"}, + expiresAt: time.Now().Add(-time.Second), + } + m.cacheOrder = append(m.cacheOrder, key) + m.cacheMu.Unlock() + + done := make(chan struct{}) + m.RequestDelayed(Request{BeforeCursor: "foo", FileType: "go"}, func(resp *Response, err error) { + close(done) + }) + <-done + + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "expired entry should not be used") +} + +func TestManagerRequestNowBypassesCache(t *testing.T) { + var callCount int32 + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + atomic.AddInt32(&callCount, 1) + return &Response{Text: "fresh"}, nil + }, + }, + debounce: 10 * time.Millisecond, + cache: make(map[uint64]*cacheEntry), + cacheOrder: make([]uint64, 0, maxCacheEntries), + } + + // Pre-populate cache + key := cacheKey(Request{BeforeCursor: "foo", FileType: "go"}) + m.cacheMu.Lock() + m.cache[key] = &cacheEntry{ + resp: &Response{Text: "stale"}, + expiresAt: time.Now().Add(time.Hour), + } + m.cacheOrder = append(m.cacheOrder, key) + m.cacheMu.Unlock() + + // RequestNow always fetches fresh + resp, err := m.RequestNow(Request{BeforeCursor: "foo", FileType: "go"}) + assert.NoError(t, err) + assert.Equal(t, "fresh", resp.Text) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) +} + +func TestManagerCacheEviction(t *testing.T) { + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + return &Response{Text: req.BeforeCursor}, nil + }, + }, + debounce: 5 * time.Millisecond, + cache: make(map[uint64]*cacheEntry), + cacheOrder: make([]uint64, 0, maxCacheEntries), + } + + // Fill cache to max + var wg sync.WaitGroup + for i := 0; i < maxCacheEntries; i++ { + wg.Add(1) + key := string(rune('a' + i)) + m.RequestDelayed(Request{BeforeCursor: key, FileType: "go"}, func(resp *Response, err error) { + wg.Done() + }) + time.Sleep(20 * time.Millisecond) + } + wg.Wait() + + assert.Equal(t, maxCacheEntries, len(m.cache)) + + // One more request should evict oldest + wg.Add(1) + m.RequestDelayed(Request{BeforeCursor: "zzz", FileType: "go"}, func(resp *Response, err error) { + wg.Done() + }) + time.Sleep(50 * time.Millisecond) + wg.Wait() + + assert.Equal(t, maxCacheEntries, len(m.cache)) + // Oldest entry should be gone + oldestKey := cacheKey(Request{BeforeCursor: "a", FileType: "go"}) + _, ok := m.cache[oldestKey] + assert.False(t, ok, "oldest cache entry should have been evicted") +} + +func TestManagerConcurrentRequests(t *testing.T) { + m := &Manager{ + provider: &mockProvider{ + complete: func(ctx context.Context, req Request) (*Response, error) { + time.Sleep(10 * time.Millisecond) + return &Response{Text: req.BeforeCursor}, nil + }, + }, + debounce: 5 * time.Millisecond, + cache: make(map[uint64]*cacheEntry), + cacheOrder: make([]uint64, 0, maxCacheEntries), + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + m.RequestNow(Request{BeforeCursor: string(rune('a' + n))}) + }(i) + } + wg.Wait() + + assert.Equal(t, 10, len(m.cache)) +} diff --git a/internal/ai/providers.go b/internal/ai/providers.go new file mode 100644 index 0000000000..f408fc242a --- /dev/null +++ b/internal/ai/providers.go @@ -0,0 +1,19 @@ +package ai + +import "context" + +type Request struct { + BeforeCursor string + AfterCursor string + FileType string + FileName string +} + +type Response struct { + Text string +} + +type Provider interface { + Complete(ctx context.Context, req Request) (*Response, error) + Name() string +} diff --git a/internal/buffer/buffer.go b/internal/buffer/buffer.go index d1f8db4b2b..2ebd579739 100644 --- a/internal/buffer/buffer.go +++ b/internal/buffer/buffer.go @@ -111,6 +111,10 @@ type SharedBuffer struct { // it changes based on how the buffer has changed HasSuggestions bool + // InlineCompletion is ghost text suggested by an AI provider. + // It is drawn after the cursor and cleared on any buffer mutation. + InlineCompletion string + // The Highlighter struct actually performs the highlighting Highlighter *highlight.Highlighter // SyntaxDef represents the syntax highlighting definition being used @@ -125,6 +129,7 @@ type SharedBuffer struct { func (b *SharedBuffer) insert(pos Loc, value []byte) { b.HasSuggestions = false + b.InlineCompletion = "" b.LineArray.insert(pos, value) b.setModified() @@ -134,6 +139,7 @@ func (b *SharedBuffer) insert(pos Loc, value []byte) { func (b *SharedBuffer) remove(start, end Loc) []byte { b.HasSuggestions = false + b.InlineCompletion = "" defer b.setModified() defer b.MarkModified(start.Y, end.Y) return b.LineArray.remove(start, end) @@ -268,6 +274,43 @@ type Buffer struct { OverwriteMode bool } +func (b *Buffer) TextBeforeCursor() string { + c := b.GetActiveCursor() + n := b.LinesNum() + var result []byte + for i := 0; i < n && i < c.Y; i++ { + result = append(result, b.LineBytes(i)...) + result = append(result, '\n') + } + if c.Y < n { + l := b.LineBytes(c.Y) + if c.X < len(l) { + result = append(result, l[:c.X]...) + } else { + result = append(result, l...) + } + } + return string(result) +} + +func (b *Buffer) TextAfterCursor() string { + c := b.GetActiveCursor() + n := b.LinesNum() + var result []byte + if c.Y < n { + l := b.LineBytes(c.Y) + if c.X < len(l) { + result = append(result, l[c.X:]...) + } + result = append(result, '\n') + } + for i := c.Y + 1; i < n; i++ { + result = append(result, b.LineBytes(i)...) + result = append(result, '\n') + } + return string(result) +} + // NewBufferFromFileWithCommand opens a new buffer with a given command // If cmd.StartCursor is {-1, -1} the location does not overwrite what the cursor location // would otherwise be (start of file, or saved cursor position if `savecursor` is diff --git a/internal/config/settings.go b/internal/config/settings.go index 26f1cc8d81..9f2f5f8092 100644 --- a/internal/config/settings.go +++ b/internal/config/settings.go @@ -131,7 +131,12 @@ var DefaultGlobalOnlySettings = map[string]any{ "sucmd": "sudo", "tabhighlight": false, "tabreverse": true, - "xterm": false, + "xterm": false, + "aicomplete": true, + "aicompleteprovider": "codestral", + "aicompletemodel": "codestral-latest", + "aicompleteurl": "", + "aicompletedebounce": float64(300), } // a list of settings that should never be globally modified diff --git a/internal/display/bufwindow.go b/internal/display/bufwindow.go index ddbb044c7b..ac49cf6c9a 100644 --- a/internal/display/bufwindow.go +++ b/internal/display/bufwindow.go @@ -29,6 +29,10 @@ type BufWindow struct { hasMessage bool maxLineNumLength int drawDivider bool + + // Visual cursor position for ghost text rendering + cursorX int + cursorY int } // NewBufWindow creates a new window at a location in the screen with a width and height @@ -372,6 +376,8 @@ func (w *BufWindow) getStyle(style tcell.Style, bloc buffer.Loc) (tcell.Style, b func (w *BufWindow) showCursor(x, y int, main bool) { if w.active { if main { + w.cursorX = x + w.cursorY = y screen.ShowCursor(x, y) } else { screen.ShowFakeCursorMulti(x, y) @@ -418,6 +424,9 @@ func (w *BufWindow) displayBuffer() { } } + w.cursorX = -1 + w.cursorY = -1 + lineNumStyle := config.DefStyle if style, ok := config.Colorscheme["line-number"]; ok { lineNumStyle = style @@ -893,6 +902,42 @@ func (w *BufWindow) displayScrollBar() { } } +func (w *BufWindow) displayGhostText() { + text := w.Buf.InlineCompletion + if text == "" || w.cursorX < 0 { + return + } + + style := config.DefStyle.Foreground(tcell.ColorGray) + if s, ok := config.Colorscheme["ai-completion"]; ok { + style = s + } + + maxX := w.X + w.gutterOffset + w.bufWidth + x := w.cursorX + 1 + y := w.cursorY + + if x >= maxX || y < w.Y || y >= w.Y+w.bufHeight { + return + } + + more := false + start := x + + for _, r := range text { + if x >= maxX || r == '\n' || r == '\r' { + more = true + break + } + screen.SetContent(x, y, r, nil, style) + x++ + } + + if more && x < maxX && x > start { + screen.SetContent(x, y, '…', nil, style) + } +} + // Display displays the buffer and the statusline func (w *BufWindow) Display() { w.updateDisplayInfo() @@ -900,4 +945,5 @@ func (w *BufWindow) Display() { w.displayStatusLine() w.displayScrollBar() w.displayBuffer() + w.displayGhostText() }