From ff26466f336a888860da3ae4650f504414fbfc43 Mon Sep 17 00:00:00 2001 From: Kevin Castro Date: Sun, 29 Mar 2026 13:46:58 -0500 Subject: [PATCH] Bumped test coverage to 97.6% --- agent_internal_test.go | 73 ++++++ handoff_internal_test.go | 13 + mcp/mcp_transport_test.go | 43 +++- multimodal_internal_test.go | 18 ++ output_internal_test.go | 36 +++ provider/anthropic/anthropic_test.go | 54 ++++ provider/bedrock/bedrock_internal_test.go | 14 + provider/ollama/ollama_test.go | 12 + provider/openai/openai_internal_test.go | 16 ++ provider/openai/openai_transport_test.go | 130 ++++++++++ provider/openai/responses_internal_test.go | 139 ++++++++++ stream_internal_test.go | 284 +++++++++++++++++++++ tool/auto_test.go | 17 ++ tool/builder_test.go | 30 +++ tool/deferred_test.go | 33 +++ typed_agent_internal_test.go | 153 +++++++++++ validator_internal_test.go | 53 ++++ 17 files changed, 1117 insertions(+), 1 deletion(-) create mode 100644 agent_internal_test.go create mode 100644 multimodal_internal_test.go diff --git a/agent_internal_test.go b/agent_internal_test.go new file mode 100644 index 0000000..e74bf55 --- /dev/null +++ b/agent_internal_test.go @@ -0,0 +1,73 @@ +package agentic + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/regularkevvv/agentic-go/internal/testutil" + testprovider "github.com/regularkevvv/agentic-go/provider/test" +) + +func TestNewAgentDynamicRegistersHandoffs(t *testing.T) { + child := NewAgent[handoffDeps]("child", testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"})) + h := NewHandoff("delegate", "delegate work", child) + + agent := NewAgentDynamic[handoffDeps]( + func(ctx RunContext[handoffDeps]) (string, error) { return "dynamic", nil }, + &testutil.StubModel{NameValue: "dynamic-model"}, + WithHandoffs(h), + ) + + if agent.registry == nil || !agent.registry.Has("delegate") { + t.Fatalf("expected dynamic agent to register handoff, got %#v", agent.registry) + } +} + +func TestAgentRunAdditionalErrorPaths(t *testing.T) { + t.Run("no choices in response", func(t *testing.T) { + agent := NewAgent[any]("system", &testutil.StubModel{ + NameValue: "empty-model", + Response: &ChatResponse{}, + }) + + _, err := agent.Run(context.Background(), "prompt", nil) + if err == nil || err.Error() != "no choices in response" { + t.Fatalf("expected no choices error, got %v", err) + } + }) + + t.Run("tool call without registered tools", func(t *testing.T) { + agent := NewAgent[any]("system", &testutil.StubModel{ + NameValue: "tool-model", + Response: &ChatResponse{ + Choices: []Choice{{ + Message: NewToolUseMessage(ToolUse{ + ID: "call_1", + Name: "lookup", + Input: map[string]interface{}{"city": "Lima"}, + }), + FinishReason: FinishReasonToolCalls, + }}, + }, + }) + + _, err := agent.Run(context.Background(), "prompt", nil) + if err == nil || !strings.Contains(err.Error(), "no tools are registered") { + t.Fatalf("expected missing registry error, got %v", err) + } + }) + + t.Run("model error is wrapped", func(t *testing.T) { + agent := NewAgent[any]("system", &testutil.StubModel{ + NameValue: "error-model", + Err: errors.New("boom"), + }) + + _, err := agent.Run(context.Background(), "prompt", nil) + if err == nil || err.Error() != "model request: boom" { + t.Fatalf("expected wrapped model error, got %v", err) + } + }) +} diff --git a/handoff_internal_test.go b/handoff_internal_test.go index 98e6faf..b30d48d 100644 --- a/handoff_internal_test.go +++ b/handoff_internal_test.go @@ -167,3 +167,16 @@ func TestAddHandoffAndWithHandoffs(t *testing.T) { t.Fatalf("expected handoff to be registered via option") } } + +func TestAddHandoffPanicsOnInvalidToolDefinition(t *testing.T) { + parent := NewAgent[handoffDeps]("parent", testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"})) + child := NewAgent[handoffDeps]("child", testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"})) + + defer func() { + if r := recover(); r == nil { + t.Fatal("expected AddHandoff to panic for an invalid handoff tool") + } + }() + + parent.AddHandoff(NewHandoff("", "delegate work", child)) +} diff --git a/mcp/mcp_transport_test.go b/mcp/mcp_transport_test.go index 5c4c534..f16ddb0 100644 --- a/mcp/mcp_transport_test.go +++ b/mcp/mcp_transport_test.go @@ -39,7 +39,7 @@ func TestClientConnectListToolsCallToolAndToolsetOverSSE(t *testing.T) { sseServer := server.NewTestServer(mcpServer) defer sseServer.Close() - client := NewSSEClient("remote-tools", sseServer.URL+"/sse") + client := NewSSEClient("remote-tools", sseServer.URL+"/sse", WithHeaders(map[string]string{"X-Test": "1"})) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -147,3 +147,44 @@ func TestClientConnectOverHTTPAndStdioFailure(t *testing.T) { t.Fatal("expected stdio connect to fail for a missing command") } } + +func TestClientListToolsPaginatesOverHTTP(t *testing.T) { + mcpServer := server.NewMCPServer( + "http-server", + "1.0.0", + server.WithToolCapabilities(true), + server.WithPaginationLimit(1), + ) + for _, name := range []string{"alpha", "beta", "gamma"} { + mcpServer.AddTool( + mcpgo.NewTool(name), + func(ctx context.Context, request mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { + return mcpgo.NewToolResultText("ok"), nil + }, + ) + } + + httpHandler := server.NewStreamableHTTPServer(mcpServer) + httpServer := httptest.NewServer(httpHandler) + defer httpServer.Close() + + client := NewHTTPClient("remote-http", httpServer.URL+"/mcp", WithHeaders(map[string]string{"X-Test": "1"})) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect: %v", err) + } + + tools, err := client.listTools(ctx) + if err != nil { + t.Fatalf("listTools: %v", err) + } + if len(tools) != 3 { + t.Fatalf("expected all paginated tools, got %#v", tools) + } + + if err := client.Close(); err != nil { + t.Fatalf("Close: %v", err) + } +} diff --git a/multimodal_internal_test.go b/multimodal_internal_test.go new file mode 100644 index 0000000..814973e --- /dev/null +++ b/multimodal_internal_test.go @@ -0,0 +1,18 @@ +package agentic + +import "testing" + +func TestInferMediaTypeAdditionalExtensions(t *testing.T) { + tests := map[string]string{ + "poster.gif": "image/gif", + "texture.webp": "image/webp", + "photo.jpg": "image/jpeg", + "photo.jpeg": "image/jpeg", + } + + for path, want := range tests { + if got := inferMediaType(path); got != want { + t.Fatalf("inferMediaType(%q) = %q, want %q", path, got, want) + } + } +} diff --git a/output_internal_test.go b/output_internal_test.go index 8342171..33073d9 100644 --- a/output_internal_test.go +++ b/output_internal_test.go @@ -150,4 +150,40 @@ func TestToolOutputSpecErrors(t *testing.T) { t.Fatalf("expected invalid JSON error, got %v", err) } }) + + t.Run("returns marshal error for non-json tool input", func(t *testing.T) { + spec := NewToolOutput[outputCoverageValue]("desc") + msg := NewToolUseMessage(ToolUse{ + ID: "call_1", + Name: "__output__", + Input: map[string]interface{}{ + "value": func() {}, + }, + }) + + _, err := spec.Parse(msg) + if err == nil || !strings.Contains(err.Error(), "marshal output tool input") { + t.Fatalf("expected marshal error, got %v", err) + } + }) + + t.Run("returns unmarshal error for wrong tool input shape", func(t *testing.T) { + type numericOutput struct { + Value int `json:"value"` + } + + spec := NewToolOutput[numericOutput]("desc") + msg := NewToolUseMessage(ToolUse{ + ID: "call_1", + Name: "__output__", + Input: map[string]interface{}{ + "value": "wrong", + }, + }) + + _, err := spec.Parse(msg) + if err == nil || !strings.Contains(err.Error(), "unmarshal to") { + t.Fatalf("expected unmarshal error, got %v", err) + } + }) } diff --git a/provider/anthropic/anthropic_test.go b/provider/anthropic/anthropic_test.go index 1165582..bed21aa 100644 --- a/provider/anthropic/anthropic_test.go +++ b/provider/anthropic/anthropic_test.go @@ -2,6 +2,8 @@ package anthropic import ( "context" + "net/http" + "net/http/httptest" "testing" "github.com/anthropics/anthropic-sdk-go" @@ -300,6 +302,36 @@ func TestBuildParamsDefaults(t *testing.T) { } } +func TestBuildParamsThinkingAndResponseFormat(t *testing.T) { + model, _ := New("claude-sonnet-4-20250514", WithAPIKey("test-key")) + temp := 0.2 + + params := model.buildParams(&core.ChatRequest{ + Model: "claude-sonnet-4-20250514", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + Temperature: &temp, + ResponseFormat: &core.ResponseFormat{ + Type: "json_schema", + JSONSchema: &core.JSONSchemaFormat{ + Schema: map[string]interface{}{"type": "object"}, + }, + }, + Thinking: &core.ThinkingConfig{Enabled: true}, + }) + + if params.OutputConfig.Format.Schema["type"] != "object" { + t.Fatalf("expected output schema to be preserved, got %#v", params.OutputConfig) + } + if params.Thinking.OfEnabled == nil || params.Thinking.OfEnabled.BudgetTokens != 10000 { + t.Fatalf("expected thinking budget default, got %#v", params.Thinking) + } + if got := params.Temperature.Value; got != 1 { + t.Fatalf("expected thinking to force temperature=1, got %#v", got) + } +} + func TestConvertResponseMessageEmpty(t *testing.T) { // Test with empty content content := []anthropic.ContentBlockUnion{} @@ -388,3 +420,25 @@ func TestRequestValidationError(t *testing.T) { t.Error("expected validation error") } } + +func TestRequestServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer server.Close() + + model, err := New("claude-sonnet-4-20250514", WithAPIKey("test-key"), WithBaseURL(server.URL)) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = model.Request(context.Background(), &core.ChatRequest{ + Model: "claude-sonnet-4-20250514", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + }) + if err == nil { + t.Fatal("expected request error") + } +} diff --git a/provider/bedrock/bedrock_internal_test.go b/provider/bedrock/bedrock_internal_test.go index d6cf359..0361154 100644 --- a/provider/bedrock/bedrock_internal_test.go +++ b/provider/bedrock/bedrock_internal_test.go @@ -78,6 +78,13 @@ func TestWithProfileOption(t *testing.T) { } } +func TestNewWithInvalidProfileReturnsConfigError(t *testing.T) { + _, err := New("anthropic.test", WithRegion("us-east-1"), WithProfile("definitely-missing-bedrock-profile")) + if err == nil { + t.Fatal("expected config loading error for missing AWS profile") + } +} + func TestBedrockRequestValidationErrors(t *testing.T) { model := &Model{modelID: "anthropic.test"} @@ -226,6 +233,13 @@ func TestBuildParamsAndInputs(t *testing.T) { } } +func TestConvertOutputMessageIgnoresNonMessageOutput(t *testing.T) { + msg := convertOutputMessage(nil) + if msg.Role != core.RoleAssistant || len(msg.Content) != 0 { + t.Fatalf("expected empty assistant message for non-message output, got %#v", msg) + } +} + func TestConvertSystemBlocksAndMessage(t *testing.T) { if got := convertSystemBlocks(core.Message{}); got != nil { t.Fatalf("expected nil system blocks for empty message, got %#v", got) diff --git a/provider/ollama/ollama_test.go b/provider/ollama/ollama_test.go index 1b2e373..462fe6f 100644 --- a/provider/ollama/ollama_test.go +++ b/provider/ollama/ollama_test.go @@ -48,6 +48,18 @@ func TestNewFromEnvHost(t *testing.T) { } } +func TestNewFromEnvAPIKey(t *testing.T) { + t.Setenv("OLLAMA_API_KEY", "env-secret") + + model, err := New("llama3.2") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if model.Name() != "llama3.2" { + t.Errorf("expected name %q, got %q", "llama3.2", model.Name()) + } +} + func TestMustNew(t *testing.T) { model := MustNew("llama3.2") if model.Name() != "llama3.2" { diff --git a/provider/openai/openai_internal_test.go b/provider/openai/openai_internal_test.go index c62b5c0..b119c47 100644 --- a/provider/openai/openai_internal_test.go +++ b/provider/openai/openai_internal_test.go @@ -109,6 +109,22 @@ func TestBuildParamsAppliesOptionalFields(t *testing.T) { } } +func TestBuildParamsUsesMediumReasoningEffortForDefaultThinkingBudget(t *testing.T) { + model := &Model{model: "gpt-4o"} + + params := model.buildParams(&core.ChatRequest{ + Model: "gpt-4o", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + Thinking: &core.ThinkingConfig{Enabled: true, BudgetTokens: 10000}, + }) + + if params.ReasoningEffort != shared.ReasoningEffortMedium { + t.Fatalf("expected medium reasoning effort, got %q", params.ReasoningEffort) + } +} + func TestConvertResponseFormat(t *testing.T) { t.Run("json_object", func(t *testing.T) { got := convertResponseFormat(&core.ResponseFormat{Type: "json_object"}) diff --git a/provider/openai/openai_transport_test.go b/provider/openai/openai_transport_test.go index ae630ec..a46c18e 100644 --- a/provider/openai/openai_transport_test.go +++ b/provider/openai/openai_transport_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/openai/openai-go/option" sdkopenairesponses "github.com/openai/openai-go/responses" @@ -271,3 +272,132 @@ func TestResponsesHelpersCoverRemainingBranches(t *testing.T) { t.Fatalf("expected empty text config when schema details are missing, got %#v", got) } } + +func TestOpenAIChatRequestPropagatesTransportErrors(t *testing.T) { + server := httptest.NewServer(http.NotFoundHandler()) + baseURL := server.URL + "/v1" + server.Close() + + model, err := New( + "gpt-4o", + WithAPIKey("test-key"), + WithBaseURL(baseURL), + WithRequestOptions(option.WithHTTPClient(&http.Client{Timeout: time.Second})), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = model.Request(context.Background(), &core.ChatRequest{ + Model: "gpt-4o", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + }) + if err == nil { + t.Fatal("expected Request to propagate the transport error") + } +} + +func TestOpenAIChatStreamEmitsErrorEventOnTransportFailure(t *testing.T) { + server := httptest.NewServer(http.NotFoundHandler()) + baseURL := server.URL + "/v1" + server.Close() + + model, err := New( + "gpt-4o", + WithAPIKey("test-key"), + WithBaseURL(baseURL), + WithRequestOptions(option.WithHTTPClient(&http.Client{Timeout: time.Second})), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + + stream, err := model.RequestStream(context.Background(), &core.ChatRequest{ + Model: "gpt-4o", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + }) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + var events []core.StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(events) != 1 { + t.Fatalf("expected 1 stream event, got %d: %#v", len(events), events) + } + if events[0].Type != core.StreamEventError || events[0].Error == nil { + t.Fatalf("expected a stream error event, got %#v", events[0]) + } + if !strings.Contains(events[0].Error.Error(), "openai stream:") { + t.Fatalf("expected openai stream prefix, got %v", events[0].Error) + } +} + +func TestOpenAIResponsesRequestPropagatesTransportErrors(t *testing.T) { + server := httptest.NewServer(http.NotFoundHandler()) + baseURL := server.URL + "/v1" + server.Close() + + model, err := NewResponses( + "gpt-4.1", + WithAPIKey("test-key"), + WithBaseURL(baseURL), + WithRequestOptions(option.WithHTTPClient(&http.Client{Timeout: time.Second})), + ) + if err != nil { + t.Fatalf("NewResponses: %v", err) + } + + _, err = model.Request(context.Background(), &core.ChatRequest{ + Model: "gpt-4.1", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + }) + if err == nil { + t.Fatal("expected Request to propagate the transport error") + } + if !strings.Contains(err.Error(), "openai responses:") { + t.Fatalf("expected wrapped responses error, got %v", err) + } +} + +func TestOpenAIResponsesStreamEmitsErrorEventOnTransportFailure(t *testing.T) { + server := httptest.NewServer(http.NotFoundHandler()) + baseURL := server.URL + "/v1" + server.Close() + + model, err := NewResponses( + "gpt-4.1", + WithAPIKey("test-key"), + WithBaseURL(baseURL), + WithRequestOptions(option.WithHTTPClient(&http.Client{Timeout: time.Second})), + ) + if err != nil { + t.Fatalf("NewResponses: %v", err) + } + + stream, err := model.RequestStream(context.Background(), &core.ChatRequest{ + Model: "gpt-4.1", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + }) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + var events []core.StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(events) != 1 { + t.Fatalf("expected 1 stream event, got %d: %#v", len(events), events) + } + if events[0].Type != core.StreamEventError || events[0].Error == nil { + t.Fatalf("expected a stream error event, got %#v", events[0]) + } + if !strings.Contains(events[0].Error.Error(), "openai responses stream:") { + t.Fatalf("expected openai responses stream prefix, got %v", events[0].Error) + } +} diff --git a/provider/openai/responses_internal_test.go b/provider/openai/responses_internal_test.go index a2427ce..518009a 100644 --- a/provider/openai/responses_internal_test.go +++ b/provider/openai/responses_internal_test.go @@ -21,6 +21,21 @@ func TestResponsesRequestValidationErrors(t *testing.T) { } } +func TestNewResponsesWithOptions(t *testing.T) { + model, err := NewResponses( + "gpt-4.1", + WithAPIKey("test-key"), + WithBaseURL("https://example.com/v1"), + WithOrganization("org-123"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if model.Name() != "gpt-4.1" { + t.Fatalf("expected model name to be preserved, got %q", model.Name()) + } +} + func TestResponsesBuildParams(t *testing.T) { model := &ResponsesModel{model: "gpt-4.1"} temperature := 0.6 @@ -205,4 +220,128 @@ func TestConvertResponsesFinishReason(t *testing.T) { if got := convertResponsesFinishReason(&responses.Response{Status: "completed"}); got != core.FinishReasonStop { t.Fatalf("expected stop finish reason, got %q", got) } + + if got := convertResponsesFinishReason(&responses.Response{Status: "failed"}); got != core.FinishReasonStop { + t.Fatalf("expected failed status to map to stop, got %q", got) + } + + if got := convertResponsesFinishReason(&responses.Response{Status: "canceled"}); got != core.FinishReasonStop { + t.Fatalf("expected canceled status to map to stop, got %q", got) + } + + if got := convertResponsesFinishReason(&responses.Response{Status: "unknown"}); got != core.FinishReasonStop { + t.Fatalf("expected unknown status to map to stop, got %q", got) + } +} + +func TestResponsesSchemaHelpers(t *testing.T) { + t.Run("ensureAdditionalPropertiesFalse recurses through nested objects and arrays", func(t *testing.T) { + schema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "literal": "value", + "meta": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"id"}, + }, + "tags": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "label": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + "required": []string{"name"}, + } + + normalized := ensureAdditionalPropertiesFalse(schema) + if normalized["additionalProperties"] != false { + t.Fatalf("expected root additionalProperties=false, got %#v", normalized) + } + + required := normalized["required"].([]string) + if len(required) != 4 { + t.Fatalf("expected all root properties to become required, got %#v", required) + } + + meta := normalized["properties"].(map[string]interface{})["meta"].(map[string]interface{}) + if meta["additionalProperties"] != false { + t.Fatalf("expected nested object additionalProperties=false, got %#v", meta) + } + + tags := normalized["properties"].(map[string]interface{})["tags"].(map[string]interface{}) + items := tags["items"].(map[string]interface{}) + if items["additionalProperties"] != false { + t.Fatalf("expected array items additionalProperties=false, got %#v", items) + } + if normalized["properties"].(map[string]interface{})["literal"] != "value" { + t.Fatalf("expected non-map property values to be preserved, got %#v", normalized["properties"]) + } + + // Original input should not be mutated. + if _, ok := schema["additionalProperties"]; ok { + t.Fatalf("expected original schema to remain unchanged, got %#v", schema) + } + }) + + t.Run("convertResponsesTools handles nil params and empty description", func(t *testing.T) { + tools := convertResponsesTools([]core.Tool{ + { + Type: core.ToolTypeFunction, + Function: core.Function{ + Name: "noop", + Description: "", + Parameters: nil, + }, + }, + { + Type: core.ToolTypeFunction, + Function: core.Function{ + Name: "strict_no_desc", + Description: "", + Parameters: map[string]interface{}{"type": "object"}, + }, + }, + }) + + if len(tools) != 2 || tools[0].OfFunction == nil || tools[1].OfFunction == nil { + t.Fatalf("expected two function tools, got %#v", tools) + } + if tools[0].OfFunction.Parameters == nil { + t.Fatalf("expected nil params to become an empty map, got %#v", tools[0].OfFunction.Parameters) + } + if tools[1].OfFunction.Parameters["additionalProperties"] != false { + t.Fatalf("expected strict schema normalization, got %#v", tools[1].OfFunction.Parameters) + } + }) + + t.Run("convertTextConfig includes optional schema metadata", func(t *testing.T) { + strict := true + tc := convertTextConfig(&core.ResponseFormat{ + Type: "json_schema", + JSONSchema: &core.JSONSchemaFormat{ + Name: "output", + Description: "Structured output", + Schema: map[string]interface{}{"type": "object"}, + Strict: &strict, + }, + }) + + if tc.Format.OfJSONSchema == nil { + t.Fatal("expected JSON schema text config") + } + if tc.Format.OfJSONSchema.Description.Value != "Structured output" { + t.Fatalf("expected description to be preserved, got %#v", tc.Format.OfJSONSchema.Description) + } + if tc.Format.OfJSONSchema.Strict.Value != true { + t.Fatalf("expected strict flag to be preserved, got %#v", tc.Format.OfJSONSchema.Strict) + } + }) } diff --git a/stream_internal_test.go b/stream_internal_test.go index 8e98480..9fc7c6c 100644 --- a/stream_internal_test.go +++ b/stream_internal_test.go @@ -10,6 +10,29 @@ import ( "github.com/regularkevvv/agentic-go/internal/testutil" ) +type streamRegistryStub struct { + executeResult ToolExecutionResult + executeErr error +} + +func (s *streamRegistryStub) Register(tool Tool, handler ToolHandler) error { return nil } + +func (s *streamRegistryStub) Get(name string) (ToolHandler, bool) { return nil, false } + +func (s *streamRegistryStub) Execute(ctx context.Context, toolCall ToolUse, deps any) (ToolExecutionResult, error) { + return s.executeResult, s.executeErr +} + +func (s *streamRegistryStub) ExecuteBatch(ctx context.Context, toolCalls []ToolUse, deps any) ([]ToolExecutionResult, error) { + return nil, nil +} + +func (s *streamRegistryStub) Tools() []Tool { return nil } + +func (s *streamRegistryStub) Has(name string) bool { return false } + +func (s *streamRegistryStub) Count() int { return 0 } + func TestRunStreamTrueText(t *testing.T) { model := &testutil.ScriptedStreamModel{ Streams: [][]StreamEvent{{ @@ -239,6 +262,267 @@ func TestRunStreamTrueValidationRetryAndOutputTool(t *testing.T) { }) } +func TestRunStreamTrueAdditionalErrorPaths(t *testing.T) { + t.Run("prepare loop error is returned directly", func(t *testing.T) { + agent := NewAgentDynamic[any]( + func(ctx RunContext[any]) (string, error) { + return "", errors.New("prompt failed") + }, + &testutil.ScriptedStreamModel{NameValue: "stream-model"}, + ) + + _, err := agent.RunStream(context.Background(), "prompt", nil) + if err == nil || err.Error() != "system prompt: prompt failed" { + t.Fatalf("expected prepare loop error, got %v", err) + } + }) + + t.Run("pre request usage limit stops immediately", func(t *testing.T) { + model := &testutil.ScriptedStreamModel{NameValue: "stream-model"} + agent := NewAgent[any]( + "system", + model, + WithUsageLimits[any](UsageLimits{MaxRequests: IntPtr(0)}), + ) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || !strings.Contains(err.Error(), "usage limit exceeded: requests") { + t.Fatalf("expected pre-request usage limit error, got %v", err) + } + if len(model.Requests) != 0 { + t.Fatalf("expected no streaming requests, got %d", len(model.Requests)) + } + }) + + t.Run("build request error is forwarded", func(t *testing.T) { + model := &testutil.ScriptedStreamModel{NameValue: "stream-model"} + agent := NewAgent[any]("system", model) + + stream, err := agent.RunStream( + context.Background(), + "prompt", + nil, + WithRunHistoryProcessor(HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return nil, errors.New("boom") + })), + ) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || err.Error() != "history processor: boom" { + t.Fatalf("expected build request error, got %v", err) + } + }) + + t.Run("request stream error is wrapped", func(t *testing.T) { + model := &testutil.ScriptedStreamModel{NameValue: "stream-model"} + agent := NewAgent[any]("system", model) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || err.Error() != "model request: no scripted stream available" { + t.Fatalf("expected wrapped request stream error, got %v", err) + } + }) + + t.Run("stream consumption errors are forwarded", func(t *testing.T) { + expected := errors.New("stream failed") + model := &testutil.ScriptedStreamModel{ + NameValue: "stream-model", + Streams: [][]StreamEvent{{ + {Type: StreamEventError, Error: expected}, + }}, + } + agent := NewAgent[any]("system", model) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if !errors.Is(err, expected) { + t.Fatalf("expected %v, got %v", expected, err) + } + }) + + t.Run("post response usage limit is enforced", func(t *testing.T) { + model := &testutil.ScriptedStreamModel{ + NameValue: "stream-model", + Streams: [][]StreamEvent{{ + {Type: StreamEventTextDelta, Delta: "hello"}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }}, + } + agent := NewAgent[any]( + "system", + model, + WithUsageLimits[any](UsageLimits{MaxTotalTokens: IntPtr(0)}), + ) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || !strings.Contains(err.Error(), "usage limit exceeded: total_tokens") { + t.Fatalf("expected post-response usage limit error, got %v", err) + } + }) + + t.Run("tool call usage limit is enforced", func(t *testing.T) { + type noopInput struct{} + type noopOutput struct{} + + tool, handler := MustToolPlain("noop", "noop", func(input noopInput) (noopOutput, error) { + return noopOutput{}, nil + }) + + model := &testutil.ScriptedStreamModel{ + NameValue: "stream-model", + Streams: [][]StreamEvent{{ + {Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "call_1", Name: "noop"}}, + {Type: StreamEventToolCallDelta, ToolCallID: "call_1", Delta: `{}`}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }}, + } + agent := NewAgent[any]( + "system", + model, + WithUsageLimits[any](UsageLimits{MaxToolCalls: IntPtr(0)}), + ).AddTool(tool, handler) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || !strings.Contains(err.Error(), "usage limit exceeded: tool_calls") { + t.Fatalf("expected tool call limit error, got %v", err) + } + }) + + t.Run("registry execution errors are forwarded", func(t *testing.T) { + model := &testutil.ScriptedStreamModel{ + NameValue: "stream-model", + Streams: [][]StreamEvent{{ + {Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "call_1", Name: "noop"}}, + {Type: StreamEventToolCallDelta, ToolCallID: "call_1", Delta: `{}`}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }}, + } + agent := NewAgent[any]("system", model) + agent.registry = &streamRegistryStub{executeErr: errors.New("registry failed")} + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || err.Error() != `execute tool "noop": registry failed` { + t.Fatalf("expected registry execute error, got %v", err) + } + }) + + t.Run("model retry continues to the next streamed request", func(t *testing.T) { + type retryInput struct{} + type retryOutput struct{} + + tool, handler := MustToolPlain("retry_tool", "retry tool", func(input retryInput) (retryOutput, error) { + return retryOutput{}, Retry("try again") + }) + + model := &testutil.ScriptedStreamModel{ + NameValue: "stream-model", + Streams: [][]StreamEvent{ + { + {Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "call_1", Name: "retry_tool"}}, + {Type: StreamEventToolCallDelta, ToolCallID: "call_1", Delta: `{}`}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }, + { + {Type: StreamEventTextDelta, Delta: "done"}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }, + }, + } + agent := NewAgent[any]("system", model, WithRetries[any](RetryConfig{MaxRetries: 1})).AddTool(tool, handler) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + var events []StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(model.Requests) != 2 { + t.Fatalf("expected retry to trigger a second streamed request, got %d", len(model.Requests)) + } + if events[len(events)-1].Type != StreamEventDone { + t.Fatalf("expected done event after retry, got %#v", events) + } + + foundRetryMessage := false + for _, msg := range model.Requests[1].Messages { + results := msg.GetToolResults() + if msg.Role == RoleTool && len(results) > 0 && strings.Contains(results[0].Content, "try again") { + foundRetryMessage = true + break + } + } + if !foundRetryMessage { + t.Fatalf("expected retry tool result in follow-up request, got %#v", model.Requests[1].Messages) + } + }) + + t.Run("max iterations error is emitted", func(t *testing.T) { + type noopInput struct{} + type noopOutput struct{} + + tool, handler := MustToolPlain("noop", "noop", func(input noopInput) (noopOutput, error) { + return noopOutput{}, nil + }) + + model := &testutil.ScriptedStreamModel{ + NameValue: "stream-model", + Streams: [][]StreamEvent{{ + {Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "call_1", Name: "noop"}}, + {Type: StreamEventToolCallDelta, ToolCallID: "call_1", Delta: `{}`}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }}, + } + agent := NewAgent[any]("system", model, WithMaxIterations[any](1)).AddTool(tool, handler) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + var maxIterErr *MaxIterationsError + err = stream.Wait() + if !errors.As(err, &maxIterErr) || maxIterErr.MaxIterations != 1 { + t.Fatalf("expected max iterations error, got %v", err) + } + }) +} + func TestConsumeAndForward(t *testing.T) { t.Run("reconstructs text thinking tool calls and forwards events", func(t *testing.T) { out := make(chan StreamEvent, 8) diff --git a/tool/auto_test.go b/tool/auto_test.go index d7fd2b9..c9e6794 100644 --- a/tool/auto_test.go +++ b/tool/auto_test.go @@ -110,6 +110,11 @@ type MultiFieldDescInput struct { Value int `json:"value"` } +type PtrDescriptionInput struct { + _ struct{} `tool:"Pointer description"` + Name string `json:"name"` +} + func TestInferDescription(t *testing.T) { tests := []struct { name string @@ -131,6 +136,18 @@ func TestInferDescription(t *testing.T) { } } +func TestInferToolNameAndDescriptionAdditionalBranches(t *testing.T) { + if got := inferToolName[struct{ Value int }](); got != "" { + t.Fatalf("expected unnamed struct to infer empty name, got %q", got) + } + if got := inferDescription[*PtrDescriptionInput](); got != "Pointer description" { + t.Fatalf("expected pointer description, got %q", got) + } + if got := inferDescription[string](); got != "" { + t.Fatalf("expected non-struct description to be empty, got %q", got) + } +} + // ============================================================================ // Auto tests // ============================================================================ diff --git a/tool/builder_test.go b/tool/builder_test.go index dc8d086..cac26b4 100644 --- a/tool/builder_test.go +++ b/tool/builder_test.go @@ -2,6 +2,7 @@ package tool import ( "context" + "strings" "testing" "github.com/regularkevvv/agentic-go/internal/core" @@ -85,6 +86,11 @@ func TestPlainToolHandlerMarshalError(t *testing.T) { if err == nil { t.Error("expected error for invalid input type") } + + _, err = handler.Execute(context.TODO(), map[string]interface{}{"x": func() {}}, nil) + if err == nil || !strings.Contains(err.Error(), "marshal input") { + t.Fatalf("expected marshal input error, got %v", err) + } } func TestDepsToolHandlerMarshalError(t *testing.T) { @@ -101,6 +107,11 @@ func TestDepsToolHandlerMarshalError(t *testing.T) { if err == nil { t.Error("expected error for invalid input type") } + + _, err = handler.Execute(context.TODO(), map[string]interface{}{"x": func() {}}, nil) + if err == nil || !strings.Contains(err.Error(), "marshal input") { + t.Fatalf("expected marshal input error, got %v", err) + } } func TestRegisterToolsetError(t *testing.T) { @@ -136,3 +147,22 @@ func TestFormatToolResultJSON(t *testing.T) { t.Errorf("unexpected result: %q", got) } } + +func TestNewToolFromStructAdditionalCases(t *testing.T) { + t.Run("empty description", func(t *testing.T) { + _, err := NewToolFromStruct("search", "", struct{}{}) + if err == nil || err.Error() != "tool description cannot be empty" { + t.Fatalf("expected empty description error, got %v", err) + } + }) + + t.Run("primitive input still generates schema", func(t *testing.T) { + tool, err := NewToolFromStruct("count", "Count a value", 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tool.Function.Parameters["type"] == nil { + t.Fatalf("expected schema type to be present, got %#v", tool.Function.Parameters) + } + }) +} diff --git a/tool/deferred_test.go b/tool/deferred_test.go index 4e63e7b..489ddfb 100644 --- a/tool/deferred_test.go +++ b/tool/deferred_test.go @@ -43,6 +43,23 @@ func TestDeferredToolExecute(t *testing.T) { if out.(string) != "value" { t.Fatalf("unexpected output %#v", out) } + + t.Run("timeout-configured handler still returns immediate result", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string, 1) + ch <- "fast" + close(ch) + return ch, nil + }, WithDeferredTimeout(time.Second)) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + out, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + if execErr != nil || out.(string) != "fast" { + t.Fatalf("expected immediate result with timeout configured, got out=%#v err=%v", out, execErr) + } + }) } func TestDeferredToolExecuteErrors(t *testing.T) { @@ -276,6 +293,22 @@ func TestDeferredToolWithApproval(t *testing.T) { t.Fatalf("expected unmarshal error, got %v", execErr) } }) + + t.Run("approval handler marshal error", func(t *testing.T) { + _, handler, err := DeferredToolWithApproval("sync_approval", "sync approval", func(ctx context.Context, input deferredInput) (string, error) { + return "approved", nil + }, func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return true, nil + }) + if err != nil { + t.Fatalf("DeferredToolWithApproval: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": func() {}}, nil) + if execErr == nil || !strings.Contains(execErr.Error(), "marshal input") { + t.Fatalf("expected marshal error, got %v", execErr) + } + }) } func TestMustDeferredToolWithApproval(t *testing.T) { diff --git a/typed_agent_internal_test.go b/typed_agent_internal_test.go index e163433..1998958 100644 --- a/typed_agent_internal_test.go +++ b/typed_agent_internal_test.go @@ -23,6 +23,46 @@ func (wrongTypeOutputSpec) Parse(Message) (any, error) { return 123, nil } +type invalidModeOutputSpec struct { + *ToolOutputSpec[typedAgentCoverageValue] +} + +func (s invalidModeOutputSpec) Mode() OutputMode { + return OutputMode("invalid") +} + +type failingOutputSpec struct { + err error +} + +func (s failingOutputSpec) Tools() []Tool { + return nil +} + +func (s failingOutputSpec) Parse(Message) (any, error) { + return nil, s.err +} + +type failingResponseFormatSpec struct { + err error +} + +func (s failingResponseFormatSpec) Tools() []Tool { + return nil +} + +func (s failingResponseFormatSpec) Parse(Message) (any, error) { + return nil, s.err +} + +func (s failingResponseFormatSpec) ResponseFormat() *ResponseFormat { + return &ResponseFormat{Type: "json_object"} +} + +func (s failingResponseFormatSpec) Mode() OutputMode { + return OutputModeNative +} + func TestTypedAgentHelpersAndTextProcessorError(t *testing.T) { type input struct { Value int `json:"value"` @@ -103,3 +143,116 @@ func TestTypedAgentHelpersAndTextProcessorError(t *testing.T) { t.Fatalf("expected text processor error, got %v", err) } } + +func TestTypedAgentRunFallsBackToToolModeForUnknownOutputMode(t *testing.T) { + spec := invalidModeOutputSpec{ + ToolOutputSpec: NewToolOutput[typedAgentCoverageValue]("desc"), + } + + model := &testutil.StubModel{ + NameValue: "typed-model", + Response: &ChatResponse{ + Choices: []Choice{{ + Message: NewToolUseMessage(ToolUse{ + ID: "call_1", + Name: "__output__", + Input: map[string]interface{}{ + "value": "fallback", + }, + }), + FinishReason: FinishReasonToolCalls, + }}, + }, + } + + agent := NewAgent[any]("system", model) + registerOutputTools(agent, spec.ToolOutputSpec) + + ta := &TypedAgent[any, typedAgentCoverageValue]{ + agent: agent, + outputSpec: spec, + } + + result, err := ta.Run(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("Run: %v", err) + } + if result.Output.Value != "fallback" { + t.Fatalf("expected fallback output, got %#v", result.Output) + } +} + +func TestTypedAgentRunAdditionalErrorPaths(t *testing.T) { + t.Run("tool output parse error is wrapped", func(t *testing.T) { + ta := &TypedAgent[any, typedAgentCoverageValue]{ + agent: NewAgent[any]("system", &testutil.StubModel{ + NameValue: "typed-model", + Response: &ChatResponse{ + Choices: []Choice{{ + Message: NewTextMessage(RoleAssistant, "plain text"), + FinishReason: FinishReasonStop, + }}, + }, + }), + outputSpec: failingOutputSpec{err: errors.New("bad parse")}, + } + + _, err := ta.Run(context.Background(), "prompt", nil) + if err == nil || err.Error() != "parse structured output: bad parse" { + t.Fatalf("expected wrapped parse error, got %v", err) + } + }) + + t.Run("response format parse error is wrapped", func(t *testing.T) { + ta := &TypedAgent[any, typedAgentCoverageValue]{ + agent: NewAgent[any]("system", &testutil.StubModel{ + NameValue: "typed-model", + Response: &ChatResponse{ + Choices: []Choice{{ + Message: NewTextMessage(RoleAssistant, `{}`), + FinishReason: FinishReasonStop, + }}, + }, + }), + outputSpec: failingResponseFormatSpec{err: errors.New("bad parse")}, + } + + _, err := ta.Run(context.Background(), "prompt", nil) + if err == nil || err.Error() != "parse structured output: bad parse" { + t.Fatalf("expected wrapped parse error, got %v", err) + } + }) + + t.Run("response format model error is returned", func(t *testing.T) { + ta := &TypedAgent[any, typedAgentCoverageValue]{ + agent: NewAgent[any]("system", &testutil.StubModel{ + NameValue: "typed-model", + Err: errors.New("boom"), + }), + outputSpec: failingResponseFormatSpec{err: errors.New("unused")}, + } + + _, err := ta.Run(context.Background(), "prompt", nil) + if err == nil || err.Error() != "model request: boom" { + t.Fatalf("expected model error, got %v", err) + } + }) + + t.Run("text processor model error is returned", func(t *testing.T) { + ta := NewTypedAgentWithMode[any, int]( + "system", + &testutil.StubModel{ + NameValue: "typed-model", + Err: errors.New("boom"), + }, + NewTextProcessorOutput(func(text string) (int, error) { + return 0, nil + }), + ) + + _, err := ta.Run(context.Background(), "prompt", nil) + if err == nil || err.Error() != "model request: boom" { + t.Fatalf("expected model error, got %v", err) + } + }) +} diff --git a/validator_internal_test.go b/validator_internal_test.go index cd89cc8..593d681 100644 --- a/validator_internal_test.go +++ b/validator_internal_test.go @@ -101,3 +101,56 @@ func TestValidateStructAdditionalTagsAndTypedValidator(t *testing.T) { t.Fatalf("expected typed validator to pass, got %v", err) } } + +func TestValidateStructPointerAndFieldNameFallbacks(t *testing.T) { + tests := []struct { + name string + value any + substr string + }{ + { + name: "pointer input uses json field name", + value: &struct { + DisplayName string `json:"display_name" validate:"required"` + }{}, + substr: "display_name is required", + }, + { + name: "dash json tag falls back to Go name", + value: struct { + Hidden string `json:"-" validate:"required"` + }{}, + substr: "Hidden is required", + }, + { + name: "missing json tag falls back to Go name", + value: struct { + PlainField string `validate:"required"` + }{}, + substr: "PlainField is required", + }, + { + name: "string min message", + value: struct { + Body string `json:"body" validate:"min=4"` + }{Body: "hey"}, + substr: "body must be at least 4 characters long", + }, + { + name: "string max message", + value: struct { + Body string `json:"body" validate:"max=2"` + }{Body: "long"}, + substr: "body must be at most 2 characters long", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateStruct(tt.value) + if err == nil || !strings.Contains(err.Error(), tt.substr) { + t.Fatalf("expected error containing %q, got %v", tt.substr, err) + } + }) + } +}