forked from glebkudr/shotgun_code
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauto_context_service.go
More file actions
322 lines (279 loc) · 8.56 KB
/
auto_context_service.go
File metadata and controls
322 lines (279 loc) · 8.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
package main
import (
"embed"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
"shotgun_code/internal/llm/provider"
)
//go:embed design/prompts/contextPreparation.md
var embeddedPromptFS embed.FS
const (
autoContextTemplatePath = "design/prompts/contextPreparation.md"
maxAutoContextTreeChars = 15_000
)
var errAutoContextTreeTooLarge = errors.New("auto context file tree exceeds the allowed size")
type autoContextParser struct{}
type AutoContextResult struct {
Files []string `json:"files"`
Reasoning string `json:"reasoning,omitempty"`
}
func (autoContextParser) Parse(text string) (AutoContextResult, error) {
return parseAutoContextJSON(text)
}
func (autoContextParser) ParseWithPrompt(text string, _ schema.PromptValue) (AutoContextResult, error) {
return parseAutoContextJSON(text)
}
func (autoContextParser) GetFormatInstructions() string {
return "Respond ONLY with a JSON object that matches this schema:\n" +
"```\n{\n \"files\": [\"relative/path/from/project/root\"],\n \"reasoning\": \"optional short description\"\n}\n```\n" +
"No code fences, commentary, or explanations outside the JSON object."
}
func (autoContextParser) Type() string {
return "auto_context_json_parser"
}
type AutoContextService struct {
parser autoContextParser
templateMu sync.Mutex
template prompts.PromptTemplate
templateLoaded bool
}
func NewAutoContextService() *AutoContextService {
return &AutoContextService{}
}
func (s *AutoContextService) ensureTemplate() error {
s.templateMu.Lock()
defer s.templateMu.Unlock()
if s.templateLoaded {
return nil
}
var templateBody string
if bytes, err := os.ReadFile(autoContextTemplatePath); err == nil {
templateBody = string(bytes)
} else {
content, readErr := embeddedPromptFS.ReadFile(autoContextTemplatePath)
if readErr != nil {
return fmt.Errorf("failed to load auto-context prompt template: %w", readErr)
}
templateBody = string(content)
}
s.template = prompts.NewPromptTemplate(
templateBody,
[]string{"FILE_TREE", "USER_TASK", "CURRENT_UNDERSTANDING"},
)
s.templateLoaded = true
return nil
}
func (s *AutoContextService) BuildPrompt(fileTree, userTask, understanding string) (string, error) {
if err := s.ensureTemplate(); err != nil {
return "", err
}
formatted, err := s.template.Format(map[string]any{
"FILE_TREE": fileTree,
"USER_TASK": userTask,
"CURRENT_UNDERSTANDING": understanding,
})
if err != nil {
return "", fmt.Errorf("failed to render auto-context prompt: %w", err)
}
return strings.TrimSpace(formatted) + "\n\n" + s.parser.GetFormatInstructions(), nil
}
func (s *AutoContextService) ParseResponse(text string) (AutoContextResult, error) {
return s.parser.Parse(text)
}
func parseAutoContextJSON(text string) (AutoContextResult, error) {
cleaned := strings.TrimSpace(text)
if cleaned == "" {
return AutoContextResult{}, errors.New("empty response from LLM")
}
// Strip markdown fences if present.
if strings.HasPrefix(cleaned, "```") {
cleaned = strings.TrimPrefix(cleaned, "```json")
cleaned = strings.TrimPrefix(cleaned, "```JSON")
cleaned = strings.TrimPrefix(cleaned, "```")
if idx := strings.LastIndex(cleaned, "```"); idx >= 0 {
cleaned = cleaned[:idx]
}
}
cleaned = strings.TrimSpace(cleaned)
var result AutoContextResult
decoder := json.NewDecoder(strings.NewReader(cleaned))
decoder.DisallowUnknownFields()
if err := decoder.Decode(&result); err != nil {
return AutoContextResult{}, fmt.Errorf("failed to decode auto-context response: %w", err)
}
normalized := make([]string, 0, len(result.Files))
for _, f := range result.Files {
f = normalizeRelativePath(f)
if f != "" {
normalized = append(normalized, f)
}
}
if len(normalized) == 0 {
return AutoContextResult{}, errors.New("response did not include any valid files")
}
result.Files = normalized
return result, nil
}
func buildAutoContextTree(rootDir string, excludedMap map[string]bool) (string, error) {
var builder strings.Builder
builder.WriteString(filepath.Base(rootDir) + string(os.PathSeparator) + "\n")
var walk func(string, string) error
walk = func(currentPath, prefix string) error {
entries, err := os.ReadDir(currentPath)
if err != nil {
return fmt.Errorf("failed to read directory %s: %w", currentPath, err)
}
sort.SliceStable(entries, func(i, j int) bool {
if entries[i].IsDir() && !entries[j].IsDir() {
return true
}
if !entries[i].IsDir() && entries[j].IsDir() {
return false
}
return strings.ToLower(entries[i].Name()) < strings.ToLower(entries[j].Name())
})
visibleEntries := make([]os.DirEntry, 0, len(entries))
for _, entry := range entries {
relPath, _ := filepath.Rel(rootDir, filepath.Join(currentPath, entry.Name()))
if excludedMap[normalizeRelativePath(relPath)] {
continue
}
visibleEntries = append(visibleEntries, entry)
}
for idx, entry := range visibleEntries {
branch := "├── "
nextPrefix := prefix + "│ "
if idx == len(visibleEntries)-1 {
branch = "└── "
nextPrefix = prefix + " "
}
builder.WriteString(prefix + branch + entry.Name() + "\n")
if builder.Len() > maxAutoContextTreeChars {
return errAutoContextTreeTooLarge
}
if entry.IsDir() {
if err := walk(filepath.Join(currentPath, entry.Name()), nextPrefix); err != nil {
return err
}
}
}
return nil
}
if err := walk(rootDir, ""); err != nil {
return "", err
}
if builder.Len() > maxAutoContextTreeChars {
return "", errAutoContextTreeTooLarge
}
return builder.String(), nil
}
func normalizeRelativePath(rel string) string {
rel = strings.TrimSpace(rel)
if rel == "" || rel == "." {
return ""
}
rel = strings.TrimPrefix(rel, "./")
rel = filepath.ToSlash(rel)
rel = strings.TrimPrefix(rel, "/")
return rel
}
// normalizeCandidateForRoot brings an LLM-returned path into the canonical
// "relative to rootDir" form. It accepts either strictly relative paths like
// "frontend/src/..." or paths prefixed with the project root name, e.g.:
// "shotgun_code/frontend/src/..." when rootDir == ".../shotgun_code".
func normalizeCandidateForRoot(rootDir, candidate string) string {
candidate = normalizeRelativePath(candidate)
if candidate == "" {
return ""
}
rootBase := filepath.Base(rootDir)
if rootBase == "" || rootBase == "." {
return candidate
}
rootBase = filepath.ToSlash(rootBase)
// Common case: "shotgun_code/frontend/src/..." → "frontend/src/..."
prefix := rootBase + "/"
if strings.HasPrefix(candidate, prefix) {
return strings.TrimPrefix(candidate, prefix)
}
// Also accept "./shotgun_code/..." just in case the model prepends "./".
dotPrefix := "./" + prefix
if strings.HasPrefix(candidate, dotPrefix) {
return strings.TrimPrefix(candidate, dotPrefix)
}
// If the candidate is exactly the root name, it is not a file path.
if candidate == rootBase {
return ""
}
return candidate
}
func resolveLLMSelection(rootDir string, candidates []string) ([]string, error) {
if len(candidates) == 0 {
return nil, errors.New("no candidate paths provided")
}
selected := make(map[string]struct{})
for _, candidate := range candidates {
candidate = normalizeCandidateForRoot(rootDir, candidate)
if candidate == "" {
continue
}
absPath := filepath.Join(rootDir, filepath.FromSlash(candidate))
info, err := os.Stat(absPath)
if err != nil {
continue
}
if info.IsDir() {
filepath.WalkDir(absPath, func(path string, d os.DirEntry, err error) error {
if err != nil {
return nil
}
if d.IsDir() {
return nil
}
rel, relErr := filepath.Rel(rootDir, path)
if relErr != nil {
return nil
}
rel = normalizeRelativePath(rel)
if rel != "" {
selected[rel] = struct{}{}
}
return nil
})
} else {
selected[filepath.ToSlash(candidate)] = struct{}{}
}
}
if len(selected) == 0 {
return nil, errors.New("no existing files matched the LLM selection")
}
sorted := make([]string, 0, len(selected))
for rel := range selected {
sorted = append(sorted, rel)
}
sort.Strings(sorted)
return sorted, nil
}
func buildProviderConfig(settings LLMSettings) provider.Config {
return provider.Config{
Provider: settings.ActiveProvider,
Model: fallbackModel(settings),
APIKey: settings.keyForProvider(settings.ActiveProvider),
BaseURL: strings.TrimSpace(settings.BaseURL),
}
}
func fallbackModel(settings LLMSettings) string {
model := strings.TrimSpace(settings.Model)
if model != "" {
return model
}
return defaultModelForProvider(settings.ActiveProvider)
}