From 33d6c44bf02e0cdd4ff43c3885fda3410d998547 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 28 Mar 2026 21:17:21 +0000 Subject: [PATCH 1/2] feat: add sub-agents support Signed-off-by: Ettore Di Giacinto --- README.md | 153 +++++++++++++++ agent.go | 382 ++++++++++++++++++++++++++++++++++++ agent_test.go | 521 ++++++++++++++++++++++++++++++++++++++++++++++++++ options.go | 36 ++++ stream.go | 2 + tools.go | 96 +++++++++- 6 files changed, 1184 insertions(+), 6 deletions(-) create mode 100644 agent.go create mode 100644 agent_test.go diff --git a/README.md b/README.md index c219c3e..ba48ed4 100644 --- a/README.md +++ b/README.md @@ -630,6 +630,159 @@ result, err := cogito.ExecuteTools(llm, fragment, - The sink state tool receives a `reasoning` parameter containing the LLM's reasoning about why no tool is needed - Custom sink state tools must accept a `reasoning` parameter in their arguments +#### Sub-Agent Spawning + +Cogito supports spawning sub-agents via tools, allowing the LLM to delegate tasks to independent child agents. Sub-agents can run in the **foreground** (blocking — waits for result) or **background** (non-blocking — returns an ID immediately so the parent can continue working). + +**Enabling Sub-Agents:** + +```go +result, err := cogito.ExecuteTools(llm, fragment, + cogito.WithTools(searchTool, weatherTool), + cogito.EnableAgentSpawning, // Adds spawn_agent, check_agent, get_agent_result tools + cogito.WithIterations(10), +) +``` + +When enabled, three built-in tools are injected: +- **`spawn_agent`** — Spawns a sub-agent with a task. Set `background: true` for non-blocking execution. +- **`check_agent`** — Checks the status of a background agent by ID. +- **`get_agent_result`** — Retrieves the result of a background agent. Set `wait: true` to block until done. + +**Foreground Agents (Blocking):** + +The LLM calls `spawn_agent` with `background: false`. The sub-agent runs synchronously inside the tool call, and its result is returned as the tool output: + +```go +// The LLM might call: spawn_agent(task="Research quantum computing", background=false) +// → Sub-agent runs ExecuteTools with the same tools (minus agent tools) +// → Sub-agent completes and returns result +// → Parent receives result as tool output and continues +result, err := cogito.ExecuteTools(llm, fragment, + cogito.WithTools(searchTool), + cogito.EnableAgentSpawning, + cogito.WithIterations(5), +) +``` + +**Background Agents (Non-Blocking):** + +The LLM calls `spawn_agent` with `background: true`. The sub-agent launches in a goroutine, and the parent immediately gets back an agent ID. `ExecuteTools` automatically stays alive while background agents are running — when the LLM has no more tools to call, the loop blocks until a background agent finishes. The completion notification is injected into the conversation, and the LLM can react to it: + +```go +// The LLM might call: spawn_agent(task="Analyze data", background=true) +// → Returns "Agent spawned in background with ID: abc-123" +// → Parent continues working on other tasks +// → When sub-agent finishes, parent sees: +// "Background agent abc-123 has completed. Task: Analyze data. Result: ..." +result, err := cogito.ExecuteTools(llm, fragment, + cogito.WithTools(searchTool, analysisTool), + cogito.EnableAgentSpawning, + cogito.WithIterations(10), +) +``` + +**Sharing Agent State Across Calls:** + +Use `WithAgentManager` to share the agent registry across multiple `ExecuteTools` calls, allowing you to track background agents across conversation turns: + +```go +manager := cogito.NewAgentManager() + +// First turn: spawns a background agent +result1, _ := cogito.ExecuteTools(llm, fragment1, + cogito.WithTools(searchTool), + cogito.EnableAgentSpawning, + cogito.WithAgentManager(manager), + cogito.WithIterations(5), +) + +// Second turn: can check on or retrieve results from previously spawned agents +result2, _ := cogito.ExecuteTools(llm, fragment2, + cogito.WithTools(searchTool), + cogito.EnableAgentSpawning, + cogito.WithAgentManager(manager), + cogito.WithIterations(5), +) + +// Programmatic access to all agents +for _, agent := range manager.List() { + fmt.Printf("Agent %s: %s\n", agent.ID, agent.Status) +} +``` + +**Using a Separate LLM for Sub-Agents:** + +```go +mainLLM := clients.NewOpenAILLM("gpt-4", apiKey, baseURL) +subAgentLLM := clients.NewOpenAILLM("gpt-3.5-turbo", apiKey, baseURL) + +result, err := cogito.ExecuteTools(mainLLM, fragment, + cogito.WithTools(searchTool), + cogito.EnableAgentSpawning, + cogito.WithAgentLLM(subAgentLLM), // Sub-agents use a different model + cogito.WithIterations(5), +) +``` + +**Completion Callbacks:** + +Use `WithAgentCompletionCallback` to get programmatic notification when any background sub-agent finishes — useful for UI updates or external monitoring: + +```go +result, err := cogito.ExecuteTools(llm, fragment, + cogito.WithTools(searchTool), + cogito.EnableAgentSpawning, + cogito.WithAgentCompletionCallback(func(agent *cogito.AgentState) { + fmt.Printf("Agent %s finished with status: %s\n", agent.ID, agent.Status) + if agent.Status == cogito.AgentStatusCompleted { + fmt.Println("Result:", agent.Result) + } + }), + cogito.WithIterations(10), +) +``` + +**Tool Filtering:** + +By default, sub-agents receive all parent tools except the agent management tools themselves (preventing unbounded recursion). The LLM can also specify a subset of tools: + +```go +// The LLM can call: spawn_agent(task="Search only", tools=["search"]) +// → Sub-agent only has access to the "search" tool +``` + +To allow sub-agents to also spawn their own sub-agents, the LLM can explicitly include agent tools: +```go +// spawn_agent(task="Complex task", tools=["search", "spawn_agent", "check_agent", "get_agent_result"]) +``` + +**Streaming Sub-Agent Events:** + +When streaming is enabled, sub-agent events are tagged with a `StreamEventSubAgent` type and include the agent's ID: + +```go +result, err := cogito.ExecuteTools(llm, fragment, + cogito.EnableAgentSpawning, + cogito.WithTools(searchTool), + cogito.WithStreamCallback(func(ev cogito.StreamEvent) { + if ev.Type == cogito.StreamEventSubAgent { + fmt.Printf("[Agent %s] %s\n", ev.AgentID, ev.Content) + } else { + fmt.Print(ev.Content) + } + }), + cogito.WithIterations(5), +) +``` + +**When to Use Sub-Agents:** + +- **Use foreground agents** when the parent needs the result before continuing (e.g., research a topic, then summarize) +- **Use background agents** when tasks are independent and the parent can continue working (e.g., start multiple research tasks in parallel) +- **Use `WithAgentManager`** when you need to track agents across multiple conversation turns +- **Use `WithAgentLLM`** when sub-agents should use a cheaper/faster model + #### Field Annotations for Tool Arguments Cogito supports several struct field annotations to control how tool arguments are defined in the generated JSON schema: diff --git a/agent.go b/agent.go new file mode 100644 index 0000000..b9b9ed4 --- /dev/null +++ b/agent.go @@ -0,0 +1,382 @@ +package cogito + +import ( + "context" + "fmt" + "sync" + + "github.com/google/uuid" + "github.com/sashabaranov/go-openai" +) + +// AgentStatusType represents the lifecycle state of a sub-agent. +type AgentStatusType string + +const ( + AgentStatusRunning AgentStatusType = "running" + AgentStatusCompleted AgentStatusType = "completed" + AgentStatusFailed AgentStatusType = "failed" +) + +// agentToolNames are the names of the built-in agent management tools. +var agentToolNames = []string{"spawn_agent", "check_agent", "get_agent_result"} + +// SpawnAgentArgs are the arguments the LLM provides when spawning a sub-agent. +type SpawnAgentArgs struct { + Task string `json:"task" description:"The task or prompt for the sub-agent to execute"` + Background bool `json:"background" description:"If true, the agent runs in the background and returns an ID immediately. If false, blocks until the agent completes."` + Tools []string `json:"tools" description:"Optional subset of tool names available to the sub-agent. If empty, all parent tools (except agent tools) are given."` +} + +// CheckAgentArgs are the arguments for checking a background agent's status. +type CheckAgentArgs struct { + AgentID string `json:"agent_id" description:"The ID of the background agent to check"` +} + +// GetAgentResultArgs are the arguments for retrieving a background agent's result. +type GetAgentResultArgs struct { + AgentID string `json:"agent_id" description:"The ID of the background agent"` + Wait bool `json:"wait" description:"If true, blocks until the agent finishes. If false, returns immediately with current status."` +} + +// AgentState tracks the lifecycle of a single sub-agent. +type AgentState struct { + ID string + Task string + Status AgentStatusType + Result string + Fragment *Fragment + Error error + Cancel context.CancelFunc + done chan struct{} +} + +// AgentManager is a thread-safe registry of background sub-agents. +type AgentManager struct { + mu sync.RWMutex + agents map[string]*AgentState +} + +// NewAgentManager creates a new AgentManager. +func NewAgentManager() *AgentManager { + return &AgentManager{agents: make(map[string]*AgentState)} +} + +// Register adds an agent to the manager. +func (m *AgentManager) Register(agent *AgentState) { + m.mu.Lock() + defer m.mu.Unlock() + m.agents[agent.ID] = agent +} + +// Get retrieves an agent by ID. +func (m *AgentManager) Get(id string) (*AgentState, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + a, ok := m.agents[id] + return a, ok +} + +// List returns all registered agents. +func (m *AgentManager) List() []*AgentState { + m.mu.RLock() + defer m.mu.RUnlock() + result := make([]*AgentState, 0, len(m.agents)) + for _, a := range m.agents { + result = append(result, a) + } + return result +} + +// HasRunning returns true if any registered agent is still running. +func (m *AgentManager) HasRunning() bool { + m.mu.RLock() + defer m.mu.RUnlock() + for _, a := range m.agents { + if a.Status == AgentStatusRunning { + return true + } + } + return false +} + +// Wait blocks until the agent with the given ID completes, then returns it. +func (m *AgentManager) Wait(id string) (*AgentState, error) { + agent, ok := m.Get(id) + if !ok { + return nil, fmt.Errorf("agent %s not found", id) + } + <-agent.done + return agent, nil +} + +// isAgentTool returns true if the tool name is one of the built-in agent tools. +func isAgentTool(name string) bool { + for _, n := range agentToolNames { + if n == name { + return true + } + } + return false +} + +// FilterToolsForSubAgent returns a subset of parent tools suitable for a sub-agent. +// If requestedTools is non-empty, only those named tools are included. +// Agent management tools are excluded by default. +func FilterToolsForSubAgent(parentTools Tools, requestedTools []string) Tools { + if len(requestedTools) > 0 { + var filtered Tools + for _, name := range requestedTools { + if t := parentTools.Find(name); t != nil { + filtered = append(filtered, t) + } + } + return filtered + } + + // All parent tools minus agent tools + var filtered Tools + for _, t := range parentTools { + if !isAgentTool(t.Tool().Function.Name) { + filtered = append(filtered, t) + } + } + return filtered +} + +// SetAgentDone sets the done channel on an AgentState. Used for testing. +func SetAgentDone(a *AgentState, ch chan struct{}) { + a.done = ch +} + +// CheckAgentRunnerForTest exposes the checkAgentRunner for testing. +type CheckAgentRunnerForTest struct { + Manager *AgentManager +} + +func (r *CheckAgentRunnerForTest) Run(args CheckAgentArgs) (string, any, error) { + inner := &checkAgentRunner{manager: r.Manager} + return inner.Run(args) +} + +// GetAgentResultRunnerForTest exposes the getAgentResultRunner for testing. +type GetAgentResultRunnerForTest struct { + Manager *AgentManager + Ctx context.Context +} + +func (r *GetAgentResultRunnerForTest) Run(args GetAgentResultArgs) (string, any, error) { + inner := &getAgentResultRunner{manager: r.Manager, ctx: r.Ctx} + return inner.Run(args) +} + +// spawnAgentRunner implements Tool[SpawnAgentArgs]. +type spawnAgentRunner struct { + llm LLM + parentTools Tools + parentOpts []Option + manager *AgentManager + ctx context.Context + streamCB StreamCallback + messageInjectionChan chan openai.ChatCompletionMessage + agentCompletionCallback func(*AgentState) +} + +func (r *spawnAgentRunner) Run(args SpawnAgentArgs) (string, any, error) { + subTools := FilterToolsForSubAgent(r.parentTools, args.Tools) + + subOpts := append([]Option{}, + WithTools(subTools...), + WithContext(r.ctx), + ) + subOpts = append(subOpts, r.parentOpts...) + + subFragment := NewFragment( + openai.ChatCompletionMessage{Role: "user", Content: args.Task}, + ) + + if !args.Background { + // Foreground: execute synchronously + if r.streamCB != nil { + subOpts = append(subOpts, WithStreamCallback(r.streamCB)) + } + result, err := ExecuteTools(r.llm, subFragment, subOpts...) + if err != nil { + return fmt.Sprintf("Sub-agent failed: %v", err), nil, nil + } + msg := result.LastMessage().Content + return msg, result, nil + } + + // Background: launch goroutine, return ID immediately + agentID := uuid.New().String() + agent := &AgentState{ + ID: agentID, + Task: args.Task, + Status: AgentStatusRunning, + done: make(chan struct{}), + } + r.manager.Register(agent) + + subCtx, cancel := context.WithCancel(r.ctx) + agent.Cancel = cancel + + // Wrap stream callback to tag events with agent ID + if r.streamCB != nil { + parentCB := r.streamCB + subOpts = append(subOpts, WithStreamCallback(func(ev StreamEvent) { + ev.AgentID = agentID + ev.Type = StreamEventSubAgent + parentCB(ev) + })) + } + + // Override context for sub-agent + subOpts = append(subOpts, WithContext(subCtx)) + + go func() { + defer close(agent.done) + defer cancel() + + result, err := ExecuteTools(r.llm, subFragment, subOpts...) + + r.manager.mu.Lock() + if err != nil { + agent.Status = AgentStatusFailed + agent.Error = err + agent.Result = fmt.Sprintf("Failed: %v", err) + } else { + agent.Status = AgentStatusCompleted + agent.Result = result.LastMessage().Content + agent.Fragment = &result + } + r.manager.mu.Unlock() + + // Fire completion callback + if r.agentCompletionCallback != nil { + r.agentCompletionCallback(agent) + } + + // Inject completion notification into parent's loop + if r.messageInjectionChan != nil { + var content string + if agent.Status == AgentStatusCompleted { + content = fmt.Sprintf("Background agent %s has completed.\nTask: %s\nResult: %s", agentID, args.Task, agent.Result) + } else { + content = fmt.Sprintf("Background agent %s has failed.\nTask: %s\nError: %v", agentID, args.Task, agent.Error) + } + select { + case r.messageInjectionChan <- openai.ChatCompletionMessage{ + Role: "user", + Content: content, + }: + default: + // Non-blocking: if the channel is full or closed, skip notification + } + } + }() + + return fmt.Sprintf("Agent spawned in background with ID: %s", agentID), agentID, nil +} + +// checkAgentRunner implements Tool[CheckAgentArgs]. +type checkAgentRunner struct { + manager *AgentManager +} + +func (r *checkAgentRunner) Run(args CheckAgentArgs) (string, any, error) { + agent, ok := r.manager.Get(args.AgentID) + if !ok { + return fmt.Sprintf("Agent %s not found", args.AgentID), nil, nil + } + + switch agent.Status { + case AgentStatusRunning: + return fmt.Sprintf("Agent %s is still running. Task: %s", args.AgentID, agent.Task), agent.Status, nil + case AgentStatusCompleted: + return fmt.Sprintf("Agent %s completed. Task: %s\nResult: %s", args.AgentID, agent.Task, agent.Result), agent.Status, nil + case AgentStatusFailed: + return fmt.Sprintf("Agent %s failed. Task: %s\nError: %v", args.AgentID, agent.Task, agent.Error), agent.Status, nil + default: + return fmt.Sprintf("Agent %s has unknown status: %s", args.AgentID, agent.Status), agent.Status, nil + } +} + +// getAgentResultRunner implements Tool[GetAgentResultArgs]. +type getAgentResultRunner struct { + manager *AgentManager + ctx context.Context +} + +func (r *getAgentResultRunner) Run(args GetAgentResultArgs) (string, any, error) { + agent, ok := r.manager.Get(args.AgentID) + if !ok { + return fmt.Sprintf("Agent %s not found", args.AgentID), nil, nil + } + + if agent.Status == AgentStatusRunning { + if !args.Wait { + return fmt.Sprintf("Agent %s is still running. Use wait=true to block until completion.", args.AgentID), nil, nil + } + // Block until done or context cancelled + select { + case <-agent.done: + case <-r.ctx.Done(): + return fmt.Sprintf("Timed out waiting for agent %s", args.AgentID), nil, r.ctx.Err() + } + } + + if agent.Status == AgentStatusFailed { + return fmt.Sprintf("Agent %s failed: %v", args.AgentID, agent.Error), nil, nil + } + + return agent.Result, agent.Fragment, nil +} + +// newSpawnAgentTool creates the spawn_agent tool definition. +func newSpawnAgentTool( + llm LLM, + parentTools Tools, + manager *AgentManager, + ctx context.Context, + parentOpts []Option, + streamCB StreamCallback, + injectionChan chan openai.ChatCompletionMessage, + completionCB func(*AgentState), +) ToolDefinitionInterface { + return NewToolDefinition( + &spawnAgentRunner{ + llm: llm, + parentTools: parentTools, + parentOpts: parentOpts, + manager: manager, + ctx: ctx, + streamCB: streamCB, + messageInjectionChan: injectionChan, + agentCompletionCallback: completionCB, + }, + SpawnAgentArgs{}, + "spawn_agent", + "Spawn a sub-agent to handle a task. Use background=true for non-blocking execution, or background=false to wait for the result.", + ) +} + +// newCheckAgentTool creates the check_agent tool definition. +func newCheckAgentTool(manager *AgentManager) ToolDefinitionInterface { + return NewToolDefinition( + &checkAgentRunner{manager: manager}, + CheckAgentArgs{}, + "check_agent", + "Check the status of a background sub-agent by its ID.", + ) +} + +// newGetAgentResultTool creates the get_agent_result tool definition. +func newGetAgentResultTool(manager *AgentManager, ctx context.Context) ToolDefinitionInterface { + return NewToolDefinition( + &getAgentResultRunner{manager: manager, ctx: ctx}, + GetAgentResultArgs{}, + "get_agent_result", + "Get the result of a background sub-agent. Set wait=true to block until the agent finishes.", + ) +} diff --git a/agent_test.go b/agent_test.go new file mode 100644 index 0000000..2d357c0 --- /dev/null +++ b/agent_test.go @@ -0,0 +1,521 @@ +package cogito_test + +import ( + "context" + "sync" + "time" + + . "github.com/mudler/cogito" + "github.com/mudler/cogito/tests/mock" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/sashabaranov/go-openai" +) + +// slowToolRunner blocks until the ready channel is closed, simulating a slow tool. +type SlowToolArgs struct { + Query string `json:"query"` +} + +type slowToolRunner struct { + ready chan struct{} +} + +func (s *slowToolRunner) Run(args SlowToolArgs) (string, any, error) { + <-s.ready // Block until released + return "Slow search result for: " + args.Query, nil, nil +} + +var _ = Describe("Sub-Agent Spawning", func() { + var mockLLM *mock.MockOpenAIClient + + BeforeEach(func() { + mockLLM = mock.NewMockOpenAIClient() + }) + + Context("AgentManager", func() { + It("should register and retrieve agents", func() { + m := NewAgentManager() + agent := &AgentState{ + ID: "test-1", + Task: "test task", + Status: AgentStatusRunning, + } + m.Register(agent) + + got, ok := m.Get("test-1") + Expect(ok).To(BeTrue()) + Expect(got.Task).To(Equal("test task")) + + _, ok = m.Get("nonexistent") + Expect(ok).To(BeFalse()) + }) + + It("should list all agents", func() { + m := NewAgentManager() + m.Register(&AgentState{ID: "a1", Task: "task1", Status: AgentStatusRunning}) + m.Register(&AgentState{ID: "a2", Task: "task2", Status: AgentStatusCompleted}) + + agents := m.List() + Expect(agents).To(HaveLen(2)) + }) + + It("should wait for agent completion", func() { + m := NewAgentManager() + done := make(chan struct{}) + agent := &AgentState{ + ID: "wait-test", + Task: "waiting task", + Status: AgentStatusRunning, + } + // Use exported done channel pattern: set it manually for test + SetAgentDone(agent, done) + m.Register(agent) + + go func() { + time.Sleep(50 * time.Millisecond) + agent.Status = AgentStatusCompleted + agent.Result = "done" + close(done) + }() + + result, err := m.Wait("wait-test") + Expect(err).ToNot(HaveOccurred()) + Expect(result.Status).To(Equal(AgentStatusCompleted)) + }) + + It("should return error when waiting for nonexistent agent", func() { + m := NewAgentManager() + _, err := m.Wait("nonexistent") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("Foreground agent spawning", func() { + It("should execute sub-agent synchronously and return result", func() { + mockTool := mock.NewMockTool("search", "Search for information") + + // 1. Parent iteration 1: LLM selects spawn_agent tool + mockLLM.AddCreateChatCompletionFunction("spawn_agent", + `{"task": "Search for photosynthesis", "background": false}`) + + // --- Sub-agent starts (synchronous, consumes from same mock) --- + // 2. Sub-agent iteration 1: LLM selects search tool + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "photosynthesis"}`) + mock.SetRunResult(mockTool, "Photosynthesis converts sunlight to energy.") + + // 3. Sub-agent iteration 2: no more tools (sink state) + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Sub-agent done.", + }, + }}, + }) + + // 4. Sub-agent: Ask for final response after sink state + mockLLM.SetAskResponse("Photosynthesis is how plants convert sunlight into chemical energy.") + // --- Sub-agent ends, result returned to parent as tool output --- + + // 5. Parent iteration 2: no more tools (sink state) + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Parent done.", + }, + }}, + }) + + // 6. Parent: Ask for final response after sink state + mockLLM.SetAskResponse("The sub-agent found that photosynthesis converts sunlight to energy.") + + fragment := NewEmptyFragment().AddMessage(UserMessageRole, "Find info about photosynthesis") + + result, err := ExecuteTools(mockLLM, fragment, + WithTools(mockTool), + EnableAgentSpawning, + WithIterations(3), + ) + + Expect(err).ToNot(HaveOccurred()) + // The result should contain the parent's final response + Expect(result.LastMessage().Content).ToNot(BeEmpty()) + // Verify a spawn_agent tool was called + hasSpawnAgent := false + for _, t := range result.Status.ToolsCalled { + if t.Tool().Function.Name == "spawn_agent" { + hasSpawnAgent = true + break + } + } + Expect(hasSpawnAgent).To(BeTrue()) + }) + }) + + Context("Background agent spawning", func() { + It("should spawn agent in background and return ID", func() { + mockTool := mock.NewMockTool("search", "Search for information") + + // Parent: LLM selects spawn_agent with background=true + mockLLM.AddCreateChatCompletionFunction("spawn_agent", + `{"task": "Background task", "background": true}`) + + // Sub-agent (in goroutine): LLM selects search tool + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "background"}`) + mock.SetRunResult(mockTool, "Background result.") + + // Sub-agent: no more tools + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Done.", + }, + }}, + }) + + // Sub-agent: final ask response + mockLLM.SetAskResponse("Background task completed.") + + // Parent: after spawn returns ID, next iteration sees completion notification + // Then LLM responds with no more tools + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Agent started.", + }, + }}, + }) + + // Parent: final ask + mockLLM.SetAskResponse("Started a background agent to handle the task.") + + fragment := NewEmptyFragment().AddMessage(UserMessageRole, "Run a background task") + + manager := NewAgentManager() + result, err := ExecuteTools(mockLLM, fragment, + WithTools(mockTool), + EnableAgentSpawning, + WithAgentManager(manager), + WithIterations(5), + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(result.LastMessage().Content).ToNot(BeEmpty()) + + // Wait for background agent to complete + Eventually(func() int { + return len(manager.List()) + }, 2*time.Second, 50*time.Millisecond).Should(BeNumerically(">=", 1)) + + agents := manager.List() + if len(agents) > 0 { + // Wait for it to finish + Eventually(func() AgentStatusType { + a, _ := manager.Get(agents[0].ID) + return a.Status + }, 5*time.Second, 50*time.Millisecond).Should(Or(Equal(AgentStatusCompleted), Equal(AgentStatusFailed))) + } + }) + }) + + Context("check_agent tool", func() { + It("should return status for a known agent", func() { + manager := NewAgentManager() + manager.Register(&AgentState{ + ID: "test-check", + Task: "some task", + Status: AgentStatusCompleted, + Result: "task done", + }) + + runner := &CheckAgentRunnerForTest{Manager: manager} + result, _, err := runner.Run(CheckAgentArgs{AgentID: "test-check"}) + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(ContainSubstring("completed")) + Expect(result).To(ContainSubstring("task done")) + }) + + It("should return not found for unknown agent", func() { + manager := NewAgentManager() + runner := &CheckAgentRunnerForTest{Manager: manager} + result, _, err := runner.Run(CheckAgentArgs{AgentID: "unknown"}) + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(ContainSubstring("not found")) + }) + }) + + Context("get_agent_result tool", func() { + It("should return result for completed agent", func() { + manager := NewAgentManager() + done := make(chan struct{}) + close(done) // already done + agent := &AgentState{ + ID: "result-test", + Task: "result task", + Status: AgentStatusCompleted, + Result: "the final result", + } + SetAgentDone(agent, done) + manager.Register(agent) + + runner := &GetAgentResultRunnerForTest{Manager: manager, Ctx: context.Background()} + result, _, err := runner.Run(GetAgentResultArgs{AgentID: "result-test", Wait: false}) + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(Equal("the final result")) + }) + + It("should block with wait=true until agent completes", func() { + manager := NewAgentManager() + done := make(chan struct{}) + agent := &AgentState{ + ID: "wait-result", + Task: "waiting", + Status: AgentStatusRunning, + } + SetAgentDone(agent, done) + manager.Register(agent) + + var result string + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + runner := &GetAgentResultRunnerForTest{Manager: manager, Ctx: context.Background()} + result, _, _ = runner.Run(GetAgentResultArgs{AgentID: "wait-result", Wait: true}) + }() + + time.Sleep(50 * time.Millisecond) + agent.Status = AgentStatusCompleted + agent.Result = "waited result" + close(done) + + wg.Wait() + Expect(result).To(Equal("waited result")) + }) + + It("should return status when not waiting for running agent", func() { + manager := NewAgentManager() + done := make(chan struct{}) + agent := &AgentState{ + ID: "no-wait", + Task: "running", + Status: AgentStatusRunning, + } + SetAgentDone(agent, done) + manager.Register(agent) + + runner := &GetAgentResultRunnerForTest{Manager: manager, Ctx: context.Background()} + result, _, err := runner.Run(GetAgentResultArgs{AgentID: "no-wait", Wait: false}) + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(ContainSubstring("still running")) + }) + }) + + Context("Completion callback", func() { + It("should fire callback when background agent finishes", func() { + mockTool := mock.NewMockTool("search", "Search for information") + + // Parent: LLM selects spawn_agent with background=true + mockLLM.AddCreateChatCompletionFunction("spawn_agent", + `{"task": "Callback test", "background": true}`) + + // Sub-agent: LLM selects search + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`) + mock.SetRunResult(mockTool, "Callback result.") + + // Sub-agent: no more tools + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Done.", + }, + }}, + }) + + // Sub-agent: final ask + mockLLM.SetAskResponse("Callback task completed.") + + // Parent: after spawn + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Spawned.", + }, + }}, + }) + + // Parent: final ask + mockLLM.SetAskResponse("Done spawning.") + + var callbackAgent *AgentState + var callbackMu sync.Mutex + + fragment := NewEmptyFragment().AddMessage(UserMessageRole, "Run callback test") + + manager := NewAgentManager() + _, _ = ExecuteTools(mockLLM, fragment, + WithTools(mockTool), + EnableAgentSpawning, + WithAgentManager(manager), + WithAgentCompletionCallback(func(a *AgentState) { + callbackMu.Lock() + callbackAgent = a + callbackMu.Unlock() + }), + WithIterations(5), + ) + + // Wait for background agent to finish and callback to fire + Eventually(func() bool { + callbackMu.Lock() + defer callbackMu.Unlock() + return callbackAgent != nil + }, 5*time.Second, 50*time.Millisecond).Should(BeTrue()) + + callbackMu.Lock() + Expect(callbackAgent.Status).To(Or(Equal(AgentStatusCompleted), Equal(AgentStatusFailed))) + callbackMu.Unlock() + }) + }) + + Context("Tool filtering", func() { + It("should exclude agent tools from sub-agents by default", func() { + parentTools := Tools{} + searchTool := mock.NewMockTool("search", "Search") + spawnTool := mock.NewMockTool("spawn_agent", "Spawn agent") + checkTool := mock.NewMockTool("check_agent", "Check agent") + parentTools = append(parentTools, searchTool, spawnTool, checkTool) + + filtered := FilterToolsForSubAgent(parentTools, nil) + Expect(filtered).To(HaveLen(1)) + Expect(filtered[0].Tool().Function.Name).To(Equal("search")) + }) + + It("should filter to requested tools only", func() { + parentTools := Tools{} + searchTool := mock.NewMockTool("search", "Search") + weatherTool := mock.NewMockTool("weather", "Weather") + parentTools = append(parentTools, searchTool, weatherTool) + + filtered := FilterToolsForSubAgent(parentTools, []string{"weather"}) + Expect(filtered).To(HaveLen(1)) + Expect(filtered[0].Tool().Function.Name).To(Equal("weather")) + }) + }) + + Context("Loop stays alive for background agents", func() { + It("should keep ExecuteTools alive until background agents complete", func() { + // Use a separate mock for the sub-agent to avoid response ordering issues + subAgentMockLLM := mock.NewMockOpenAIClient() + + // A slow tool that blocks until we release it — simulates a long-running sub-agent + slowToolReady := make(chan struct{}) + slowTool := NewToolDefinition( + &slowToolRunner{ready: slowToolReady}, + SlowToolArgs{}, + "slow_search", + "A slow search tool", + ) + + // === Parent mock responses === + // 1. Parent: LLM selects spawn_agent with background=true + mockLLM.AddCreateChatCompletionFunction("spawn_agent", + `{"task": "Background research", "background": true, "tools": ["slow_search"]}`) + + // 2. Parent iteration 2: LLM replies with text (noTool). + // Background agent still running → blocks on injection channel. + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Waiting for background agent.", + }, + }}, + }) + + // 3. Parent iteration 3: after completion message injected from blocking wait, + // LLM sees result and replies (sink state / no tool) + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Got the background result, all done.", + }, + }}, + }) + + // 4. Parent: final ask after sink state (noTool with reasoning) + // Not needed since noTool with reasoning returns f directly + + // === Sub-agent mock responses (separate LLM) === + subAgentMockLLM.AddCreateChatCompletionFunction("slow_search", `{"query": "research"}`) + + subAgentMockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Sub-agent done.", + }, + }}, + }) + + subAgentMockLLM.SetAskResponse("Quantum computing is advancing rapidly.") + + fragment := NewEmptyFragment().AddMessage(UserMessageRole, "Research quantum computing in background") + + manager := NewAgentManager() + + // Release the slow tool after a short delay to ensure the parent loop + // has time to cycle through "waiting" iterations + go func() { + time.Sleep(100 * time.Millisecond) + close(slowToolReady) + }() + + result, err := ExecuteTools(mockLLM, fragment, + WithTools(slowTool), + EnableAgentSpawning, + WithAgentManager(manager), + WithAgentLLM(subAgentMockLLM), + WithIterations(20), + ) + + Expect(err).ToNot(HaveOccurred()) + + // Verify the background agent completed + agents := manager.List() + Expect(len(agents)).To(BeNumerically(">=", 1)) + for _, a := range agents { + Expect(a.Status).To(Equal(AgentStatusCompleted)) + } + + // Verify the parent processed the completion (has injected messages) + Expect(len(result.Status.InjectedMessages)).To(BeNumerically(">=", 1)) + }) + }) + + Context("Context cancellation", func() { + It("should cancel sub-agents when parent context is cancelled", func() { + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel immediately + cancel() + + fragment := NewEmptyFragment().AddMessage(UserMessageRole, "test") + + _, err := ExecuteTools(mockLLM, fragment, + EnableAgentSpawning, + WithContext(ctx), + WithIterations(1), + ) + + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/options.go b/options.go index 1ec0562..f332f0d 100644 --- a/options.go +++ b/options.go @@ -74,6 +74,12 @@ type Options struct { // AutoImprove options autoImproveState *AutoImproveState autoImproveReviewerLLM LLM + + // Sub-agent spawning options + enableAgentSpawning bool + agentManager *AgentManager + agentLLM LLM + agentCompletionCallback func(*AgentState) } type Option func(*Options) @@ -423,6 +429,36 @@ func WithAutoImproveReviewerLLM(llm LLM) Option { } } +// EnableAgentSpawning enables sub-agent spawning tools (spawn_agent, check_agent, get_agent_result). +// When enabled, the LLM can delegate tasks to sub-agents that run in foreground (blocking) or background (non-blocking). +var EnableAgentSpawning Option = func(o *Options) { + o.enableAgentSpawning = true +} + +// WithAgentManager provides an existing AgentManager for sharing across multiple ExecuteTools calls. +// If not provided and EnableAgentSpawning is set, a new AgentManager is created automatically. +func WithAgentManager(m *AgentManager) Option { + return func(o *Options) { + o.agentManager = m + } +} + +// WithAgentLLM sets a separate LLM for sub-agents to use. +// If not set, sub-agents share the parent's LLM. +func WithAgentLLM(llm LLM) Option { + return func(o *Options) { + o.agentLLM = llm + } +} + +// WithAgentCompletionCallback sets a callback that fires when any background sub-agent finishes. +// Useful for external monitoring or UI updates outside the LLM loop. +func WithAgentCompletionCallback(fn func(*AgentState)) Option { + return func(o *Options) { + o.agentCompletionCallback = fn + } +} + type defaultSinkStateTool struct{} func (d *defaultSinkStateTool) Execute(args map[string]any) (string, any, error) { diff --git a/stream.go b/stream.go index 813f03b..4fe9ea9 100644 --- a/stream.go +++ b/stream.go @@ -11,6 +11,7 @@ const ( StreamEventStatus StreamEventType = "status" // status message StreamEventDone StreamEventType = "done" // stream complete StreamEventError StreamEventType = "error" // error + StreamEventSubAgent StreamEventType = "sub_agent" // sub-agent event ) // StreamEvent represents a single streaming event from the LLM or tool pipeline. @@ -25,6 +26,7 @@ type StreamEvent struct { FinishReason string // "stop", "tool_calls", etc. (populated on done) Error error // populated on error Usage LLMUsage // populated on done + AgentID string // populated for sub-agent events } // StreamCallback is a function that receives streaming events. diff --git a/tools.go b/tools.go index e137365..bd26ca0 100644 --- a/tools.go +++ b/tools.go @@ -1077,6 +1077,44 @@ func ExecuteTools(llm LLM, f Fragment, opts ...Option) (Fragment, error) { return f, fmt.Errorf("force reasoning is enabled but sink state is not enabled") } + // Inject sub-agent tools if agent spawning is enabled + if o.enableAgentSpawning { + if o.agentManager == nil { + o.agentManager = NewAgentManager() + } + agentLLM := llm + if o.agentLLM != nil { + agentLLM = o.agentLLM + } + + // Auto-create injection channel for background completion notifications + if o.messageInjectionChan == nil { + o.messageInjectionChan = make(chan openai.ChatCompletionMessage, 16) + } + + // Collect parent options that should propagate to sub-agents (exclude agent-specific ones) + var subAgentOpts []Option + if o.maxIterations > 0 { + subAgentOpts = append(subAgentOpts, WithIterations(o.maxIterations)) + } + if o.maxAttempts > 0 { + subAgentOpts = append(subAgentOpts, WithMaxAttempts(o.maxAttempts)) + } + if o.maxRetries > 0 { + subAgentOpts = append(subAgentOpts, WithMaxRetries(o.maxRetries)) + } + + agentTools := []ToolDefinitionInterface{ + newSpawnAgentTool(agentLLM, o.tools, o.agentManager, o.context, subAgentOpts, o.streamCallback, o.messageInjectionChan, o.agentCompletionCallback), + newCheckAgentTool(o.agentManager), + newGetAgentResultTool(o.agentManager, o.context), + } + + // Append agent tools to both o.tools (for this call) and opts (so usableTools sees them) + o.tools = append(o.tools, agentTools...) + opts = append(opts, WithTools(agentTools...)) + } + // should I plan? if o.autoPlan { xlog.Debug("Checking if planning is needed") @@ -1293,15 +1331,35 @@ TOOL_LOOP: // The LLM replied with text instead of calling a tool - this is // equivalent to selecting the sink state (reply). f = f.AddMessage(AssistantMessageRole, reasoning) - // AutoImprove: run review step before returning - if o.autoImproveState != nil { - executeAutoImproveReview(llm, f, o.autoImproveState, o) - } - return f, nil } - if o.statusCallback != nil { + if o.statusCallback != nil && reasoning == "" { o.statusCallback("No tool was selected") } + // If background agents are still running, block until a completion message arrives + if o.agentManager != nil && o.agentManager.HasRunning() { + xlog.Debug("No tool selected but background agents still running, blocking for completions") + select { + case <-o.context.Done(): + return f, o.context.Err() + case msg, ok := <-o.messageInjectionChan: + if ok { + position := len(f.Messages) + f = f.AddMessage(MessageRole(msg.Role), msg.Content) + xlog.Debug("Injected background completion message", "position", position) + if o.messageInjectionResultChan != nil { + select { + case o.messageInjectionResultChan <- MessageInjectionResult{Count: 1, Position: position}: + default: + } + } + f.Status.InjectedMessages = append(f.Status.InjectedMessages, InjectedMessage{ + Message: msg, + Iteration: totalIterations, + }) + } + } + continue TOOL_LOOP + } // AutoImprove: run review step before returning if o.autoImproveState != nil { executeAutoImproveReview(llm, f, o.autoImproveState, o) @@ -1385,6 +1443,32 @@ TOOL_LOOP: // If no tools to execute and sink state was found, stop here if len(toolsToExecute) == 0 && hasSinkState { + // If background agents are still running, block until a completion message arrives + if o.agentManager != nil && o.agentManager.HasRunning() { + xlog.Debug("Sink state selected but background agents still running, blocking for completions") + hasSinkState = false // Reset so we re-enter the loop + select { + case <-o.context.Done(): + return f, o.context.Err() + case msg, ok := <-o.messageInjectionChan: + if ok { + position := len(f.Messages) + f = f.AddMessage(MessageRole(msg.Role), msg.Content) + xlog.Debug("Injected background completion message", "position", position) + if o.messageInjectionResultChan != nil { + select { + case o.messageInjectionResultChan <- MessageInjectionResult{Count: 1, Position: position}: + default: + } + } + f.Status.InjectedMessages = append(f.Status.InjectedMessages, InjectedMessage{ + Message: msg, + Iteration: totalIterations, + }) + } + } + continue TOOL_LOOP + } xlog.Debug("Only sink state selected, stopping execution") break } From 43a247d6a1099387ba887f5f46bfeed8d229f3a7 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 31 Mar 2026 16:18:30 +0000 Subject: [PATCH 2/2] add example Signed-off-by: Ettore Di Giacinto --- examples/sub-agents/main.go | 72 +++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 examples/sub-agents/main.go diff --git a/examples/sub-agents/main.go b/examples/sub-agents/main.go new file mode 100644 index 0000000..ad026e6 --- /dev/null +++ b/examples/sub-agents/main.go @@ -0,0 +1,72 @@ +package main + +import ( + "bufio" + "errors" + "fmt" + "os" + "strings" + + "github.com/mudler/cogito" + "github.com/mudler/cogito/clients" + "github.com/mudler/cogito/examples/internal/search" +) + +func main() { + model := os.Getenv("MODEL") + apiKey := os.Getenv("API_KEY") + baseURL := os.Getenv("BASE_URL") + + llm := clients.NewLocalAILLM(model, apiKey, baseURL) + + searchTool := cogito.NewToolDefinition( + &search.SearchTool{}, + search.SearchArgs{}, + "search", + "A search engine to find information about a topic", + ) + + // Share the agent manager across conversation turns so background agents + // spawned in one turn can be checked or retrieved in the next. + manager := cogito.NewAgentManager() + + f := cogito.NewEmptyFragment() + for { + reader := bufio.NewReader(os.Stdin) + fmt.Print("> ") + text, _ := reader.ReadString('\n') + text = strings.TrimSpace(text) + if text == "" { + continue + } + fmt.Println(text) + + f = f.AddMessage("user", text) + var err error + f, err = cogito.ExecuteTools( + llm, f, + cogito.WithTools(searchTool), + cogito.EnableAgentSpawning, + cogito.WithAgentManager(manager), + cogito.WithIterations(10), + cogito.WithMaxRetries(5), + cogito.DisableSinkState, + cogito.WithAgentCompletionCallback(func(a *cogito.AgentState) { + fmt.Printf("\n[agent %s] finished (%s)\n", a.ID[:8], a.Status) + }), + cogito.WithStreamCallback(func(ev cogito.StreamEvent) { + if ev.Type == cogito.StreamEventSubAgent { + fmt.Printf("[sub-agent %s] %s", ev.AgentID[:8], ev.Content) + } else { + fmt.Print(ev.Content) + } + }), + ) + if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + continue + } + + fmt.Println(f.LastMessage().Content) + } +}