diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..2f387ed --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,37 @@ +version: "2" + +# Report every finding — the defaults (max-same-issues: 3, +# max-issues-per-linter: 50) silently hide repeats, which masks the true +# count. +issues: + max-same-issues: 0 + max-issues-per-linter: 0 + +# Default linter set (standard: errcheck, govet, ineffassign, staticcheck, +# unused). Only the errcheck exclusions below are customized. +linters: + settings: + errcheck: + # Conventional "safe to ignore" calls: the returned error is not + # actionable at the call site. + exclude-functions: + - (io.Closer).Close # interface-typed Close (e.g. resp.Body) + - (*os.File).Close # defer file/tmpFile Close + - (*database/sql.Rows).Close # defer rows.Close + - fmt.Fprint + - fmt.Fprintf + - fmt.Fprintln # writes to http.ResponseWriter / buffers + - os.Remove # temp-file cleanup on an error path + exclusions: + rules: + # Test setup routinely ignores errors from helpers (SaveHistory, + # RunIteration, Execute, …); checking them adds noise without value. + - path: _test\.go + linters: + - errcheck + # SDK stream readers (anthropic ssestream / go-openai) expose Close via + # an embedded generic type errcheck cannot address by name; the deferred + # Close error is not actionable once the stream has been drained. + - linters: + - errcheck + source: 'defer stream\.Close\(\)' diff --git a/examples/creative_studio/main.go b/examples/creative_studio/main.go index fe4aefa..5a857e3 100644 --- a/examples/creative_studio/main.go +++ b/examples/creative_studio/main.go @@ -2,19 +2,22 @@ // and short videos from natural-language descriptions. // // Required env vars: -// OPENAI_API_KEY — image generation (DALL-E 3) -// GEMINI_API_KEY — video generation (Veo) + optional LLM -// LLM_PROVIDER — "openai" (default), "anthropic", or "gemini" +// +// OPENAI_API_KEY — image generation (DALL-E 3) +// GEMINI_API_KEY — video generation (Veo) + optional LLM +// LLM_PROVIDER — "openai" (default), "anthropic", or "gemini" // // Optional env vars: -// VEO_MODEL — video model ID. Defaults to "veo-2.0-generate-001" -// (silent video). Set to a Veo 3 model ID to get native -// audio; availability and pricing differ per tier. -// MEDIA_DIR — where generated images/videos are saved (default "generated"). +// +// VEO_MODEL — video model ID. Defaults to "veo-2.0-generate-001" +// (silent video). Set to a Veo 3 model ID to get native +// audio; availability and pricing differ per tier. +// MEDIA_DIR — where generated images/videos are saved (default "generated"). // // Run: -// cd examples/creative_studio && go run . -// open http://localhost:8890 +// +// cd examples/creative_studio && go run . +// open http://localhost:8890 package main import ( diff --git a/examples/demo/main.go b/examples/demo/main.go index 70cfbc6..590cb8e 100644 --- a/examples/demo/main.go +++ b/examples/demo/main.go @@ -128,7 +128,7 @@ func buildSessionManager(systemPrompt string) agent.SessionManager { // pendingApprovals holds HITL confirmations that are waiting on the user. // Maps approvalID → channel that receives true (approved) or false (denied). var ( - pendingMu sync.Mutex + pendingMu sync.Mutex pendingApprovals = make(map[string]chan bool) ) @@ -320,7 +320,7 @@ func uniqueID() uint64 { // activeStreams maps sessionKey → the current SSE write channel (or nil). var ( streamsMu sync.RWMutex - activeStreams = make(map[string]chan<- agent.StreamEvent) + activeStreams = make(map[string]chan<- agent.StreamEvent) ) func registerStream(sessionKey string, ch chan<- agent.StreamEvent) { diff --git a/examples/memory_demo/main.go b/examples/memory_demo/main.go index b71f4a7..1f0e7b1 100644 --- a/examples/memory_demo/main.go +++ b/examples/memory_demo/main.go @@ -174,4 +174,3 @@ func dumpNotes(ctx context.Context, store memory.Store, scope string) { fmt.Printf(" - [%s] %s\n", n.Key, n.Content) } } - diff --git a/pkg/agent/auto_cache_system_test.go b/pkg/agent/auto_cache_system_test.go index bbc96b9..b3b7d79 100644 --- a/pkg/agent/auto_cache_system_test.go +++ b/pkg/agent/auto_cache_system_test.go @@ -19,10 +19,7 @@ type cacheHintCapturingProvider struct { func (p *cacheHintCapturingProvider) GenerateStream(_ context.Context, msgs []history.Message, _ *tools.Registry, ch chan<- StreamEvent) (LLMResult, error) { p.mu.Lock() - stamp := false - if len(msgs) > 0 && msgs[0].Role == "system" && msgs[0].CacheHint { - stamp = true - } + stamp := len(msgs) > 0 && msgs[0].Role == "system" && msgs[0].CacheHint p.seenStamps = append(p.seenStamps, stamp) p.mu.Unlock() ch <- Event(ContentEvent{Text: "ok"}) diff --git a/pkg/agent/cache_gate_test.go b/pkg/agent/cache_gate_test.go index 568fd7f..0705031 100644 --- a/pkg/agent/cache_gate_test.go +++ b/pkg/agent/cache_gate_test.go @@ -14,7 +14,6 @@ import ( type gateTool struct { name string cacheable bool - hasFlag bool calls atomic.Int32 } @@ -37,7 +36,7 @@ type cacheableGateTool struct{ gateTool } func (c *cacheableGateTool) Descriptor() tools.ToolDescriptor { d := c.gateTool.Descriptor() - d.Cacheable = c.gateTool.cacheable + d.Cacheable = c.cacheable return d } diff --git a/pkg/agent/constants.go b/pkg/agent/constants.go index 54498fc..9bd8f67 100644 --- a/pkg/agent/constants.go +++ b/pkg/agent/constants.go @@ -3,11 +3,6 @@ package agent // Internal tuning constants for the agent loop. Pulled out of inline // literals so they can be cited and adjusted in one place. const ( - // runIterationBuffer sizes the channel RunIteration uses to receive - // events from the loop goroutine. Larger than the streaming buffers - // because the synchronous reader has no consumer back-pressure. - runIterationBuffer = 100 - // runIterationStreamBuffer sizes the internal channel between the // loop goroutine and the SSE proxy goroutine in RunIterationStream. runIterationStreamBuffer = 50 diff --git a/pkg/agent/event_types.go b/pkg/agent/event_types.go index c681627..d8c41de 100644 --- a/pkg/agent/event_types.go +++ b/pkg/agent/event_types.go @@ -146,8 +146,8 @@ type ContentEvent struct { Text string `json:"text"` } -func (ContentEvent) isEventPayload() {} -func (ContentEvent) eventType() StreamEventType { return EventTypeContent } +func (ContentEvent) isEventPayload() {} +func (ContentEvent) eventType() StreamEventType { return EventTypeContent } // ThoughtEvent is internal reasoning / system narration. Suppressed from the // final answer by RunIteration; surfaced by RunIterationStream when @@ -156,8 +156,8 @@ type ThoughtEvent struct { Message string `json:"message"` } -func (ThoughtEvent) isEventPayload() {} -func (ThoughtEvent) eventType() StreamEventType { return EventTypeThought } +func (ThoughtEvent) isEventPayload() {} +func (ThoughtEvent) eventType() StreamEventType { return EventTypeThought } // ToolCallEvent announces that the agent is about to execute a tool. ID is // the agent-generated correlation ID — it matches the toolCallID parameter on @@ -172,8 +172,8 @@ type ToolCallEvent struct { Reused bool `json:"reused,omitempty"` } -func (ToolCallEvent) isEventPayload() {} -func (ToolCallEvent) eventType() StreamEventType { return EventTypeToolCall } +func (ToolCallEvent) isEventPayload() {} +func (ToolCallEvent) eventType() StreamEventType { return EventTypeToolCall } // ToolProgressEvent is a mid-execution status update emitted by a tool via // tools.ReportProgress. Progress is lossy by design — consumers may drop @@ -204,8 +204,8 @@ type UsageEvent struct { Usage TokenUsage `json:"usage"` } -func (UsageEvent) isEventPayload() {} -func (UsageEvent) eventType() StreamEventType { return EventTypeUsage } +func (UsageEvent) isEventPayload() {} +func (UsageEvent) eventType() StreamEventType { return EventTypeUsage } // ErrorEvent signals a terminal failure for the current iteration. Err holds // the structured error (usable with errors.Is / errors.As); Message is its @@ -215,14 +215,14 @@ type ErrorEvent struct { Message string `json:"message"` } -func (ErrorEvent) isEventPayload() {} -func (ErrorEvent) eventType() StreamEventType { return EventTypeError } +func (ErrorEvent) isEventPayload() {} +func (ErrorEvent) eventType() StreamEventType { return EventTypeError } // DoneEvent marks the end of the stream — no more events will arrive. type DoneEvent struct{} -func (DoneEvent) isEventPayload() {} -func (DoneEvent) eventType() StreamEventType { return EventTypeDone } +func (DoneEvent) isEventPayload() {} +func (DoneEvent) eventType() StreamEventType { return EventTypeDone } // ReflectedEvent delivers a post-critique canonical answer. Round indicates // which self-critique pass produced it (1-indexed); consumers typically keep @@ -265,8 +265,8 @@ type TaskListEvent struct { Tasks []TaskListItem `json:"tasks"` } -func (TaskListEvent) isEventPayload() {} -func (TaskListEvent) eventType() StreamEventType { return EventTypeTaskList } +func (TaskListEvent) isEventPayload() {} +func (TaskListEvent) eventType() StreamEventType { return EventTypeTaskList } // MaxItersReachedEvent signals that the loop exhausted its iteration cap // without a final answer. Limit echoes AgentLoop.MaxIters at the time of diff --git a/pkg/agent/event_types_test.go b/pkg/agent/event_types_test.go index 625beb3..7a4930f 100644 --- a/pkg/agent/event_types_test.go +++ b/pkg/agent/event_types_test.go @@ -147,30 +147,32 @@ type recordingVisitor struct { visited string } -func (r *recordingVisitor) VisitContent(ContentEvent) { r.visited = "content" } -func (r *recordingVisitor) VisitThought(ThoughtEvent) { r.visited = "thought" } -func (r *recordingVisitor) VisitToolCall(ToolCallEvent) { r.visited = "tool_call" } -func (r *recordingVisitor) VisitToolProgress(ToolProgressEvent) { r.visited = "tool_progress" } -func (r *recordingVisitor) VisitActionRequired(ActionRequiredEvent) { r.visited = "action_required" } -func (r *recordingVisitor) VisitUsage(UsageEvent) { r.visited = "usage" } -func (r *recordingVisitor) VisitError(ErrorEvent) { r.visited = "error" } -func (r *recordingVisitor) VisitDone(DoneEvent) { r.visited = "done" } -func (r *recordingVisitor) VisitReflected(ReflectedEvent) { r.visited = "reflected" } -func (r *recordingVisitor) VisitToolCallReady(ToolCallReadyEvent) { r.visited = "tool_call_ready" } -func (r *recordingVisitor) VisitTaskList(TaskListEvent) { r.visited = "task_list" } -func (r *recordingVisitor) VisitMaxItersReached(MaxItersReachedEvent) { r.visited = "max_iters_reached" } -func (r *recordingVisitor) VisitSessionCreated(SessionCreatedEvent) { r.visited = "session_created" } -func (r *recordingVisitor) VisitLimitExhausted(LimitExhaustedEvent) { r.visited = "limit_exhausted" } -func (r *recordingVisitor) VisitHITLDenied(HITLDeniedEvent) { r.visited = "hitl_denied" } -func (r *recordingVisitor) VisitHITLTimedOut(HITLTimedOutEvent) { r.visited = "hitl_timed_out" } -func (r *recordingVisitor) VisitRegenerated(RegeneratedEvent) { r.visited = "regenerated" } -func (r *recordingVisitor) VisitContinued(ContinuedEvent) { r.visited = "continued" } -func (r *recordingVisitor) VisitMemoryLoaded(MemoryLoadedEvent) { r.visited = "memory_loaded" } +func (r *recordingVisitor) VisitContent(ContentEvent) { r.visited = "content" } +func (r *recordingVisitor) VisitThought(ThoughtEvent) { r.visited = "thought" } +func (r *recordingVisitor) VisitToolCall(ToolCallEvent) { r.visited = "tool_call" } +func (r *recordingVisitor) VisitToolProgress(ToolProgressEvent) { r.visited = "tool_progress" } +func (r *recordingVisitor) VisitActionRequired(ActionRequiredEvent) { r.visited = "action_required" } +func (r *recordingVisitor) VisitUsage(UsageEvent) { r.visited = "usage" } +func (r *recordingVisitor) VisitError(ErrorEvent) { r.visited = "error" } +func (r *recordingVisitor) VisitDone(DoneEvent) { r.visited = "done" } +func (r *recordingVisitor) VisitReflected(ReflectedEvent) { r.visited = "reflected" } +func (r *recordingVisitor) VisitToolCallReady(ToolCallReadyEvent) { r.visited = "tool_call_ready" } +func (r *recordingVisitor) VisitTaskList(TaskListEvent) { r.visited = "task_list" } +func (r *recordingVisitor) VisitMaxItersReached(MaxItersReachedEvent) { + r.visited = "max_iters_reached" +} +func (r *recordingVisitor) VisitSessionCreated(SessionCreatedEvent) { r.visited = "session_created" } +func (r *recordingVisitor) VisitLimitExhausted(LimitExhaustedEvent) { r.visited = "limit_exhausted" } +func (r *recordingVisitor) VisitHITLDenied(HITLDeniedEvent) { r.visited = "hitl_denied" } +func (r *recordingVisitor) VisitHITLTimedOut(HITLTimedOutEvent) { r.visited = "hitl_timed_out" } +func (r *recordingVisitor) VisitRegenerated(RegeneratedEvent) { r.visited = "regenerated" } +func (r *recordingVisitor) VisitContinued(ContinuedEvent) { r.visited = "continued" } +func (r *recordingVisitor) VisitMemoryLoaded(MemoryLoadedEvent) { r.visited = "memory_loaded" } func (r *recordingVisitor) VisitMemoryConsolidated(MemoryConsolidatedEvent) { r.visited = "memory_consolidated" } -func (r *recordingVisitor) VisitRunCost(RunCostEvent) { r.visited = "run_cost" } -func (r *recordingVisitor) VisitUnknown(UnknownEvent) { r.visited = "unknown" } +func (r *recordingVisitor) VisitRunCost(RunCostEvent) { r.visited = "run_cost" } +func (r *recordingVisitor) VisitUnknown(UnknownEvent) { r.visited = "unknown" } func TestVisit_DispatchesToMatchingMethod(t *testing.T) { cases := []struct { diff --git a/pkg/agent/loop_state.go b/pkg/agent/loop_state.go index e0fe818..7fbfdd0 100644 --- a/pkg/agent/loop_state.go +++ b/pkg/agent/loop_state.go @@ -16,7 +16,6 @@ type iterationState struct { sessionKey string iteration int streamChan chan<- StreamEvent - msgs *[]history.Message specMap map[string]*speculativeExec specMu *sync.Mutex tracker *loopDetector diff --git a/pkg/agent/loop_stream_test.go b/pkg/agent/loop_stream_test.go index 7a8908a..5211075 100644 --- a/pkg/agent/loop_stream_test.go +++ b/pkg/agent/loop_stream_test.go @@ -483,8 +483,8 @@ func TestRunIteration_EventHandlerMultiple(t *testing.T) { func TestRunIteration_RetryOnLLMError(t *testing.T) { attempts := 0 provider := &countingErrorProvider{ - failN: 2, // fail first 2 attempts, succeed on 3rd - onCall: func() { attempts++ }, + failN: 2, // fail first 2 attempts, succeed on 3rd + onCall: func() { attempts++ }, successResult: LLMResult{Content: "recovered"}, } loop, _ := setup(provider) diff --git a/pkg/agent/memory_test.go b/pkg/agent/memory_test.go index c117265..43d13d1 100644 --- a/pkg/agent/memory_test.go +++ b/pkg/agent/memory_test.go @@ -779,4 +779,3 @@ type consolidatorPanicProvider struct{} func (p *consolidatorPanicProvider) GenerateStream(_ context.Context, _ []history.Message, _ *tools.Registry, _ chan<- StreamEvent) (LLMResult, error) { panic("provider should not be called for short transcripts") } - diff --git a/pkg/agent/mock_provider.go b/pkg/agent/mock_provider.go index a0700ae..0b529a0 100644 --- a/pkg/agent/mock_provider.go +++ b/pkg/agent/mock_provider.go @@ -48,7 +48,7 @@ func (m *MockProvider) GenerateStream(ctx context.Context, memory []history.Mess toolName = "delete_database_records" argsJSON = `{"table": "users", "condition": "all"}` } - + // If running inside SQL Sub-Agent, it should call execute_sql instead of call_sql_agent if _, hasExecuteSQL := registry.Get("execute_sql"); hasExecuteSQL { toolName = "execute_sql" diff --git a/pkg/agent/reflect.go b/pkg/agent/reflect.go index 2df4fbf..be1c1ef 100644 --- a/pkg/agent/reflect.go +++ b/pkg/agent/reflect.go @@ -2,7 +2,6 @@ package agent import ( "context" - "encoding/json" "fmt" "strings" @@ -94,20 +93,3 @@ func (al *AgentLoop) reflectOnce( } return strings.TrimSpace(buf.String()), nil } - -// reflectedEventContent serializes the ReflectedEvent payload into the -// string channel used by StreamEvent. JSON keeps text and round coupled so -// downstream consumers deserialize with a single Unmarshal. -func reflectedEventContent(text string, round int) string { - payload := struct { - Text string `json:"text"` - Round int `json:"round"` - }{Text: text, Round: round} - b, err := json.Marshal(payload) - if err != nil { - // Marshal of a string+int can't realistically fail; fall back to - // raw text so the consumer still sees the answer. - return text - } - return string(b) -} diff --git a/pkg/agent/soft_landing_test.go b/pkg/agent/soft_landing_test.go index 42c10d5..adb545e 100644 --- a/pkg/agent/soft_landing_test.go +++ b/pkg/agent/soft_landing_test.go @@ -73,7 +73,7 @@ func TestSoftLanding_NotPersistedToHistory(t *testing.T) { } type toolEverProvider struct { - name string + name string calls int } diff --git a/pkg/agent/structured_output_test.go b/pkg/agent/structured_output_test.go index 6f825fb..f516419 100644 --- a/pkg/agent/structured_output_test.go +++ b/pkg/agent/structured_output_test.go @@ -56,8 +56,8 @@ func TestStructuredOutput_IsolatedFromOriginal(t *testing.T) { func TestStructuredOutput_EmptyOrNilSchemaClears(t *testing.T) { cases := []StructuredOutput{ - {}, // zero value - {Name: "x"}, // no schema + {}, // zero value + {Name: "x"}, // no schema {Schema: map[string]any{}}, // empty schema } for i, so := range cases { diff --git a/pkg/agent/tool_error_hint_test.go b/pkg/agent/tool_error_hint_test.go index 75608cb..df69380 100644 --- a/pkg/agent/tool_error_hint_test.go +++ b/pkg/agent/tool_error_hint_test.go @@ -92,9 +92,8 @@ func TestFormatToolError_ContextCarriesArgsAndIteration(t *testing.T) { // --- integration: error flows into tool result message --- type failingTool struct { - mu sync.Mutex - nCalls int - failOnce bool + mu sync.Mutex + nCalls int } func (t *failingTool) Descriptor() tools.ToolDescriptor { diff --git a/pkg/agentmetrics/handler_test.go b/pkg/agentmetrics/handler_test.go index 52b6917..5da268c 100644 --- a/pkg/agentmetrics/handler_test.go +++ b/pkg/agentmetrics/handler_test.go @@ -94,7 +94,7 @@ func TestHandler_DeterministicOrdering(t *testing.T) { a := strings.Index(body, `gopheragent_session_prompt_tokens_total{session_key="a"}`) b := strings.Index(body, `gopheragent_session_prompt_tokens_total{session_key="b"}`) c := strings.Index(body, `gopheragent_session_prompt_tokens_total{session_key="c"}`) - if !(a < b && b < c) { + if a >= b || b >= c { t.Fatalf("session keys must be sorted; got indices a=%d b=%d c=%d", a, b, c) } } diff --git a/pkg/builder/knowledge_base_test.go b/pkg/builder/knowledge_base_test.go index 6b0f502..7efc4e1 100644 --- a/pkg/builder/knowledge_base_test.go +++ b/pkg/builder/knowledge_base_test.go @@ -58,7 +58,7 @@ func TestLoadKnowledgeBase_DeterministicOrdering(t *testing.T) { idxA := strings.Index(first, "a.md") idxB := strings.Index(first, "b.md") idxC := strings.Index(first, "c.md") - if !(idxA < idxB && idxB < idxC) { + if idxA >= idxB || idxB >= idxC { t.Fatalf("expected alphabetical ordering a,b,c — got offsets a=%d b=%d c=%d", idxA, idxB, idxC) } } @@ -244,9 +244,9 @@ func TestFormatKnowledgeBase_MatchesLoadOutputForEquivalentInput(t *testing.T) { func TestFormatKnowledgeBase_DropsEmptyEntries(t *testing.T) { out := FormatKnowledgeBase([]KBDocument{ - {Path: "", Content: "orphan"}, // no path — drop - {Path: "blank.md", Content: ""}, // no content — drop - {Path: "real.md", Content: "kept"}, // keep + {Path: "", Content: "orphan"}, // no path — drop + {Path: "blank.md", Content: ""}, // no content — drop + {Path: "real.md", Content: "kept"}, // keep }) if !strings.Contains(out, "kept") { t.Fatal("valid entry dropped") @@ -274,7 +274,7 @@ func TestFormatKnowledgeBase_SortsByPath(t *testing.T) { idxA := strings.Index(out, "alpha.md") idxM := strings.Index(out, "mid.md") idxZ := strings.Index(out, "zeta.md") - if !(idxA < idxM && idxM < idxZ) { + if idxA >= idxM || idxM >= idxZ { t.Fatalf("expected alphabetical ordering; offsets a=%d m=%d z=%d", idxA, idxM, idxZ) } } diff --git a/pkg/history/inmem_session_manager.go b/pkg/history/inmem_session_manager.go index 63a8f44..ceb905a 100644 --- a/pkg/history/inmem_session_manager.go +++ b/pkg/history/inmem_session_manager.go @@ -382,4 +382,3 @@ func (m *InMemSessionManager) PurgeDeletedBefore(_ context.Context, before time. } return purged, nil } - diff --git a/pkg/history/inmem_session_manager_test.go b/pkg/history/inmem_session_manager_test.go index a2a7036..e37f4b2 100644 --- a/pkg/history/inmem_session_manager_test.go +++ b/pkg/history/inmem_session_manager_test.go @@ -35,7 +35,7 @@ func TestInMemSessionManager_TTL_Eviction(t *testing.T) { defer cancel() sm := NewInMemSessionManager("sys"). - WithTTL(50 * time.Millisecond). + WithTTL(50*time.Millisecond). StartCleanup(ctx, 20*time.Millisecond) seedSession(sm, "s1", "hi") @@ -61,7 +61,7 @@ func TestInMemSessionManager_TTL_TouchOnRead_PreventsExpiry(t *testing.T) { defer cancel() sm := NewInMemSessionManager("sys"). - WithTTL(80 * time.Millisecond). + WithTTL(80*time.Millisecond). StartCleanup(ctx, 30*time.Millisecond) seedSession(sm, "s1", "hi") @@ -197,7 +197,7 @@ func TestInMemSessionManager_ConcurrentEvictionAndAccess(t *testing.T) { defer cancel() sm := NewInMemSessionManager("sys"). - WithTTL(30 * time.Millisecond). + WithTTL(30*time.Millisecond). StartCleanup(ctx, 10*time.Millisecond) done := make(chan struct{}) diff --git a/pkg/history/query_types.go b/pkg/history/query_types.go index 5d87af1..b248a87 100644 --- a/pkg/history/query_types.go +++ b/pkg/history/query_types.go @@ -11,9 +11,9 @@ import ( // SessionQueryOpts.IncludeDeleted) and Title (empty until a caller // records one via the SessionTitler capability). type SessionMeta struct { - Key string `json:"key"` - UpdatedAt time.Time `json:"updated_at"` - MessageCount int `json:"message_count"` + Key string `json:"key"` + UpdatedAt time.Time `json:"updated_at"` + MessageCount int `json:"message_count"` // Title is the human-readable label attached to the session via // SessionTitler.SetTitle (typically produced by builtin.GenerateTitle // on the first turn). Empty when the backend has none. Backends without diff --git a/pkg/history/session_manager.go b/pkg/history/session_manager.go index 8239f61..b6212c0 100644 --- a/pkg/history/session_manager.go +++ b/pkg/history/session_manager.go @@ -15,14 +15,14 @@ import ( // FileSessionManager provides file-backed session persistence. // Sessions survive server restarts. Each session is stored as a JSON file under storagePath. type FileSessionManager struct { - sessions map[string]*Session - behaviors map[string]string - lastSumLen map[string]int - updatedAt map[string]time.Time - deletedAt map[string]time.Time - titles map[string]string - storage string - mu sync.RWMutex + sessions map[string]*Session + behaviors map[string]string + lastSumLen map[string]int + updatedAt map[string]time.Time + deletedAt map[string]time.Time + titles map[string]string + storage string + mu sync.RWMutex SystemPrompt string SummaryProvider SummaryProvider // if nil, background summarization is disabled // PromptVersion: see InMemSessionManager.PromptVersion. Same semantics. @@ -61,13 +61,13 @@ func NewFileSessionManager(storagePath string, systemPrompt ...string) (*FileSes sp = systemPrompt[0] } return &FileSessionManager{ - sessions: make(map[string]*Session), - behaviors: make(map[string]string), - lastSumLen: make(map[string]int), - updatedAt: make(map[string]time.Time), - deletedAt: make(map[string]time.Time), - titles: make(map[string]string), - storage: storagePath, + sessions: make(map[string]*Session), + behaviors: make(map[string]string), + lastSumLen: make(map[string]int), + updatedAt: make(map[string]time.Time), + deletedAt: make(map[string]time.Time), + titles: make(map[string]string), + storage: storagePath, SystemPrompt: sp, }, nil } diff --git a/pkg/llm/gemini.go b/pkg/llm/gemini.go index a7c6533..1f3488a 100644 --- a/pkg/llm/gemini.go +++ b/pkg/llm/gemini.go @@ -134,7 +134,7 @@ func (p *GeminiProvider) GenerateStream(ctx context.Context, memory []history.Me } iter := p.client.Models.GenerateContentStream(ctx, p.model, contents, config) - + streamChan <- agent.Event(agent.ThoughtEvent{Message: "Analyzing with Gemini..."}) var finalContent string diff --git a/pkg/llm/multimodal_test.go b/pkg/llm/multimodal_test.go index bd7b035..d8a43ee 100644 --- a/pkg/llm/multimodal_test.go +++ b/pkg/llm/multimodal_test.go @@ -47,8 +47,8 @@ func TestOpenAI_BytesBecomeDataURI(t *testing.T) { func TestOpenAI_SkipsEmpty(t *testing.T) { parts := openAIPartsFromMediaParts("", []history.MediaPart{ - {Type: history.PartText, Text: ""}, // empty text — drop - {Type: history.PartImage, MIME: "image/png"}, // no URL, no Data — drop + {Type: history.PartText, Text: ""}, // empty text — drop + {Type: history.PartImage, MIME: "image/png"}, // no URL, no Data — drop history.NewImagePartURL("image/png", "https://x/a"), // keep }) if len(parts) != 1 { diff --git a/pkg/llm/router_test.go b/pkg/llm/router_test.go index 1185d92..539d656 100644 --- a/pkg/llm/router_test.go +++ b/pkg/llm/router_test.go @@ -2,7 +2,6 @@ package llm import ( "context" - "fmt" "testing" "github.com/hung12ct/gopheragent/pkg/agent" @@ -90,8 +89,8 @@ func TestRouterProvider_MultipleRoutes_FirstWins(t *testing.T) { } func TestIfTokensUnder(t *testing.T) { - short := msgs("hi") // ~1 token - long := msgs(fmt.Sprintf("%s", make([]byte, 400))) // ~100 tokens + short := msgs("hi") // ~1 token + long := msgs(string(make([]byte, 400))) // ~100 tokens if !IfTokensUnder(50)(short) { t.Fatal("short message should be under 50 tokens") @@ -126,9 +125,9 @@ func TestIfLastMessageContains(t *testing.T) { // Only the last user message is checked old := []history.Message{ - {Role: "user", Content: "tldr please"}, // old message + {Role: "user", Content: "tldr please"}, // old message {Role: "assistant", Content: "summary"}, - {Role: "user", Content: "now go deeper"}, // last user → no match + {Role: "user", Content: "now go deeper"}, // last user → no match } if cond(old) { t.Fatal("should only check last user message") diff --git a/pkg/tools/builtin/http_request_test.go b/pkg/tools/builtin/http_request_test.go index 4a0c024..e022920 100644 --- a/pkg/tools/builtin/http_request_test.go +++ b/pkg/tools/builtin/http_request_test.go @@ -166,7 +166,9 @@ func TestHTTPRequestTool_AllowedHostBypassesSSRF(t *testing.T) { if err != nil { t.Fatalf("expected success for allowlisted host, got: %v", err) } - var env struct{ Status int `json:"status"` } + var env struct { + Status int `json:"status"` + } _ = json.Unmarshal([]byte(out.Text), &env) if env.Status != 200 { t.Fatalf("status: %d", env.Status) diff --git a/pkg/tools/builtin/sql_agent.go b/pkg/tools/builtin/sql_agent.go index 6bb038d..ebddf5e 100644 --- a/pkg/tools/builtin/sql_agent.go +++ b/pkg/tools/builtin/sql_agent.go @@ -69,25 +69,25 @@ type SQLResult struct { // WithMaxRows, WithQueryTimeout) are chainable and safe to call in any // order; they mutate the receiver and return it. type CallSQLAgentTool struct { - db *sql.DB - schemaRaw string - schema *Schema - examples []SQLExample - businessRules []string - maxRows int - queryTimeout time.Duration - selfConsistency int - sessionManager agent.SessionManager - provider agent.LLMProvider - onSQL func(context.Context, SQLQueryEvent) - name string - display *tools.ToolDisplay - requiresConfirmation bool - allowMutations bool - allowDDL bool - allowSelectStar bool - execSQLRequiresConfirmation bool - providerHint string + db *sql.DB + schemaRaw string + schema *Schema + examples []SQLExample + businessRules []string + maxRows int + queryTimeout time.Duration + selfConsistency int + sessionManager agent.SessionManager + provider agent.LLMProvider + onSQL func(context.Context, SQLQueryEvent) + name string + display *tools.ToolDisplay + requiresConfirmation bool + allowMutations bool + allowDDL bool + allowSelectStar bool + execSQLRequiresConfirmation bool + providerHint string } // NewCallSQLAgentTool initializes a tool capable of querying databases. The diff --git a/pkg/tools/builtin/sql_agent_builders_test.go b/pkg/tools/builtin/sql_agent_builders_test.go index 9cf0a70..8983e74 100644 --- a/pkg/tools/builtin/sql_agent_builders_test.go +++ b/pkg/tools/builtin/sql_agent_builders_test.go @@ -224,7 +224,7 @@ func TestCallSQLAgentTool_WithProviderHintInjectsAddendum(t *testing.T) { contractIdx := strings.Index(prompt, "Safety contract") hintIdx := strings.Index(prompt, "Provider-specific guidance") schemaIdx := strings.Index(prompt, "Schema (use ONLY") - if !(contractIdx < hintIdx && hintIdx < schemaIdx) { + if contractIdx >= hintIdx || hintIdx >= schemaIdx { t.Fatalf("provider hint out of order: contract=%d hint=%d schema=%d", contractIdx, hintIdx, schemaIdx) } } diff --git a/pkg/tools/builtin/sql_consistency.go b/pkg/tools/builtin/sql_consistency.go index ad2cacf..7e5d344 100644 --- a/pkg/tools/builtin/sql_consistency.go +++ b/pkg/tools/builtin/sql_consistency.go @@ -15,7 +15,6 @@ type sqlCandidate struct { index int finalResp string lastResult *SQLResult - err error } // pickByMajority clusters candidates by the hash of their last successful diff --git a/pkg/tools/builtin/sql_validate.go b/pkg/tools/builtin/sql_validate.go index 5e65d4e..c62b0ca 100644 --- a/pkg/tools/builtin/sql_validate.go +++ b/pkg/tools/builtin/sql_validate.go @@ -132,7 +132,7 @@ func StripSQLComments(s string) string { case c == '/' && i+1 < len(s) && s[i+1] == '*': // Block comment → skip to */. j := i + 2 - for j+1 < len(s) && !(s[j] == '*' && s[j+1] == '/') { + for j+1 < len(s) && (s[j] != '*' || s[j+1] != '/') { j++ } if j+1 < len(s) { @@ -158,12 +158,12 @@ func SplitStatements(s string) []string { i := 0 for i < len(s) { c := s[i] - switch { - case c == '\'' || c == '"' || c == '`': + switch c { + case '\'', '"', '`': end := scanQuoted(s, i) cur.WriteString(s[i:end]) i = end - case c == ';': + case ';': out = append(out, cur.String()) cur.Reset() i++ @@ -381,8 +381,9 @@ func matchesWordAt(s string, i int, word string) bool { } // scanQuoted returns the byte offset immediately after the quoted run that -// starts at s[i]. Supports SQL's doubled-quote escape ('' inside '...', -// "" inside "...", `` inside `...`). +// starts at s[i]. Supports SQL's doubled-quote escape, where the delimiter +// is repeated to embed a literal copy of itself inside a single-quoted, +// double-quoted, or backtick-quoted run. func scanQuoted(s string, i int) int { if i >= len(s) { return i diff --git a/pkg/tools/builtin/ssrf.go b/pkg/tools/builtin/ssrf.go index 55539f1..944eaac 100644 --- a/pkg/tools/builtin/ssrf.go +++ b/pkg/tools/builtin/ssrf.go @@ -17,21 +17,21 @@ var privateRanges []*net.IPNet func init() { blocks := []string{ - "127.0.0.0/8", // loopback (IPv4) - "10.0.0.0/8", // RFC 1918 private - "172.16.0.0/12", // RFC 1918 private - "192.168.0.0/16", // RFC 1918 private - "169.254.0.0/16", // link-local; 169.254.169.254 = AWS/Azure/GCP/DO metadata - "0.0.0.0/8", // "this" network - "100.64.0.0/10", // shared address space (RFC 6598) - "192.0.2.0/24", // TEST-NET-1 (RFC 5737) + "127.0.0.0/8", // loopback (IPv4) + "10.0.0.0/8", // RFC 1918 private + "172.16.0.0/12", // RFC 1918 private + "192.168.0.0/16", // RFC 1918 private + "169.254.0.0/16", // link-local; 169.254.169.254 = AWS/Azure/GCP/DO metadata + "0.0.0.0/8", // "this" network + "100.64.0.0/10", // shared address space (RFC 6598) + "192.0.2.0/24", // TEST-NET-1 (RFC 5737) "198.51.100.0/24", // TEST-NET-2 (RFC 5737) - "203.0.113.0/24", // TEST-NET-3 (RFC 5737) - "224.0.0.0/4", // multicast - "240.0.0.0/4", // reserved (RFC 1112) - "::1/128", // IPv6 loopback - "fc00::/7", // IPv6 unique-local (ULA) - "fe80::/10", // IPv6 link-local + "203.0.113.0/24", // TEST-NET-3 (RFC 5737) + "224.0.0.0/4", // multicast + "240.0.0.0/4", // reserved (RFC 1112) + "::1/128", // IPv6 loopback + "fc00::/7", // IPv6 unique-local (ULA) + "fe80::/10", // IPv6 link-local } for _, cidr := range blocks { _, network, err := net.ParseCIDR(cidr) diff --git a/pkg/tools/errors.go b/pkg/tools/errors.go index f0e5248..15cfe5e 100644 --- a/pkg/tools/errors.go +++ b/pkg/tools/errors.go @@ -7,12 +7,12 @@ import "errors" // category: // // - ErrUser — the caller supplied bad input; show the message to -// the user, do not retry. +// the user, do not retry. // - ErrTransient — transient downstream failure (rate limit, 5xx, -// network timeout); retrying later may succeed. +// network timeout); retrying later may succeed. // - ErrPermanent — non-recoverable downstream failure (bad config, -// permanent 4xx from an external API); retries will -// keep failing. +// permanent 4xx from an external API); retries will +// keep failing. // // Wrap pattern: // diff --git a/pkg/tools/errors_test.go b/pkg/tools/errors_test.go index 4ad42db..6c0daca 100644 --- a/pkg/tools/errors_test.go +++ b/pkg/tools/errors_test.go @@ -8,9 +8,9 @@ import ( func TestClassifyError_MatchesEachSentinel(t *testing.T) { cases := []struct { - name string - err error - want error + name string + err error + want error }{ {"user", fmt.Errorf("bad input: %w", ErrUser), ErrUser}, {"transient", fmt.Errorf("429: %w", ErrTransient), ErrTransient}, diff --git a/pkg/tools/toolsfake/fake.go b/pkg/tools/toolsfake/fake.go index 2f45f47..50b245a 100644 --- a/pkg/tools/toolsfake/fake.go +++ b/pkg/tools/toolsfake/fake.go @@ -47,7 +47,6 @@ type Tool struct { allArgs []string } - // NewTool returns a new fake tool with the given Name. Description defaults // to "fake tool", schema defaults to an empty object. Chain the With* // methods to configure further.