diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 47a35d0..4011417 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,6 +42,18 @@ jobs: - name: Check build run: go build ./... + coverage: + name: Coverage + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: Enforce coverage threshold + run: make coverage-check + vet: name: Vet runs-on: ubuntu-latest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3c8174f..abc53fb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,6 +26,7 @@ Thank you for your interest in contributing! ```bash make test # run tests +make coverage # run coverage check inputs (excludes e2e/examples) make lint # run linter make vet # run go vet make fmt # format code @@ -55,6 +56,12 @@ make test-e2e - Update documentation if you change public APIs - Write a clear description explaining **what** and **why** +### Test Organization + +- Keep tests close to the subject they verify; avoid catch-all files like `coverage_*_test.go` +- Prefer `*_test.go` for behavior tests, `*_internal_test.go` for same-package white-box tests, and `*_transport_test.go` for local protocol/server tests +- Reserve `e2e/` for smoke tests that hit real external APIs and run them via `make test-e2e` + ## Reporting Bugs Open an issue with: diff --git a/Makefile b/Makefile index a338f35..7cda6bb 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,8 @@ -.PHONY: test lint vet build fmt check clean test-e2e +.PHONY: test lint vet build fmt check clean test-e2e coverage coverage-check GOLANGCI_LINT_VERSION := v2.1.6 +COVERAGE_PACKAGES := $(shell go list ./... | grep -vE '/(e2e|examples|internal/testutil)($$|/)') +COVERAGE_THRESHOLD := 95.0 # Run all tests (excluding e2e) test: @@ -8,7 +10,23 @@ test: # Run end-to-end tests (requires API keys) test-e2e: - go test -race -count=1 -timeout 300s ./e2e/... + go test -tags=e2e -race -count=1 -timeout 300s ./e2e/... + +# Run coverage excluding e2e, examples, and test-only helper packages +coverage: + go test -count=1 -coverprofile=coverage.out $(COVERAGE_PACKAGES) + @go tool cover -func=coverage.out | tail -n 1 + +# Enforce minimum coverage threshold for CI +coverage-check: coverage + @pct=$$(go tool cover -func=coverage.out | awk '/^total:/ { gsub(/%/, "", $$3); print $$3 }'); \ + awk -v pct="$$pct" -v threshold="$(COVERAGE_THRESHOLD)" 'BEGIN { \ + if ((pct + 0) < (threshold + 0)) { \ + printf("coverage %.1f%% is below required %.1f%%\n", pct + 0, threshold + 0); \ + exit 1; \ + } \ + printf("coverage %.1f%% meets required %.1f%%\n", pct + 0, threshold + 0); \ + }' # Run linter (same version as CI) lint: diff --git a/agent_options_test.go b/agent_options_test.go index 74ddaed..83191d4 100644 --- a/agent_options_test.go +++ b/agent_options_test.go @@ -1,6 +1,9 @@ package agentic -import "testing" +import ( + "context" + "testing" +) func TestWithTemperature(t *testing.T) { cfg := defaultAgentConfig[any]() @@ -79,3 +82,28 @@ func TestDefaultAgentConfig(t *testing.T) { t.Errorf("expected default maxRetries 1, got %d", cfg.retryConfig.MaxRetries) } } + +func TestHistoryProcessorOptions(t *testing.T) { + proc := HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return append(messages, NewTextMessage(RoleAssistant, "processed")), nil + }) + + cfg := defaultAgentConfig[any]() + WithHistoryProcessor[any](proc)(&cfg) + if cfg.historyProcessor == nil { + t.Fatal("expected agent history processor to be set") + } + + opts := applyRunOptions([]RunOption{WithRunHistoryProcessor(proc)}) + if opts.historyProcessor == nil { + t.Fatal("expected run history processor to be set") + } + + got, err := opts.historyProcessor.Process(context.Background(), []Message{NewTextMessage(RoleUser, "hello")}) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != 2 || got[1].GetTextContent() != "processed" { + t.Fatalf("unexpected processed messages %#v", got) + } +} diff --git a/agent_extra_test.go b/agent_runtime_test.go similarity index 74% rename from agent_extra_test.go rename to agent_runtime_test.go index a129aff..e388a4f 100644 --- a/agent_extra_test.go +++ b/agent_runtime_test.go @@ -2,11 +2,12 @@ package agentic_test import ( "context" - "errors" "fmt" + "strings" "testing" agentic "github.com/regularkevvv/agentic-go" + "github.com/regularkevvv/agentic-go/internal/testutil" "github.com/regularkevvv/agentic-go/provider/test" ) @@ -40,7 +41,7 @@ func TestAgentDynamicPromptError(t *testing.T) { if err == nil { t.Fatal("expected error from dynamic prompt") } - if !containsStr(err.Error(), "system prompt") { + if !strings.Contains(err.Error(), "system prompt") { t.Errorf("expected 'system prompt' in error, got %q", err.Error()) } } @@ -113,27 +114,30 @@ func TestAgentWithMessagesContainingSystemPrompt(t *testing.T) { } func TestAgentModelRequestError(t *testing.T) { - model := &errorModel{err: fmt.Errorf("api error")} + model := &testutil.StubModel{NameValue: "error-model", Err: fmt.Errorf("api error")} agent := agentic.NewAgent[any]("test", model) _, err := agent.Run(context.Background(), "hello", nil) if err == nil { t.Fatal("expected error from model request") } - if !containsStr(err.Error(), "model request") { + if !strings.Contains(err.Error(), "model request") { t.Errorf("expected 'model request' in error, got %q", err.Error()) } } func TestAgentNoChoicesInResponse(t *testing.T) { - model := &emptyChoicesModel{} + model := &testutil.StubModel{ + NameValue: "empty-model", + Response: &agentic.ChatResponse{Choices: nil}, + } agent := agentic.NewAgent[any]("test", model) _, err := agent.Run(context.Background(), "hello", nil) if err == nil { t.Fatal("expected error for no choices") } - if !containsStr(err.Error(), "no choices") { + if !strings.Contains(err.Error(), "no choices") { t.Errorf("expected 'no choices' in error, got %q", err.Error()) } } @@ -173,45 +177,6 @@ func TestAgentAddToolPanicsOnDuplicate(t *testing.T) { agent.AddTool(tool, handler) // should panic } -func TestRunOutputStructuredFull(t *testing.T) { - type Result struct { - Name string `json:"name" description:"Name"` - Score int `json:"score" description:"Score"` - } - - outputSpec := agentic.NewToolOutput[Result]("Provide result") - - model := test.NewTestModel( - test.ModelResponse{ - ToolCalls: []agentic.ToolUse{ - { - ID: "c1", - Name: "__output__", - Input: map[string]interface{}{ - "name": "test", - "score": float64(100), - }, - }, - }, - }, - ) - - agent := agentic.NewAgent[any]("test", model) - - result, err := agentic.RunOutput[any, Result]( - context.Background(), agent, "go", nil, outputSpec, - ) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.Output.Name != "test" { - t.Errorf("expected name %q, got %q", "test", result.Output.Name) - } - if result.Output.Score != 100 { - t.Errorf("expected score 100, got %d", result.Output.Score) - } -} - func TestFirstNonNil(t *testing.T) { // Tested indirectly through agent options, but let's verify behavior // by using run options that override agent options @@ -260,40 +225,3 @@ func TestAgentNewAgentDynamicWithOptions(t *testing.T) { t.Errorf("expected %q, got %q", "ok", result.Output) } } - -// Helper models for testing error paths - -type errorModel struct { - err error -} - -func (m *errorModel) Request(ctx context.Context, req *agentic.ChatRequest) (*agentic.ChatResponse, error) { - return nil, m.err -} - -func (m *errorModel) Name() string { return "error-model" } - -type emptyChoicesModel struct{} - -func (m *emptyChoicesModel) Request(ctx context.Context, req *agentic.ChatRequest) (*agentic.ChatResponse, error) { - return &agentic.ChatResponse{Choices: nil}, nil -} - -func (m *emptyChoicesModel) Name() string { return "empty-model" } - -func containsStr(s, sub string) bool { - return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstring(s, sub)) -} - -func containsSubstring(s, sub string) bool { - return errors.New(s).Error() != "" && len(sub) > 0 && findSubstring(s, sub) -} - -func findSubstring(s, sub string) bool { - for i := 0; i <= len(s)-len(sub); i++ { - if s[i:i+len(sub)] == sub { - return true - } - } - return false -} diff --git a/agent_toolset_test.go b/agent_toolset_test.go deleted file mode 100644 index e5d9c8f..0000000 --- a/agent_toolset_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package agentic - -import ( - "context" - "testing" -) - -func TestAddToolset(t *testing.T) { - type In struct { - X int `json:"x"` - } - type Out struct { - Y int `json:"y"` - } - - t1, h1 := MustToolPlain("tool_a", "tool a", func(in In) (Out, error) { return Out{Y: 1}, nil }) - t2, h2 := MustToolPlain("tool_b", "tool b", func(in In) (Out, error) { return Out{Y: 2}, nil }) - - ts := NewToolset().Add(t1, h1).Add(t2, h2) - - model := &mockModelSimple{name: "test", response: &ChatResponse{ - Choices: []Choice{{Message: NewTextMessage(RoleAssistant, "done"), FinishReason: FinishReasonStop}}, - }} - - agent := NewAgent[any]("test", model).AddToolset(ts) - - if agent.registry == nil { - t.Fatal("expected registry to be created") - } - if !agent.registry.Has("tool_a") { - t.Error("expected registry to have tool_a") - } - if !agent.registry.Has("tool_b") { - t.Error("expected registry to have tool_b") - } -} - -// mockModelSimple is a minimal mock for internal tests -type mockModelSimple struct { - name string - response *ChatResponse -} - -func (m *mockModelSimple) Request(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { - return m.response, nil -} - -func (m *mockModelSimple) Name() string { - return m.name -} diff --git a/aliases_test.go b/aliases_test.go new file mode 100644 index 0000000..b273464 --- /dev/null +++ b/aliases_test.go @@ -0,0 +1,235 @@ +package agentic + +import ( + "context" + "testing" + "time" + + "github.com/regularkevvv/agentic-go/internal/testutil" + testprovider "github.com/regularkevvv/agentic-go/provider/test" +) + +type aliasInput struct { + _ struct{} `tool:"Alias greeting tool"` + Name string `json:"name" description:"Name to greet"` +} + +type aliasOutput struct { + Greeting string `json:"greeting"` +} + +type aliasDeps struct { + Prefix string +} + +type aliasTypedOutput struct { + Result string `json:"result"` +} + +func TestAddToolset(t *testing.T) { + type In struct { + X int `json:"x"` + } + type Out struct { + Y int `json:"y"` + } + + t1, h1 := MustToolPlain("tool_a", "tool a", func(in In) (Out, error) { return Out{Y: 1}, nil }) + t2, h2 := MustToolPlain("tool_b", "tool b", func(in In) (Out, error) { return Out{Y: 2}, nil }) + + ts := NewToolset().Add(t1, h1).Add(t2, h2) + model := &testutil.StubModel{ + NameValue: "test", + Response: &ChatResponse{ + Choices: []Choice{{Message: NewTextMessage(RoleAssistant, "done"), FinishReason: FinishReasonStop}}, + }, + } + + agent := NewAgent[any]("test", model).AddToolset(ts) + + if agent.registry == nil { + t.Fatal("expected registry to be created") + } + if !agent.registry.Has("tool_a") { + t.Error("expected registry to have tool_a") + } + if !agent.registry.Has("tool_b") { + t.Error("expected registry to have tool_b") + } +} + +func TestAliasToolWrappers(t *testing.T) { + if tool := MustNewToolFromStruct("alias_struct", "alias struct", aliasInput{}); tool.Function.Name != "alias_struct" { + t.Fatalf("unexpected tool name %q", tool.Function.Name) + } + + plainTool, plainHandler, err := ToolPlain("alias_plain", "alias plain", func(input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: "hello " + input.Name}, nil + }) + if err != nil { + t.Fatalf("ToolPlain: %v", err) + } + if plainTool.Function.Name != "alias_plain" { + t.Fatalf("unexpected plain tool name %q", plainTool.Function.Name) + } + if out, err := plainHandler.Execute(context.Background(), map[string]interface{}{"name": "world"}, nil); err != nil || out.(aliasOutput).Greeting != "hello world" { + t.Fatalf("unexpected plain handler result: out=%#v err=%v", out, err) + } + + depsTool, depsHandler, err := ToolWithDeps[aliasInput, aliasOutput, aliasDeps]( + "alias_with_deps", + "alias deps", + func(ctx RunContext[aliasDeps], input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: ctx.Deps.Prefix + input.Name}, nil + }, + ) + if err != nil { + t.Fatalf("ToolWithDeps: %v", err) + } + if depsTool.Function.Name != "alias_with_deps" { + t.Fatalf("unexpected deps tool name %q", depsTool.Function.Name) + } + if out, err := depsHandler.Execute(context.Background(), map[string]interface{}{"name": "world"}, &aliasDeps{Prefix: "hi "}); err != nil || out.(aliasOutput).Greeting != "hi world" { + t.Fatalf("unexpected deps handler result: out=%#v err=%v", out, err) + } +} + +func TestAliasToolsetAndAutoWrappers(t *testing.T) { + t1, h1, err := AutoTool(func(input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: "hello " + input.Name}, nil + }, AutoToolName("alias_auto"), AutoToolDescription("alias auto tool")) + if err != nil { + t.Fatalf("AutoTool: %v", err) + } + if t1.Function.Name != "alias_auto" || t1.Function.Description != "alias auto tool" { + t.Fatalf("unexpected auto tool metadata: %#v", t1.Function) + } + + t2, h2, err := AutoToolWithDeps[aliasInput, aliasOutput, aliasDeps]( + func(ctx RunContext[aliasDeps], input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: ctx.Deps.Prefix + input.Name}, nil + }, + AutoToolName("alias_auto_deps"), + ) + if err != nil { + t.Fatalf("AutoToolWithDeps: %v", err) + } + + if _, handler := MustAutoToolWithDeps[aliasInput, aliasOutput, aliasDeps]( + func(ctx RunContext[aliasDeps], input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: ctx.Deps.Prefix + input.Name}, nil + }, + AutoToolName("alias_auto_must"), + ); handler == nil { + t.Fatal("expected MustAutoToolWithDeps to return a handler") + } + + set1 := NewToolset().Add(t1, h1) + set2 := NewToolset().Add(t2, h2) + combined := CombineToolsets(set1, set2) + filtered := FilterToolset(combined, func(name string) bool { return name == "alias_auto" }) + prefixed := PrefixToolset(filtered, "pref") + + registry := NewRegistry() + if err := RegisterToolset(registry, prefixed); err != nil { + t.Fatalf("RegisterToolset: %v", err) + } + if !registry.Has("pref__alias_auto") { + t.Fatalf("expected prefixed tool to be registered") + } +} + +func TestAliasAgentAndDeferredWrappers(t *testing.T) { + plainModel := testprovider.NewTestModel( + testprovider.ModelResponse{ + ToolCalls: []ToolUse{{ + ID: "call_1", + Name: "alias_input", + Input: map[string]interface{}{"name": "world"}, + }}, + }, + testprovider.ModelResponse{Text: "done"}, + ) + plainAgent := AddTool(NewAgent[any]("system", plainModel), func(input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: "hello " + input.Name}, nil + }) + if _, err := plainAgent.Run(context.Background(), "run tool", nil); err != nil { + t.Fatalf("AddTool run: %v", err) + } + + depsModel := testprovider.NewTestModel( + testprovider.ModelResponse{ + ToolCalls: []ToolUse{{ + ID: "call_1", + Name: "alias_input", + Input: map[string]interface{}{"name": "world"}, + }}, + }, + testprovider.ModelResponse{Text: "done"}, + ) + depsAgent := AddToolWithDeps(NewAgent[aliasDeps]("system", depsModel), func(ctx RunContext[aliasDeps], input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: ctx.Deps.Prefix + input.Name}, nil + }) + if _, err := depsAgent.Run(context.Background(), "run tool", &aliasDeps{Prefix: "hi "}); err != nil { + t.Fatalf("AddToolWithDeps run: %v", err) + } + + typed := NewTypedAgent[any, aliasTypedOutput]( + "system", + testprovider.NewTestModel(testprovider.ModelResponse{Text: `{"result":"ok"}`}), + "Return typed output", + ) + before := typed.agent.registry.Count() + AddToolToTyped(typed, func(input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: "hello " + input.Name}, nil + }) + if typed.agent.registry.Count() <= before { + t.Fatalf("expected typed agent registry count to increase") + } + + approval := WithApproval(func(ctx context.Context, toolCall ToolUse) (bool, error) { + return true, nil + }) + timeout := WithDeferredTimeout(time.Second) + + asyncTool, asyncHandler, err := DeferredTool("alias_deferred", "alias deferred", func(ctx context.Context, input aliasInput) (<-chan aliasOutput, error) { + ch := make(chan aliasOutput, 1) + ch <- aliasOutput{Greeting: "hello " + input.Name} + close(ch) + return ch, nil + }, approval, timeout) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + if asyncTool.Function.Name != "alias_deferred" { + t.Fatalf("unexpected deferred tool name %q", asyncTool.Function.Name) + } + if out, err := asyncHandler.Execute(context.Background(), map[string]interface{}{"name": "world"}, nil); err != nil || out.(aliasOutput).Greeting != "hello world" { + t.Fatalf("unexpected deferred output: out=%#v err=%v", out, err) + } + + if _, handler := MustDeferredTool("alias_deferred_must", "alias deferred", func(ctx context.Context, input aliasInput) (<-chan aliasOutput, error) { + ch := make(chan aliasOutput, 1) + ch <- aliasOutput{Greeting: "must " + input.Name} + close(ch) + return ch, nil + }); handler == nil { + t.Fatal("expected MustDeferredTool to return a handler") + } + + if _, handler, err := DeferredToolWithApproval("alias_deferred_approval", "alias deferred approval", func(ctx context.Context, input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: "approved " + input.Name}, nil + }, func(ctx context.Context, toolCall ToolUse) (bool, error) { + return true, nil + }); err != nil || handler == nil { + t.Fatalf("DeferredToolWithApproval: handler=%#v err=%v", handler, err) + } + + if _, handler := MustDeferredToolWithApproval("alias_deferred_approval_must", "alias deferred approval", func(ctx context.Context, input aliasInput) (aliasOutput, error) { + return aliasOutput{Greeting: "approved " + input.Name}, nil + }, func(ctx context.Context, toolCall ToolUse) (bool, error) { + return true, nil + }); handler == nil { + t.Fatal("expected MustDeferredToolWithApproval to return a handler") + } +} diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 7f5f3b0..14809aa 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -1,3 +1,6 @@ +//go:build e2e +// +build e2e + // Package e2e contains end-to-end tests that make real LLM API calls. // // These tests require API keys to be set as environment variables: diff --git a/handoff_internal_test.go b/handoff_internal_test.go new file mode 100644 index 0000000..98e6faf --- /dev/null +++ b/handoff_internal_test.go @@ -0,0 +1,169 @@ +package agentic + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/regularkevvv/agentic-go/internal/testutil" + testprovider "github.com/regularkevvv/agentic-go/provider/test" +) + +type handoffDeps struct { + User string +} + +func TestNewHandoffOptions(t *testing.T) { + child := NewAgent[handoffDeps]("child", testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"})) + h := NewHandoff("delegate", "delegate work", child, + WithHandoffFilter(HandoffFullHistory), + WithHandoffSystemPrompt("sub prompt"), + WithHandoffMaxTokens(77), + ) + + if h.name != "delegate" || h.description != "delegate work" { + t.Fatalf("unexpected handoff metadata: %#v", h) + } + if h.config.inputFilter != HandoffFullHistory { + t.Fatalf("expected filter %v, got %v", HandoffFullHistory, h.config.inputFilter) + } + if h.config.systemPrompt != "sub prompt" { + t.Fatalf("expected system prompt override, got %q", h.config.systemPrompt) + } + if h.config.maxTokens == nil || *h.config.maxTokens != 77 { + t.Fatalf("expected max tokens override, got %#v", h.config.maxTokens) + } +} + +func TestHandoffHandlerExecuteSuccess(t *testing.T) { + childModel := testprovider.NewTestModel(testprovider.ModelResponse{Text: "delegated result"}) + child := NewAgent[handoffDeps]("child system", childModel) + h := NewHandoff("delegate", "delegate work", child, + WithHandoffSystemPrompt("Delegate carefully"), + WithHandoffMaxTokens(33), + ) + handler := &handoffHandler[handoffDeps]{ + name: "delegate", + handoff: h, + } + + if handler.Name() != "delegate" { + t.Fatalf("expected handler name %q, got %q", "delegate", handler.Name()) + } + + out, err := handler.Execute(context.Background(), map[string]interface{}{"task": "review this"}, &handoffDeps{User: "kevin"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if out != "delegated result" { + t.Fatalf("expected delegated output, got %#v", out) + } + + calls := childModel.Calls() + if len(calls) != 1 { + t.Fatalf("expected 1 child model call, got %d", len(calls)) + } + if calls[0].MaxTokens == nil || *calls[0].MaxTokens != 33 { + t.Fatalf("expected max tokens override on child request, got %#v", calls[0].MaxTokens) + } + if len(calls[0].Messages) < 2 { + t.Fatalf("expected system + user message, got %#v", calls[0].Messages) + } + if calls[0].Messages[1].GetTextContent() != "Delegate carefully\n\nreview this" { + t.Fatalf("unexpected delegated prompt: %q", calls[0].Messages[1].GetTextContent()) + } +} + +func TestHandoffHandlerExecuteErrors(t *testing.T) { + t.Run("marshal input", func(t *testing.T) { + child := NewAgent[handoffDeps]("child", testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"})) + handler := &handoffHandler[handoffDeps]{ + name: "delegate", + handoff: NewHandoff("delegate", "delegate work", child), + } + + _, err := handler.Execute(context.Background(), map[string]interface{}{"task": func() {}}, nil) + if err == nil || !strings.Contains(err.Error(), "marshal input") { + t.Fatalf("expected marshal error, got %v", err) + } + }) + + t.Run("unmarshal typed input", func(t *testing.T) { + child := NewAgent[handoffDeps]("child", testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"})) + handler := &handoffHandler[handoffDeps]{ + name: "delegate", + handoff: NewHandoff("delegate", "delegate work", child), + } + + _, err := handler.Execute(context.Background(), map[string]interface{}{"task": map[string]interface{}{"nested": true}}, nil) + if err == nil || !strings.Contains(err.Error(), "unmarshal handoff input") { + t.Fatalf("expected unmarshal error, got %v", err) + } + }) + + t.Run("invalid deps type", func(t *testing.T) { + child := NewAgent[handoffDeps]("child", testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"})) + handler := &handoffHandler[handoffDeps]{ + name: "delegate", + handoff: NewHandoff("delegate", "delegate work", child), + } + + _, err := handler.Execute(context.Background(), map[string]interface{}{"task": "review"}, "wrong") + if err == nil || !strings.Contains(err.Error(), "invalid deps type") { + t.Fatalf("expected deps type error, got %v", err) + } + }) + + t.Run("child run error", func(t *testing.T) { + expected := errors.New("child failed") + child := NewAgent[handoffDeps]("child", &testutil.StubModel{NameValue: "handoff-error", Err: expected}) + handler := &handoffHandler[handoffDeps]{ + name: "delegate", + handoff: NewHandoff("delegate", "delegate work", child), + } + + _, err := handler.Execute(context.Background(), map[string]interface{}{"task": "review"}, &handoffDeps{}) + if err == nil || !strings.Contains(err.Error(), `handoff to "delegate": model request: child failed`) { + t.Fatalf("expected wrapped child error, got %v", err) + } + }) +} + +func TestAddHandoffAndWithHandoffs(t *testing.T) { + childModel := testprovider.NewTestModel(testprovider.ModelResponse{Text: "delegated"}) + child := NewAgent[handoffDeps]("child", childModel) + h := NewHandoff("delegate", "delegate work", child) + + parentModel := testprovider.NewTestModel( + testprovider.ModelResponse{ + ToolCalls: []ToolUse{{ + ID: "call_1", + Name: "delegate", + Input: map[string]interface{}{"task": "child task"}, + }}, + }, + testprovider.ModelResponse{Text: "finished"}, + ) + + parent := NewAgent[handoffDeps]("parent", parentModel).AddHandoff(h) + result, err := parent.Run(context.Background(), "delegate this", &handoffDeps{User: "kevin"}) + if err != nil { + t.Fatalf("Run: %v", err) + } + if result.Output != "finished" { + t.Fatalf("expected final parent output, got %q", result.Output) + } + if len(childModel.Calls()) != 1 { + t.Fatalf("expected handoff child to be invoked once, got %d", len(childModel.Calls())) + } + + withOption := NewAgent[handoffDeps]( + "parent", + testprovider.NewTestModel(testprovider.ModelResponse{Text: "ok"}), + WithHandoffs(h), + ) + if withOption.registry == nil || !withOption.registry.Has("delegate") { + t.Fatalf("expected handoff to be registered via option") + } +} diff --git a/history_processor_internal_test.go b/history_processor_internal_test.go new file mode 100644 index 0000000..b22efcb --- /dev/null +++ b/history_processor_internal_test.go @@ -0,0 +1,300 @@ +package agentic + +import ( + "context" + "errors" + "testing" + "time" +) + +type historyStubModel struct { + requests []*ChatRequest + resp *ChatResponse + err error +} + +func (m *historyStubModel) Request(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + m.requests = append(m.requests, req) + if m.err != nil { + return nil, m.err + } + if m.resp != nil { + return m.resp, nil + } + return &ChatResponse{ + Model: "history-stub", + Created: time.Unix(0, 0), + Choices: []Choice{{ + Index: 0, + Message: NewTextMessage(RoleAssistant, "summary"), + }}, + }, nil +} + +func (m *historyStubModel) Name() string { + return "history-stub" +} + +func TestHistoryProcessorFunc(t *testing.T) { + processor := HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return append(messages, NewTextMessage(RoleAssistant, "processed")), nil + }) + + got, err := processor.Process(context.Background(), []Message{NewTextMessage(RoleUser, "hello")}) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != 2 { + t.Fatalf("expected 2 messages, got %d", len(got)) + } + if got[1].GetTextContent() != "processed" { + t.Fatalf("expected appended message, got %q", got[1].GetTextContent()) + } +} + +func TestTruncateHistory(t *testing.T) { + t.Run("returns original when total length fits", func(t *testing.T) { + messages := []Message{ + NewTextMessage(RoleUser, "one"), + NewTextMessage(RoleAssistant, "two"), + } + + got, err := TruncateHistory(2).Process(context.Background(), messages) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != len(messages) { + t.Fatalf("expected %d messages, got %d", len(messages), len(got)) + } + }) + + t.Run("returns original when only system plus max tail remains", func(t *testing.T) { + messages := []Message{ + NewTextMessage(RoleSystem, "system"), + NewTextMessage(RoleUser, "one"), + NewTextMessage(RoleAssistant, "two"), + } + + got, err := TruncateHistory(2).Process(context.Background(), messages) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != len(messages) { + t.Fatalf("expected %d messages, got %d", len(messages), len(got)) + } + }) + + t.Run("preserves system and last messages", func(t *testing.T) { + messages := []Message{ + NewTextMessage(RoleSystem, "system"), + NewTextMessage(RoleUser, "one"), + NewTextMessage(RoleAssistant, "two"), + NewTextMessage(RoleUser, "three"), + NewTextMessage(RoleAssistant, "four"), + } + + got, err := TruncateHistory(2).Process(context.Background(), messages) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != 3 { + t.Fatalf("expected 3 messages, got %d", len(got)) + } + if got[0].Role != RoleSystem { + t.Fatalf("expected first message to remain system, got %q", got[0].Role) + } + if got[1].GetTextContent() != "three" || got[2].GetTextContent() != "four" { + t.Fatalf("unexpected tail: %#v", got) + } + }) +} + +func TestSlidingWindowHistory(t *testing.T) { + t.Run("returns empty history unchanged", func(t *testing.T) { + got, err := SlidingWindowHistory(10, func(Message) int { return 1 }).Process(context.Background(), nil) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != 0 { + t.Fatalf("expected empty history, got %d messages", len(got)) + } + }) + + t.Run("returns only system when budget is exhausted by system prompt", func(t *testing.T) { + messages := []Message{ + NewTextMessage(RoleSystem, "system"), + NewTextMessage(RoleUser, "one"), + } + + got, err := SlidingWindowHistory(1, func(Message) int { return 2 }).Process(context.Background(), messages) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != 1 || got[0].Role != RoleSystem { + t.Fatalf("expected only system message, got %#v", got) + } + }) + + t.Run("keeps most recent messages that fit", func(t *testing.T) { + messages := []Message{ + NewTextMessage(RoleSystem, "system"), + NewTextMessage(RoleUser, "one"), + NewTextMessage(RoleAssistant, "two"), + NewTextMessage(RoleUser, "three"), + NewTextMessage(RoleAssistant, "four"), + } + + costs := map[string]int{ + "system": 1, + "one": 2, + "two": 2, + "three": 3, + "four": 2, + } + + got, err := SlidingWindowHistory(6, func(msg Message) int { + return costs[msg.GetTextContent()] + }).Process(context.Background(), messages) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != 3 { + t.Fatalf("expected 3 messages, got %d", len(got)) + } + if got[0].Role != RoleSystem || got[1].GetTextContent() != "three" || got[2].GetTextContent() != "four" { + t.Fatalf("unexpected window: %#v", got) + } + }) +} + +func TestSummarizeHistory(t *testing.T) { + t.Run("returns original when history fits", func(t *testing.T) { + model := &historyStubModel{} + messages := []Message{ + NewTextMessage(RoleUser, "one"), + NewTextMessage(RoleAssistant, "two"), + } + + got, err := SummarizeHistory(model, 2).Process(context.Background(), messages) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != len(messages) { + t.Fatalf("expected %d messages, got %d", len(messages), len(got)) + } + if len(model.requests) != 0 { + t.Fatalf("expected summarizer not to be called, got %d requests", len(model.requests)) + } + }) + + t.Run("summarizes older messages and keeps recent tail", func(t *testing.T) { + model := &historyStubModel{ + resp: &ChatResponse{ + Model: "history-stub", + Created: time.Unix(0, 0), + Choices: []Choice{{ + Index: 0, + Message: NewTextMessage(RoleAssistant, "compact summary"), + }}, + }, + } + messages := []Message{ + NewTextMessage(RoleSystem, "system"), + NewTextMessage(RoleUser, "first"), + NewTextMessage(RoleAssistant, "second"), + NewTextMessage(RoleUser, "third"), + NewTextMessage(RoleAssistant, "fourth"), + } + + got, err := SummarizeHistory(model, 2).Process(context.Background(), messages) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(model.requests) != 1 { + t.Fatalf("expected 1 summarizer request, got %d", len(model.requests)) + } + if model.requests[0].Model != model.Name() { + t.Fatalf("expected model name %q, got %q", model.Name(), model.requests[0].Model) + } + if got[0].Role != RoleSystem { + t.Fatalf("expected system prompt to be preserved, got %q", got[0].Role) + } + if got[1].GetTextContent() != "[Conversation summary]: compact summary" { + t.Fatalf("unexpected summary message: %#v", got[1]) + } + if got[2].GetTextContent() != "third" || got[3].GetTextContent() != "fourth" { + t.Fatalf("unexpected recent messages: %#v", got) + } + }) + + t.Run("wraps model error", func(t *testing.T) { + model := &historyStubModel{err: errors.New("boom")} + messages := []Message{ + NewTextMessage(RoleUser, "first"), + NewTextMessage(RoleAssistant, "second"), + NewTextMessage(RoleUser, "third"), + } + + _, err := SummarizeHistory(model, 1).Process(context.Background(), messages) + if err == nil || err.Error() != "summarize history: boom" { + t.Fatalf("expected wrapped error, got %v", err) + } + }) + + t.Run("errors when summary response has no choices", func(t *testing.T) { + model := &historyStubModel{ + resp: &ChatResponse{ + Model: "history-stub", + Created: time.Unix(0, 0), + }, + } + messages := []Message{ + NewTextMessage(RoleUser, "first"), + NewTextMessage(RoleAssistant, "second"), + NewTextMessage(RoleUser, "third"), + } + + _, err := SummarizeHistory(model, 1).Process(context.Background(), messages) + if err == nil || err.Error() != "summarize history: no response choices" { + t.Fatalf("expected no choices error, got %v", err) + } + }) +} + +func TestChainProcessors(t *testing.T) { + t.Run("applies processors in sequence", func(t *testing.T) { + p1 := HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return append(messages, NewTextMessage(RoleAssistant, "one")), nil + }) + p2 := HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return append(messages, NewTextMessage(RoleAssistant, "two")), nil + }) + + got, err := ChainProcessors(p1, p2).Process(context.Background(), []Message{NewTextMessage(RoleUser, "start")}) + if err != nil { + t.Fatalf("Process: %v", err) + } + if len(got) != 3 { + t.Fatalf("expected 3 messages, got %d", len(got)) + } + if got[1].GetTextContent() != "one" || got[2].GetTextContent() != "two" { + t.Fatalf("unexpected chained output: %#v", got) + } + }) + + t.Run("returns first processor error", func(t *testing.T) { + expected := errors.New("stop") + p1 := HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return nil, expected + }) + p2 := HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + t.Fatal("second processor should not be called") + return messages, nil + }) + + _, err := ChainProcessors(p1, p2).Process(context.Background(), []Message{NewTextMessage(RoleUser, "start")}) + if !errors.Is(err, expected) { + t.Fatalf("expected %v, got %v", expected, err) + } + }) +} diff --git a/internal/core/handler_test.go b/internal/core/handler_test.go new file mode 100644 index 0000000..5a2c1cc --- /dev/null +++ b/internal/core/handler_test.go @@ -0,0 +1,24 @@ +package core + +import ( + "strings" + "testing" +) + +func TestFormatToolResult(t *testing.T) { + if got := FormatToolResult(nil); got != "" { + t.Fatalf("expected empty string for nil result, got %q", got) + } + + if got := FormatToolResult("plain text"); got != "plain text" { + t.Fatalf("expected raw string, got %q", got) + } + + if got := FormatToolResult(map[string]any{"ok": true}); !strings.Contains(got, `"ok":true`) { + t.Fatalf("expected JSON output, got %q", got) + } + + if got := FormatToolResult(make(chan int)); !strings.Contains(got, "Error formatting result") { + t.Fatalf("expected formatting error message, got %q", got) + } +} diff --git a/internal/core/message_test.go b/internal/core/message_test.go index 1dc8b07..fb77ce3 100644 --- a/internal/core/message_test.go +++ b/internal/core/message_test.go @@ -69,3 +69,28 @@ func TestGetTextContentEmpty(t *testing.T) { t.Errorf("expected empty string, got %q", got) } } + +func TestThinkingHelpers(t *testing.T) { + redacted := &ThinkingBlock{ID: "redacted_thinking"} + if !redacted.IsRedacted() { + t.Fatal("expected redacted thinking block to be detected") + } + + visible := &ThinkingBlock{ID: "visible"} + if visible.IsRedacted() { + t.Fatal("expected non-redacted thinking block") + } + + msg := Message{ + Role: RoleAssistant, + Content: []Part{ + {Type: ContentThinking, Thinking: &ThinkingBlock{Text: "first "}}, + {Type: ContentText, Text: "ignored"}, + {Type: ContentThinking, Thinking: &ThinkingBlock{Text: "second"}}, + }, + } + + if got := msg.GetThinkingContent(); got != "first second" { + t.Fatalf("expected combined thinking content, got %q", got) + } +} diff --git a/internal/core/stream_test.go b/internal/core/stream_test.go new file mode 100644 index 0000000..850ee1e --- /dev/null +++ b/internal/core/stream_test.go @@ -0,0 +1,45 @@ +package core + +import ( + "errors" + "testing" +) + +func TestStreamResultTextAndWait(t *testing.T) { + ch := make(chan StreamEvent, 3) + ch <- StreamEvent{Type: StreamEventTextDelta, Delta: "hello "} + ch <- StreamEvent{Type: StreamEventTextDelta, Delta: "world"} + close(ch) + + result := NewStreamResult(ch) + text, err := result.Text() + if err != nil { + t.Fatalf("Text: %v", err) + } + if text != "hello world" { + t.Fatalf("expected accumulated text, got %q", text) + } + if err := result.Wait(); err != nil { + t.Fatalf("Wait after Text should be nil, got %v", err) + } +} + +func TestStreamResultWaitReturnsError(t *testing.T) { + expected := errors.New("stream failed") + ch := make(chan StreamEvent, 2) + ch <- StreamEvent{Type: StreamEventTextDelta, Delta: "partial"} + ch <- StreamEvent{Type: StreamEventError, Error: expected} + close(ch) + + result := NewStreamResult(ch) + if err := result.Wait(); !errors.Is(err, expected) { + t.Fatalf("expected %v, got %v", expected, err) + } + text, err := result.Text() + if !errors.Is(err, expected) { + t.Fatalf("expected repeated Text call to return %v, got %v", expected, err) + } + if text != "partial" { + t.Fatalf("expected accumulated partial text, got %q", text) + } +} diff --git a/internal/testutil/model.go b/internal/testutil/model.go new file mode 100644 index 0000000..b94e77c --- /dev/null +++ b/internal/testutil/model.go @@ -0,0 +1,74 @@ +package testutil + +import ( + "context" + "errors" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +// StubModel is a small request spy for tests that only need a canned response +// or error and want to inspect the requests that were made. +type StubModel struct { + NameValue string + Response *core.ChatResponse + Err error + Requests []*core.ChatRequest +} + +func (m *StubModel) Request(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) { + m.Requests = append(m.Requests, req) + if m.Err != nil { + return nil, m.Err + } + if m.Response != nil { + return m.Response, nil + } + return &core.ChatResponse{Model: m.Name()}, nil +} + +func (m *StubModel) Name() string { + if m.NameValue != "" { + return m.NameValue + } + return "stub-model" +} + +// ScriptedStreamModel is a request spy for streaming tests that replays a fixed +// sequence of stream event slices across successive RequestStream calls. +type ScriptedStreamModel struct { + NameValue string + Requests []*core.ChatRequest + Streams [][]core.StreamEvent +} + +func (m *ScriptedStreamModel) Request(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) { + return nil, errors.New("Request should not be called for streaming tests") +} + +func (m *ScriptedStreamModel) RequestStream(ctx context.Context, req *core.ChatRequest) (*core.StreamResult, error) { + m.Requests = append(m.Requests, req) + if len(m.Streams) == 0 { + return nil, errors.New("no scripted stream available") + } + + events := m.Streams[0] + m.Streams = m.Streams[1:] + return NewScriptedStream(events...), nil +} + +func (m *ScriptedStreamModel) Name() string { + if m.NameValue != "" { + return m.NameValue + } + return "scripted-stream" +} + +func NewScriptedStream(events ...core.StreamEvent) *core.StreamResult { + ch := make(chan core.StreamEvent, len(events)) + for _, event := range events { + ch <- event + } + close(ch) + return core.NewStreamResult(ch) +} diff --git a/mcp/client.go b/mcp/client.go index 9ad737e..d3f250e 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -152,6 +152,11 @@ func (c *Client) Connect(ctx context.Context) error { } } + if err = c.inner.Start(ctx); err != nil { + _ = c.inner.Close() + return fmt.Errorf("mcp start: %w", err) + } + // Perform initialization handshake initReq := mcptypes.InitializeRequest{} initReq.Params.ProtocolVersion = mcptypes.LATEST_PROTOCOL_VERSION diff --git a/mcp/mcp_internal_test.go b/mcp/mcp_internal_test.go new file mode 100644 index 0000000..133402b --- /dev/null +++ b/mcp/mcp_internal_test.go @@ -0,0 +1,84 @@ +package mcp + +import ( + "context" + "strings" + "testing" + + mcptypes "github.com/mark3labs/mcp-go/mcp" +) + +func TestClientConstructorsAndHelpers(t *testing.T) { + stdio := NewStdioClient("filesystem", "npx", []string{"server"}, WithEnv(map[string]string{"A": "1"})) + if stdio.Name() != "filesystem" { + t.Fatalf("unexpected stdio client name %q", stdio.Name()) + } + if stdio.kind != kindStdio || stdio.stdioCmd != "npx" || len(stdio.stdioArgs) != 1 || len(stdio.stdioEnv) != 1 { + t.Fatalf("unexpected stdio client config %#v", stdio) + } + + sse := NewSSEClient("remote", "http://localhost/sse", WithHeaders(map[string]string{"Authorization": "Bearer token"})) + if sse.kind != kindSSE || sse.httpURL != "http://localhost/sse" || sse.httpHdrs["Authorization"] != "Bearer token" { + t.Fatalf("unexpected SSE client config %#v", sse) + } + + http := NewHTTPClient("remote-http", "http://localhost/mcp", WithHeaders(map[string]string{"X-Test": "1"})) + if http.kind != kindHTTP || http.httpURL != "http://localhost/mcp" || http.httpHdrs["X-Test"] != "1" { + t.Fatalf("unexpected HTTP client config %#v", http) + } + + if err := http.Close(); err != nil { + t.Fatalf("Close on disconnected client should be nil, got %v", err) + } +} + +func TestToolsetConversionAndErrorPaths(t *testing.T) { + client := NewHTTPClient("remote-http", "http://localhost/mcp") + + if _, err := NewToolset(client); err == nil || !strings.Contains(err.Error(), "call Connect first") { + t.Fatalf("expected disconnected client error from NewToolset, got %v", err) + } + if _, err := NewToolsetWithContext(context.Background(), client); err == nil || !strings.Contains(err.Error(), "call Connect first") { + t.Fatalf("expected disconnected client error from NewToolsetWithContext, got %v", err) + } + + mcpTool := mcptypes.Tool{ + Name: "lookup_weather", + Description: "Lookup weather", + InputSchema: mcptypes.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "city": map[string]any{"type": "string"}, + }, + Required: []string{"city"}, + }, + } + + tool, handler := convertMCPTool(client, mcpTool) + if tool.Function.Name != "lookup_weather" { + t.Fatalf("unexpected converted tool name %q", tool.Function.Name) + } + if handler.Name() != "lookup_weather" { + t.Fatalf("unexpected handler name %q", handler.Name()) + } + + text := extractTextContent([]mcptypes.Content{ + mcptypes.NewTextContent("hello"), + &mcptypes.TextContent{Type: "text", Text: "world"}, + }) + if text != "hello\nworld" { + t.Fatalf("unexpected extracted text %q", text) + } + + withPlaceholder := extractTextContent([]mcptypes.Content{ + mcptypes.NewTextContent("hello"), + &mcptypes.ImageContent{Type: "image", MIMEType: "image/png", Data: "AQID"}, + }) + if !strings.Contains(withPlaceholder, "[*mcp.ImageContent content]") { + t.Fatalf("expected placeholder for non-text content, got %q", withPlaceholder) + } + + if _, err := handler.Execute(context.Background(), map[string]interface{}{"city": "Lima"}, nil); err == nil || !strings.Contains(err.Error(), `mcp tool "lookup_weather"`) { + t.Fatalf("expected wrapped MCP execution error, got %v", err) + } +} diff --git a/mcp/mcp_transport_test.go b/mcp/mcp_transport_test.go new file mode 100644 index 0000000..5c4c534 --- /dev/null +++ b/mcp/mcp_transport_test.go @@ -0,0 +1,149 @@ +package mcp + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + "time" + + mcpgo "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +func TestClientConnectListToolsCallToolAndToolsetOverSSE(t *testing.T) { + mcpServer := server.NewMCPServer( + "test-server", + "1.0.0", + server.WithToolCapabilities(true), + ) + mcpServer.AddTool( + mcpgo.NewTool( + "echo", + mcpgo.WithDescription("Echo a name"), + mcpgo.WithString("name"), + ), + func(ctx context.Context, request mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { + return mcpgo.NewToolResultText(fmt.Sprintf("hello %s", request.GetArguments()["name"])), nil + }, + ) + mcpServer.AddTool( + mcpgo.NewTool("fail"), + func(ctx context.Context, request mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { + return mcpgo.NewToolResultError("boom"), nil + }, + ) + + sseServer := server.NewTestServer(mcpServer) + defer sseServer.Close() + + client := NewSSEClient("remote-tools", sseServer.URL+"/sse") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect: %v", err) + } + if err := client.Connect(ctx); err != nil { + t.Fatalf("expected repeated Connect to be a no-op, got %v", err) + } + + tools, err := client.listTools(ctx) + if err != nil { + t.Fatalf("listTools: %v", err) + } + if len(tools) != 2 { + t.Fatalf("expected 2 tools, got %#v", tools) + } + + result, err := client.callTool(ctx, "echo", map[string]any{"name": "Lima"}) + if err != nil { + t.Fatalf("callTool: %v", err) + } + if got := extractTextContent(result.Content); got != "hello Lima" { + t.Fatalf("unexpected tool result %q", got) + } + + toolset, err := NewToolset(client) + if err != nil { + t.Fatalf("NewToolset: %v", err) + } + + convertedTools, handlers := toolset.ToolsAndHandlers() + if len(convertedTools) != 2 || len(handlers) != 2 { + t.Fatalf("unexpected toolset contents: tools=%d handlers=%d", len(convertedTools), len(handlers)) + } + + var echoHandler, failHandler core.ToolHandler + for _, handler := range handlers { + switch handler.Name() { + case "echo": + echoHandler = handler + case "fail": + failHandler = handler + } + } + if echoHandler == nil || failHandler == nil { + t.Fatalf("expected both handlers to be present, got %#v", handlers) + } + + out, err := echoHandler.Execute(ctx, map[string]any{"name": "Cusco"}, nil) + if err != nil { + t.Fatalf("echo handler execute: %v", err) + } + if out != "hello Cusco" { + t.Fatalf("unexpected echo handler output %#v", out) + } + + out, err = failHandler.Execute(ctx, map[string]any{}, nil) + if err == nil || out != "boom" { + t.Fatalf("expected MCP error result, got output=%#v err=%v", out, err) + } + + if err := client.Close(); err != nil { + t.Fatalf("Close: %v", err) + } +} + +func TestClientConnectOverHTTPAndStdioFailure(t *testing.T) { + mcpServer := server.NewMCPServer( + "http-server", + "1.0.0", + server.WithToolCapabilities(true), + ) + mcpServer.AddTool( + mcpgo.NewTool("echo_http"), + func(ctx context.Context, request mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { + return mcpgo.NewToolResultText("ok"), nil + }, + ) + + httpHandler := server.NewStreamableHTTPServer(mcpServer) + httpServer := httptest.NewServer(httpHandler) + defer httpServer.Close() + + client := NewHTTPClient("remote-http", httpServer.URL+"/mcp") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Connect(ctx); err != nil { + t.Fatalf("HTTP Connect: %v", err) + } + tools, err := client.listTools(ctx) + if err != nil { + t.Fatalf("HTTP listTools: %v", err) + } + if len(tools) != 1 || tools[0].Name != "echo_http" { + t.Fatalf("unexpected HTTP tool list %#v", tools) + } + if err := client.Close(); err != nil { + t.Fatalf("HTTP Close: %v", err) + } + + stdioClient := NewStdioClient("missing", "/definitely/not/a/real-command", nil) + if err := stdioClient.Connect(ctx); err == nil { + t.Fatal("expected stdio connect to fail for a missing command") + } +} diff --git a/multimodal_file_test.go b/multimodal_file_test.go new file mode 100644 index 0000000..3b46e63 --- /dev/null +++ b/multimodal_file_test.go @@ -0,0 +1,39 @@ +package agentic_test + +import ( + "os" + "path/filepath" + "testing" + + agentic "github.com/regularkevvv/agentic-go" +) + +func TestNewImageFileMessageAdditionalMediaTypes(t *testing.T) { + dir := t.TempDir() + tests := []struct { + name string + extension string + mediaType string + }{ + {name: "bmp", extension: ".bmp", mediaType: "image/bmp"}, + {name: "svg", extension: ".svg", mediaType: "image/svg+xml"}, + {name: "pdf", extension: ".pdf", mediaType: "application/pdf"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := filepath.Join(dir, tt.name+tt.extension) + if err := os.WriteFile(path, []byte("data"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + msg, err := agentic.NewImageFileMessage("inspect", path) + if err != nil { + t.Fatalf("NewImageFileMessage: %v", err) + } + if got := msg.Content[1].ImageData.MediaType; got != tt.mediaType { + t.Fatalf("expected media type %q, got %q", tt.mediaType, got) + } + }) + } +} diff --git a/output_extra_test.go b/output_internal_test.go similarity index 74% rename from output_extra_test.go rename to output_internal_test.go index 20405a0..8342171 100644 --- a/output_extra_test.go +++ b/output_internal_test.go @@ -2,9 +2,14 @@ package agentic import ( "context" + "strings" "testing" ) +type outputCoverageValue struct { + Value string `json:"value"` +} + func TestTextOutputSpec(t *testing.T) { spec := &TextOutputSpec{} @@ -119,3 +124,30 @@ func TestNoopToolHandler(t *testing.T) { t.Errorf("expected key=value, got %v", resultMap["key"]) } } + +func TestToolOutputSpecErrors(t *testing.T) { + t.Run("panics when output tool schema is invalid", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for empty output tool description") + } + }() + NewToolOutput[outputCoverageValue]("").Tools() + }) + + t.Run("returns no output found", func(t *testing.T) { + spec := NewToolOutput[outputCoverageValue]("desc") + _, err := spec.Parse(Message{Role: RoleAssistant}) + if err == nil || !strings.Contains(err.Error(), "no output found") { + t.Fatalf("expected no output error, got %v", err) + } + }) + + t.Run("returns invalid json error", func(t *testing.T) { + spec := NewToolOutput[outputCoverageValue]("desc") + _, err := spec.Parse(NewTextMessage(RoleAssistant, "not json")) + if err == nil || !strings.Contains(err.Error(), "not valid JSON") { + t.Fatalf("expected invalid JSON error, got %v", err) + } + }) +} diff --git a/output_mode_internal_test.go b/output_mode_internal_test.go new file mode 100644 index 0000000..92a248a --- /dev/null +++ b/output_mode_internal_test.go @@ -0,0 +1,45 @@ +package agentic + +import ( + "strings" + "testing" +) + +type outputModeCoverageValue struct { + Value string `json:"value"` +} + +func TestOutputModeAdditionalBranches(t *testing.T) { + t.Run("native json schema adds object defaults", func(t *testing.T) { + spec := NewNativeOutput[any]("generic", "generic output") + schema := spec.jsonSchema() + + if got := schema["type"]; got != "object" { + t.Fatalf("expected default object type, got %#v", got) + } + if _, ok := schema["properties"].(map[string]interface{}); !ok { + t.Fatalf("expected default properties map, got %#v", schema["properties"]) + } + }) + + t.Run("prompted parse errors", func(t *testing.T) { + type validated struct { + Name string `json:"name" validate:"required"` + } + + emptySpec := NewPromptedOutput[outputModeCoverageValue]() + if _, err := emptySpec.Parse(Message{Role: RoleAssistant}); err == nil || err.Error() != "no text content in response for prompted output parsing" { + t.Fatalf("expected empty text error, got %v", err) + } + + invalidSpec := NewPromptedOutput[outputModeCoverageValue]() + if _, err := invalidSpec.Parse(NewTextMessage(RoleAssistant, "{")); err == nil || !strings.Contains(err.Error(), "parse prompted JSON output") { + t.Fatalf("expected JSON parse error, got %v", err) + } + + validatedSpec := NewPromptedOutput[validated]() + if _, err := validatedSpec.Parse(NewTextMessage(RoleAssistant, `{}`)); err == nil || !IsValidationError(err) { + t.Fatalf("expected validation error, got %v", err) + } + }) +} diff --git a/output_test.go b/output_test.go index 54218c0..6a4286a 100644 --- a/output_test.go +++ b/output_test.go @@ -54,6 +54,45 @@ func TestRunOutputStructured(t *testing.T) { } } +func TestRunOutputStructuredFull(t *testing.T) { + type Result struct { + Name string `json:"name" description:"Name"` + Score int `json:"score" description:"Score"` + } + + outputSpec := agentic.NewToolOutput[Result]("Provide result") + + model := test.NewTestModel( + test.ModelResponse{ + ToolCalls: []agentic.ToolUse{ + { + ID: "c1", + Name: "__output__", + Input: map[string]interface{}{ + "name": "test", + "score": float64(100), + }, + }, + }, + }, + ) + + agent := agentic.NewAgent[any]("test", model) + + result, err := agentic.RunOutput[any, Result]( + context.Background(), agent, "go", nil, outputSpec, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Output.Name != "test" { + t.Errorf("expected name %q, got %q", "test", result.Output.Name) + } + if result.Output.Score != 100 { + t.Errorf("expected score 100, got %d", result.Output.Score) + } +} + func TestToolOutputSpecParse(t *testing.T) { type Result struct { Value int `json:"value"` diff --git a/provider/anthropic/anthropic_transport_test.go b/provider/anthropic/anthropic_transport_test.go new file mode 100644 index 0000000..3adbd90 --- /dev/null +++ b/provider/anthropic/anthropic_transport_test.go @@ -0,0 +1,211 @@ +package anthropic + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +func TestAnthropicRequestAndStreamWithLocalServer(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.NotFound(w, r) + return + } + + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request body: %v", err) + } + + stream, _ := body["stream"].(bool) + if !stream { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "id":"msg_req", + "type":"message", + "role":"assistant", + "model":"claude-sonnet", + "content":[{"type":"text","text":"hello from claude"}], + "stop_reason":"end_turn", + "usage":{"input_tokens":4,"output_tokens":2,"cache_read_input_tokens":1,"cache_creation_input_tokens":3} + }`) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_stream","type":"message","role":"assistant","content":[],"model":"claude-sonnet","usage":{"input_tokens":7,"cache_read_input_tokens":1,"cache_creation_input_tokens":2}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"call_1","name":"lookup","input":{}}}`, + ``, + `event: content_block_delta`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"city\":\"Lima\"}"}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":1,"content_block":{"type":"thinking","thinking":""}}`, + ``, + `event: content_block_delta`, + `data: {"type":"content_block_delta","index":1,"delta":{"type":"thinking_delta","thinking":"think"}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}`, + ``, + `event: content_block_delta`, + `data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n")) + })) + defer server.Close() + + model, err := New("claude-sonnet", WithAPIKey("test-key"), WithBaseURL(server.URL)) + if err != nil { + t.Fatalf("New: %v", err) + } + + req := &core.ChatRequest{ + Model: "claude-sonnet", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + } + + resp, err := model.Request(context.Background(), req) + if err != nil { + t.Fatalf("Request: %v", err) + } + if resp.Choices[0].Message.GetTextContent() != "hello from claude" { + t.Fatalf("unexpected request output %q", resp.Choices[0].Message.GetTextContent()) + } + if resp.Usage.CacheReadTokens != 1 || resp.Usage.CacheCreationTokens != 3 { + t.Fatalf("unexpected request usage %#v", resp.Usage) + } + + stream, err := model.RequestStream(context.Background(), req) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + var events []core.StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(events) != 5 { + t.Fatalf("expected 5 stream events, got %d: %#v", len(events), events) + } + if events[0].Type != core.StreamEventToolCallStart || events[0].ToolUse == nil || events[0].ToolUse.Name != "lookup" { + t.Fatalf("unexpected tool-call start %#v", events[0]) + } + if events[1].Type != core.StreamEventToolCallDelta || events[1].ToolCallID != "call_1" || events[1].Delta != `{"city":"Lima"}` { + t.Fatalf("unexpected tool-call delta %#v", events[1]) + } + if events[2].Type != core.StreamEventThinkingDelta || events[2].Delta != "think" { + t.Fatalf("unexpected thinking delta %#v", events[2]) + } + if events[3].Type != core.StreamEventTextDelta || events[3].Delta != "hello" { + t.Fatalf("unexpected text delta %#v", events[3]) + } + if events[4].Type != core.StreamEventDone || events[4].Usage == nil || events[4].Usage.TotalTokens != 12 { + t.Fatalf("unexpected done event %#v", events[4]) + } +} + +func TestAnthropicConversionHelpersCoverAdditionalBranches(t *testing.T) { + param := convertMessage(core.Message{ + Role: core.RoleUser, + Content: []core.Part{ + { + Type: core.ContentThinking, + Thinking: &core.ThinkingBlock{ID: "redacted_thinking", Signature: "secret"}, + }, + {Type: core.ContentText, Text: "cached text"}, + {Type: core.ContentCachePoint, CachePoint: &core.CachePoint{TTL: "1h"}}, + { + Type: core.ContentImageData, + ImageData: &core.ImageData{Data: "AQID", MediaType: "image/png"}, + }, + { + Type: core.ContentImageURL, + ImageURL: &core.ImageURL{URL: "https://example.com/image.png"}, + }, + { + Type: core.ContentDocumentURL, + DocumentURL: &core.DocumentURL{URL: "https://example.com/file.pdf"}, + }, + { + Type: core.ContentToolUse, + ToolUse: &core.ToolUse{ID: "call_1", Name: "lookup", Input: map[string]any{"city": "Lima"}}, + }, + { + Type: core.ContentToolResult, + ToolResult: &core.ToolResult{ToolUseID: "call_1", Content: `{"ok":true}`, IsError: true}, + }, + { + Type: core.ContentUploadedFile, + UploadedFile: &core.UploadedFile{FileID: "file_123"}, + }, + }, + }) + + if len(param.Content) != 7 { + t.Fatalf("expected 7 content blocks, got %#v", param.Content) + } + if cc := param.Content[1].GetCacheControl(); cc == nil || cc.TTL != anthropic.CacheControlEphemeralTTLTTL1h { + t.Fatalf("expected cache control to be attached to previous block, got %#v", cc) + } + + jsonSchema := convertResponseFormat(&core.ResponseFormat{ + Type: "json_schema", + JSONSchema: &core.JSONSchemaFormat{Schema: map[string]any{"type": "object"}}, + }) + if jsonSchema.Format.Schema["type"] != "object" { + t.Fatalf("expected JSON schema response format, got %#v", jsonSchema) + } + + unsupported := convertResponseFormat(&core.ResponseFormat{Type: "json_object"}) + if unsupported.Format.Schema != nil { + t.Fatalf("expected unsupported format to be empty, got %#v", unsupported) + } + + mustBlock := func(raw string) anthropic.ContentBlockUnion { + t.Helper() + var block anthropic.ContentBlockUnion + if err := json.Unmarshal([]byte(raw), &block); err != nil { + t.Fatalf("unmarshal content block: %v", err) + } + return block + } + + msg := convertResponseMessage([]anthropic.ContentBlockUnion{ + mustBlock(`{"type":"thinking","thinking":"reasoning","signature":"sig"}`), + mustBlock(`{"type":"redacted_thinking","data":"encrypted"}`), + mustBlock(`{"type":"text","text":"hello"}`), + mustBlock(`{"type":"tool_use","id":"call_1","name":"lookup","input":{"city":"Lima"}}`), + }, "assistant") + + if msg.GetThinkingContent() != "reasoning" { + t.Fatalf("unexpected thinking content %q", msg.GetThinkingContent()) + } + toolUses := msg.GetToolUses() + if len(toolUses) != 1 || toolUses[0].Input["city"] != "Lima" { + t.Fatalf("unexpected tool uses %#v", toolUses) + } + if msg.Content[1].Thinking == nil || !msg.Content[1].Thinking.IsRedacted() || msg.Content[1].Thinking.Signature != "encrypted" { + t.Fatalf("unexpected redacted thinking block %#v", msg.Content[1]) + } +} diff --git a/provider/bedrock/bedrock.go b/provider/bedrock/bedrock.go index 5e4da01..edf7379 100644 --- a/provider/bedrock/bedrock.go +++ b/provider/bedrock/bedrock.go @@ -27,10 +27,45 @@ import ( // Model implements the core.Model and core.StreamModel interfaces // using the AWS Bedrock Runtime Converse API. type Model struct { - client *bedrockruntime.Client + client runtimeClient modelID string } +type runtimeClient interface { + Converse(ctx context.Context, params *bedrockruntime.ConverseInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.ConverseOutput, error) + ConverseStream(ctx context.Context, params *bedrockruntime.ConverseStreamInput, optFns ...func(*bedrockruntime.Options)) (converseStream, error) +} + +type converseStream interface { + Events() <-chan types.ConverseStreamOutput + Close() error + Err() error +} + +type sdkRuntimeClient struct { + client *bedrockruntime.Client +} + +func (c *sdkRuntimeClient) Converse( + ctx context.Context, + params *bedrockruntime.ConverseInput, + optFns ...func(*bedrockruntime.Options), +) (*bedrockruntime.ConverseOutput, error) { + return c.client.Converse(ctx, params, optFns...) +} + +func (c *sdkRuntimeClient) ConverseStream( + ctx context.Context, + params *bedrockruntime.ConverseStreamInput, + optFns ...func(*bedrockruntime.Options), +) (converseStream, error) { + resp, err := c.client.ConverseStream(ctx, params, optFns...) + if err != nil { + return nil, err + } + return resp.GetStream(), nil +} + // Option configures the Bedrock Model. type Option func(*config) @@ -40,7 +75,7 @@ type config struct { accessKeyID string secretAccessKey string sessionToken string - client *bedrockruntime.Client + client runtimeClient } // WithRegion sets the AWS region. If not set, the AWS_DEFAULT_REGION env var is used. @@ -65,7 +100,7 @@ func WithCredentials(accessKeyID, secretAccessKey, sessionToken string) Option { // WithClient sets a pre-configured Bedrock Runtime client. // When set, all other connection options are ignored. func WithClient(client *bedrockruntime.Client) Option { - return func(c *config) { c.client = client } + return func(c *config) { c.client = &sdkRuntimeClient{client: client} } } // New creates a new Bedrock Model. @@ -127,7 +162,7 @@ func New(modelID string, opts ...Option) (*Model, error) { client := bedrockruntime.NewFromConfig(awsCfg) return &Model{ - client: client, + client: &sdkRuntimeClient{client: client}, modelID: modelID, }, nil } @@ -165,7 +200,7 @@ func (m *Model) RequestStream(ctx context.Context, req *core.ChatRequest) (*core streamInput := m.buildStreamInput(req) - resp, err := m.client.ConverseStream(ctx, streamInput) + stream, err := m.client.ConverseStream(ctx, streamInput) if err != nil { return nil, fmt.Errorf("bedrock: %w", err) } @@ -176,7 +211,6 @@ func (m *Model) RequestStream(ctx context.Context, req *core.ChatRequest) (*core go func() { defer close(ch) - stream := resp.GetStream() defer func() { _ = stream.Close() }() var usage core.Usage diff --git a/provider/bedrock/bedrock_internal_test.go b/provider/bedrock/bedrock_internal_test.go new file mode 100644 index 0000000..d6cf359 --- /dev/null +++ b/provider/bedrock/bedrock_internal_test.go @@ -0,0 +1,486 @@ +package bedrock + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + + agentic "github.com/regularkevvv/agentic-go" + "github.com/regularkevvv/agentic-go/internal/core" +) + +type mockRuntimeClient struct { + converseInput *bedrockruntime.ConverseInput + converseOutput *bedrockruntime.ConverseOutput + converseErr error + + streamInput *bedrockruntime.ConverseStreamInput + stream converseStream + streamErr error +} + +func (m *mockRuntimeClient) Converse( + ctx context.Context, + params *bedrockruntime.ConverseInput, + optFns ...func(*bedrockruntime.Options), +) (*bedrockruntime.ConverseOutput, error) { + m.converseInput = params + return m.converseOutput, m.converseErr +} + +func (m *mockRuntimeClient) ConverseStream( + ctx context.Context, + params *bedrockruntime.ConverseStreamInput, + optFns ...func(*bedrockruntime.Options), +) (converseStream, error) { + m.streamInput = params + return m.stream, m.streamErr +} + +type mockConverseStream struct { + events chan types.ConverseStreamOutput + err error + closed bool +} + +func (s *mockConverseStream) Events() <-chan types.ConverseStreamOutput { + return s.events +} + +func (s *mockConverseStream) Close() error { + s.closed = true + return nil +} + +func (s *mockConverseStream) Err() error { + return s.err +} + +func newMockConverseStream(events ...types.ConverseStreamOutput) *mockConverseStream { + ch := make(chan types.ConverseStreamOutput, len(events)) + for _, event := range events { + ch <- event + } + close(ch) + return &mockConverseStream{events: ch} +} + +func TestWithProfileOption(t *testing.T) { + cfg := &config{} + WithProfile("dev-profile")(cfg) + if cfg.profile != "dev-profile" { + t.Fatalf("expected profile to be set, got %#v", cfg) + } +} + +func TestBedrockRequestValidationErrors(t *testing.T) { + model := &Model{modelID: "anthropic.test"} + + if _, err := model.Request(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected Request to fail validation") + } + if _, err := model.RequestStream(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected RequestStream to fail validation") + } +} + +func TestBedrockRequestUsesRuntimeClient(t *testing.T) { + mock := &mockRuntimeClient{ + converseOutput: &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "answer"}, + }, + }, + }, + StopReason: types.StopReasonEndTurn, + Usage: &types.TokenUsage{ + InputTokens: aws.Int32(10), + OutputTokens: aws.Int32(5), + TotalTokens: aws.Int32(15), + }, + }, + } + model := &Model{client: mock, modelID: "anthropic.test"} + + resp, err := model.Request(context.Background(), &core.ChatRequest{ + Model: "anthropic.test", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + }) + if err != nil { + t.Fatalf("Request: %v", err) + } + if mock.converseInput == nil || aws.ToString(mock.converseInput.ModelId) != "anthropic.test" { + t.Fatalf("unexpected converse input %#v", mock.converseInput) + } + if len(mock.converseInput.Messages) != 1 { + t.Fatalf("expected a single converted message, got %#v", mock.converseInput.Messages) + } + if resp.Choices[0].Message.GetTextContent() != "answer" { + t.Fatalf("unexpected response message %#v", resp.Choices[0].Message) + } + if resp.Usage.TotalTokens != 15 { + t.Fatalf("unexpected usage %#v", resp.Usage) + } +} + +func TestBedrockRequestWrapsRuntimeError(t *testing.T) { + model := &Model{ + client: &mockRuntimeClient{converseErr: errors.New("boom")}, + modelID: "anthropic.test", + } + + _, err := model.Request(context.Background(), &core.ChatRequest{ + Model: "anthropic.test", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + }) + if err == nil || err.Error() != "bedrock: boom" { + t.Fatalf("expected wrapped runtime error, got %v", err) + } +} + +func TestBuildParamsAndInputs(t *testing.T) { + model := &Model{modelID: "anthropic.test"} + temperature := 0.6 + maxTokens := 128 + topP := 0.9 + toolChoice := core.ToolChoiceRequired + + req := &core.ChatRequest{ + Model: "anthropic.test", + Messages: []core.Message{ + core.NewTextMessage(core.RoleSystem, "system"), + { + Role: core.RoleUser, + Content: []core.Part{ + {Type: core.ContentText, Text: "hello"}, + agentic.ImageDataPart([]byte("img"), "image/png"), + {Type: core.ContentImageURL, ImageURL: &core.ImageURL{URL: "https://example.com/image.png"}}, + {Type: core.ContentDocumentURL, DocumentURL: &core.DocumentURL{URL: "https://example.com/file.pdf"}}, + }, + }, + { + Role: core.RoleAssistant, + Content: []core.Part{ + {Type: core.ContentText, Text: "working"}, + {Type: core.ContentToolUse, ToolUse: &core.ToolUse{ID: "call_1", Name: "lookup", Input: map[string]interface{}{"city": "Lima"}}}, + {Type: core.ContentThinking, Thinking: &core.ThinkingBlock{Text: "reasoning", Signature: "sig"}}, + }, + }, + core.NewToolResultMessage("call_1", `{"temp":72}`, false), + }, + Temperature: &temperature, + MaxTokens: &maxTokens, + TopP: &topP, + Tools: []core.Tool{{ + Type: core.ToolTypeFunction, + Function: core.Function{ + Name: "lookup", + Description: "Lookup weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }}, + ToolChoice: &toolChoice, + Thinking: &core.ThinkingConfig{Enabled: true, BudgetTokens: 300}, + } + + params := model.buildParams(req) + if len(params.messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(params.messages)) + } + if len(params.system) != 1 { + t.Fatalf("expected 1 system block, got %#v", params.system) + } + if params.inferenceConf == nil || params.inferenceConf.Temperature == nil || *params.inferenceConf.Temperature != float32(temperature) { + t.Fatalf("unexpected inference config %#v", params.inferenceConf) + } + if params.toolConfig == nil || len(params.toolConfig.Tools) != 1 { + t.Fatalf("expected tool config, got %#v", params.toolConfig) + } + if params.additionalReq == nil { + t.Fatal("expected additional model request fields for thinking config") + } + + input := model.buildInput(req) + if input.ModelId == nil || *input.ModelId != "anthropic.test" { + t.Fatalf("unexpected converse input model %#v", input.ModelId) + } + streamInput := model.buildStreamInput(req) + if streamInput.ModelId == nil || *streamInput.ModelId != "anthropic.test" { + t.Fatalf("unexpected converse stream input model %#v", streamInput.ModelId) + } +} + +func TestConvertSystemBlocksAndMessage(t *testing.T) { + if got := convertSystemBlocks(core.Message{}); got != nil { + t.Fatalf("expected nil system blocks for empty message, got %#v", got) + } + + converted := convertMessage(core.Message{ + Role: core.RoleAssistant, + Content: []core.Part{ + {Type: core.ContentText, Text: "hello"}, + {Type: core.ContentToolUse, ToolUse: &core.ToolUse{ID: "call_1", Name: "lookup", Input: map[string]interface{}{"city": "Lima"}}}, + {Type: core.ContentToolResult, ToolResult: &core.ToolResult{ToolUseID: "call_1", Content: `{"temp":72}`, IsError: true}}, + agentic.ImageDataPart([]byte("img"), "image/png"), + {Type: core.ContentThinking, Thinking: &core.ThinkingBlock{Text: "reasoning", Signature: "sig"}}, + {Type: core.ContentAudioURL, AudioURL: &core.AudioURL{URL: "https://example.com/audio.mp3"}}, + {Type: core.ContentCachePoint, CachePoint: &core.CachePoint{}}, + }, + }) + if converted == nil || len(converted.Content) != 5 { + t.Fatalf("expected 5 supported content blocks, got %#v", converted) + } + + if got := convertMessage(core.Message{ + Role: core.RoleUser, + Content: []core.Part{ + {Type: core.ContentAudioURL, AudioURL: &core.AudioURL{URL: "https://example.com/audio.mp3"}}, + {Type: core.ContentVideoURL, VideoURL: &core.VideoURL{URL: "https://example.com/video.mp4"}}, + }, + }); got != nil { + t.Fatalf("expected unsupported-only message to be skipped, got %#v", got) + } +} + +func TestBedrockRequestStreamEmitsEventsAndUsage(t *testing.T) { + stream := newMockConverseStream( + &types.ConverseStreamOutputMemberContentBlockStart{ + Value: types.ContentBlockStartEvent{ + ContentBlockIndex: aws.Int32(0), + Start: &types.ContentBlockStartMemberToolUse{ + Value: types.ToolUseBlockStart{ + ToolUseId: aws.String("call_1"), + Name: aws.String("lookup"), + }, + }, + }, + }, + &types.ConverseStreamOutputMemberContentBlockDelta{ + Value: types.ContentBlockDeltaEvent{ + ContentBlockIndex: aws.Int32(0), + Delta: &types.ContentBlockDeltaMemberToolUse{ + Value: types.ToolUseBlockDelta{Input: aws.String(`{"city":"Lima"}`)}, + }, + }, + }, + &types.ConverseStreamOutputMemberContentBlockDelta{ + Value: types.ContentBlockDeltaEvent{ + ContentBlockIndex: aws.Int32(1), + Delta: &types.ContentBlockDeltaMemberReasoningContent{ + Value: &types.ReasoningContentBlockDeltaMemberText{Value: "thinking"}, + }, + }, + }, + &types.ConverseStreamOutputMemberContentBlockDelta{ + Value: types.ContentBlockDeltaEvent{ + ContentBlockIndex: aws.Int32(2), + Delta: &types.ContentBlockDeltaMemberText{Value: "answer"}, + }, + }, + &types.UnknownUnionMember{Tag: "ignored"}, + &types.ConverseStreamOutputMemberMetadata{ + Value: types.ConverseStreamMetadataEvent{ + Usage: &types.TokenUsage{ + InputTokens: aws.Int32(10), + OutputTokens: aws.Int32(5), + TotalTokens: aws.Int32(15), + }, + }, + }, + &types.ConverseStreamOutputMemberMessageStop{ + Value: types.MessageStopEvent{StopReason: types.StopReasonEndTurn}, + }, + ) + mock := &mockRuntimeClient{stream: stream} + model := &Model{client: mock, modelID: "anthropic.test"} + + result, err := model.RequestStream(context.Background(), &core.ChatRequest{ + Model: "anthropic.test", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + }) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + var events []core.StreamEvent + for event := range result.Events { + events = append(events, event) + } + + if mock.streamInput == nil || aws.ToString(mock.streamInput.ModelId) != "anthropic.test" { + t.Fatalf("unexpected stream input %#v", mock.streamInput) + } + if !stream.closed { + t.Fatal("expected stream to be closed by RequestStream") + } + if len(events) != 5 { + t.Fatalf("expected 5 emitted events, got %#v", events) + } + if events[0].Type != core.StreamEventToolCallStart || events[0].ToolUse == nil || events[0].ToolUse.Name != "lookup" { + t.Fatalf("unexpected first event %#v", events[0]) + } + if events[1].Type != core.StreamEventToolCallDelta || events[1].ToolCallID != "call_1" { + t.Fatalf("unexpected second event %#v", events[1]) + } + if events[2].Type != core.StreamEventThinkingDelta || events[2].Delta != "thinking" { + t.Fatalf("unexpected thinking event %#v", events[2]) + } + if events[3].Type != core.StreamEventTextDelta || events[3].Delta != "answer" { + t.Fatalf("unexpected text event %#v", events[3]) + } + if events[4].Type != core.StreamEventDone || events[4].Usage == nil || events[4].Usage.TotalTokens != 15 { + t.Fatalf("unexpected done event %#v", events[4]) + } +} + +func TestBedrockRequestStreamErrors(t *testing.T) { + t.Run("connect error", func(t *testing.T) { + model := &Model{ + client: &mockRuntimeClient{streamErr: errors.New("stream failed")}, + modelID: "anthropic.test", + } + + _, err := model.RequestStream(context.Background(), &core.ChatRequest{ + Model: "anthropic.test", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + }) + if err == nil || err.Error() != "bedrock: stream failed" { + t.Fatalf("expected wrapped stream creation error, got %v", err) + } + }) + + t.Run("reader error", func(t *testing.T) { + stream := newMockConverseStream() + stream.err = errors.New("reader failed") + model := &Model{ + client: &mockRuntimeClient{stream: stream}, + modelID: "anthropic.test", + } + + result, err := model.RequestStream(context.Background(), &core.ChatRequest{ + Model: "anthropic.test", + Messages: []core.Message{ + core.NewTextMessage(core.RoleUser, "hello"), + }, + }) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + if waitErr := result.Wait(); waitErr == nil || waitErr.Error() != "bedrock stream: reader failed" { + t.Fatalf("expected reader error from stream, got %v", waitErr) + } + if !stream.closed { + t.Fatal("expected errored stream to be closed") + } + }) +} + +func TestConvertToolConfigChoices(t *testing.T) { + tools := []core.Tool{{ + Type: core.ToolTypeFunction, + Function: core.Function{ + Name: "lookup", + Description: "Lookup weather", + Parameters: map[string]interface{}{"type": "object"}, + }, + }} + + none := core.ToolChoiceNone + if got := convertToolConfig(tools, &none); got != nil { + t.Fatalf("expected nil tool config for none choice, got %#v", got) + } + + required := core.ToolChoiceRequired + if got := convertToolConfig(tools, &required); got == nil || got.ToolChoice == nil { + t.Fatalf("expected required tool config, got %#v", got) + } + + auto := core.ToolChoiceAuto + if got := convertToolConfig(tools, &auto); got == nil || got.ToolChoice == nil { + t.Fatalf("expected auto tool config, got %#v", got) + } +} + +func TestBedrockConvertResponseAndUsage(t *testing.T) { + model := &Model{modelID: "anthropic.test"} + resp := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{Value: "answer"}, + &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String("call_1"), + Name: aws.String("lookup"), + Input: document.NewLazyDocument(map[string]any{"city": "Lima"}), + }, + }, + &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: types.ReasoningTextBlock{ + Text: aws.String("reasoning"), + Signature: aws.String("sig"), + }, + }, + }, + }, + }, + }, + StopReason: types.StopReasonToolUse, + Usage: &types.TokenUsage{ + InputTokens: aws.Int32(100), + OutputTokens: aws.Int32(40), + TotalTokens: aws.Int32(140), + CacheReadInputTokens: aws.Int32(8), + CacheWriteInputTokens: aws.Int32(5), + }, + } + + chatResp := model.convertResponse(resp) + if chatResp.Model != "anthropic.test" { + t.Fatalf("unexpected model %q", chatResp.Model) + } + if chatResp.Choices[0].FinishReason != core.FinishReasonToolCalls { + t.Fatalf("expected tool-calls finish reason, got %q", chatResp.Choices[0].FinishReason) + } + msg := chatResp.Choices[0].Message + if msg.GetTextContent() != "answer" { + t.Fatalf("unexpected text content %q", msg.GetTextContent()) + } + if msg.GetThinkingContent() != "reasoning" { + t.Fatalf("unexpected thinking content %q", msg.GetThinkingContent()) + } + if len(msg.GetToolUses()) != 1 || msg.GetToolUses()[0].Name != "lookup" { + t.Fatalf("unexpected tool uses %#v", msg.GetToolUses()) + } + if chatResp.Usage.CacheReadTokens != 8 || chatResp.Usage.CacheCreationTokens != 5 { + t.Fatalf("unexpected usage %#v", chatResp.Usage) + } +} + +func TestConvertStopReasonDefault(t *testing.T) { + if got := convertStopReason(types.StopReason("other")); got != core.FinishReasonStop { + t.Fatalf("expected default stop finish reason, got %q", got) + } +} diff --git a/provider/gemini/gemini_internal_test.go b/provider/gemini/gemini_internal_test.go new file mode 100644 index 0000000..dd6e35a --- /dev/null +++ b/provider/gemini/gemini_internal_test.go @@ -0,0 +1,206 @@ +package gemini + +import ( + "context" + "testing" + + "google.golang.org/genai" + + agentic "github.com/regularkevvv/agentic-go" + "github.com/regularkevvv/agentic-go/internal/core" +) + +func TestWithVertexAIOption(t *testing.T) { + cfg := &config{} + WithVertexAI("project-1", "us-west1")(cfg) + if !cfg.vertexAI || cfg.project != "project-1" || cfg.location != "us-west1" { + t.Fatalf("unexpected vertex config %#v", cfg) + } +} + +func TestGeminiRequestValidationErrors(t *testing.T) { + model := &Model{model: "gemini-2.5-pro"} + + if _, err := model.Request(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected Request to fail validation") + } + if _, err := model.RequestStream(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected RequestStream to fail validation") + } +} + +func TestBuildRequestCoversConversions(t *testing.T) { + model := &Model{model: "gemini-2.5-pro"} + temperature := 0.6 + maxTokens := 128 + topP := 0.9 + frequencyPenalty := 0.1 + presencePenalty := 0.2 + toolChoice := core.ToolChoiceRequired + + contents, cfg := model.buildRequest(&core.ChatRequest{ + Model: "gemini-2.5-pro", + Messages: []core.Message{ + core.NewTextMessage(core.RoleSystem, "system"), + { + Role: core.RoleUser, + Content: []core.Part{ + {Type: core.ContentText, Text: "hello"}, + {Type: core.ContentImageURL, ImageURL: &core.ImageURL{URL: "https://example.com/image.jpg"}}, + agentic.ImageDataPart([]byte("img"), "image/png"), + {Type: core.ContentAudioURL, AudioURL: &core.AudioURL{URL: "https://example.com/audio.mp3", Format: "mp3"}}, + {Type: core.ContentVideoURL, VideoURL: &core.VideoURL{URL: "https://example.com/video.mp4"}}, + {Type: core.ContentDocumentURL, DocumentURL: &core.DocumentURL{URL: "https://example.com/file.pdf"}}, + {Type: core.ContentUploadedFile, UploadedFile: &core.UploadedFile{FileID: "file_123"}}, + {Type: core.ContentCachePoint, CachePoint: &core.CachePoint{}}, + }, + }, + { + Role: core.RoleAssistant, + Content: []core.Part{ + {Type: core.ContentText, Text: "working"}, + {Type: core.ContentToolUse, ToolUse: &core.ToolUse{ID: "call_1", Name: "lookup", Input: map[string]interface{}{"city": "Lima"}}}, + {Type: core.ContentThinking, Thinking: &core.ThinkingBlock{Text: "reasoning"}}, + }, + }, + core.NewToolResultMessage("call_1", `{"temp":72}`, false), + }, + Temperature: &temperature, + MaxTokens: &maxTokens, + TopP: &topP, + FrequencyPenalty: &frequencyPenalty, + PresencePenalty: &presencePenalty, + Tools: []core.Tool{{ + Type: core.ToolTypeFunction, + Function: core.Function{ + Name: "lookup", + Description: "Lookup weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }}, + ToolChoice: &toolChoice, + ResponseFormat: &core.ResponseFormat{ + Type: "json_schema", + JSONSchema: &core.JSONSchemaFormat{ + Name: "weather", + Schema: map[string]interface{}{"type": "object"}, + }, + }, + Thinking: &core.ThinkingConfig{Enabled: true, BudgetTokens: 123}, + }) + + if len(contents) != 3 { + t.Fatalf("expected 3 converted messages, got %d", len(contents)) + } + if cfg.SystemInstruction == nil || cfg.SystemInstruction.Parts[0].Text != "system" { + t.Fatalf("expected system instruction, got %#v", cfg.SystemInstruction) + } + if cfg.Temperature == nil || *cfg.Temperature != float32(temperature) { + t.Fatalf("unexpected temperature %#v", cfg.Temperature) + } + if cfg.MaxOutputTokens != int32(maxTokens) { + t.Fatalf("expected max output tokens %d, got %d", maxTokens, cfg.MaxOutputTokens) + } + if cfg.TopP == nil || *cfg.TopP != float32(topP) { + t.Fatalf("unexpected top_p %#v", cfg.TopP) + } + if cfg.FrequencyPenalty == nil || *cfg.FrequencyPenalty != float32(frequencyPenalty) { + t.Fatalf("unexpected frequency penalty %#v", cfg.FrequencyPenalty) + } + if cfg.PresencePenalty == nil || *cfg.PresencePenalty != float32(presencePenalty) { + t.Fatalf("unexpected presence penalty %#v", cfg.PresencePenalty) + } + if len(cfg.Tools) != 1 || len(cfg.Tools[0].FunctionDeclarations) != 1 { + t.Fatalf("expected one function declaration, got %#v", cfg.Tools) + } + if cfg.ToolConfig == nil || cfg.ToolConfig.FunctionCallingConfig.Mode != genai.FunctionCallingConfigModeAny { + t.Fatalf("unexpected tool config %#v", cfg.ToolConfig) + } + if cfg.ResponseMIMEType != "application/json" || cfg.ResponseSchema == nil { + t.Fatalf("expected JSON response schema, got %#v", cfg) + } + if cfg.ThinkingConfig == nil || !cfg.ThinkingConfig.IncludeThoughts || cfg.ThinkingConfig.ThinkingBudget == nil || *cfg.ThinkingConfig.ThinkingBudget != 123 { + t.Fatalf("unexpected thinking config %#v", cfg.ThinkingConfig) + } +} + +func TestConvertMessageFallbackAndToolConfig(t *testing.T) { + fallback := convertMessage(core.Message{ + Role: core.RoleUser, + Content: []core.Part{ + {Type: core.ContentCachePoint, CachePoint: &core.CachePoint{}}, + {Type: core.ContentUploadedFile, UploadedFile: &core.UploadedFile{FileID: "file_123"}}, + {Type: core.ContentImageData, ImageData: &core.ImageData{Data: "!!!", MediaType: "image/png"}}, + }, + }) + if len(fallback.Parts) != 1 || fallback.Parts[0].Text != "" { + t.Fatalf("expected empty text fallback, got %#v", fallback.Parts) + } + + if got := convertToolConfig(core.ToolChoiceNone); got.FunctionCallingConfig.Mode != genai.FunctionCallingConfigModeNone { + t.Fatalf("expected none tool config, got %#v", got) + } + if got := convertToolConfig(core.ToolChoiceAuto); got.FunctionCallingConfig.Mode != genai.FunctionCallingConfigModeAuto { + t.Fatalf("expected auto tool config, got %#v", got) + } +} + +func TestGeminiConvertResponseAndUsage(t *testing.T) { + model := &Model{model: "gemini-2.5-pro"} + resp := &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: &genai.Content{ + Role: "model", + Parts: []*genai.Part{ + genai.NewPartFromText("answer"), + {Thought: true, Text: "thinking"}, + genai.NewPartFromFunctionCall("lookup", map[string]any{"city": "Lima"}), + }, + }, + FinishReason: genai.FinishReasonMaxTokens, + }}, + UsageMetadata: &genai.GenerateContentResponseUsageMetadata{ + PromptTokenCount: 100, + CandidatesTokenCount: 40, + CachedContentTokenCount: 8, + }, + } + + chatResp := model.convertResponse(resp) + if chatResp.Model != "gemini-2.5-pro" { + t.Fatalf("unexpected model %q", chatResp.Model) + } + if chatResp.Choices[0].FinishReason != core.FinishReasonLength { + t.Fatalf("expected length finish reason, got %q", chatResp.Choices[0].FinishReason) + } + if chatResp.Usage.CacheReadTokens != 8 || chatResp.Usage.TotalTokens != 140 { + t.Fatalf("unexpected usage %#v", chatResp.Usage) + } + msg := chatResp.Choices[0].Message + if msg.GetTextContent() != "answer" { + t.Fatalf("unexpected text content %q", msg.GetTextContent()) + } + if msg.GetThinkingContent() != "thinking" { + t.Fatalf("unexpected thinking content %q", msg.GetThinkingContent()) + } + if len(msg.GetToolUses()) != 1 || msg.GetToolUses()[0].Name != "lookup" { + t.Fatalf("unexpected tool uses %#v", msg.GetToolUses()) + } +} + +func TestGeminiFinishReasonMapping(t *testing.T) { + if got := convertFinishReason(genai.FinishReasonStop); got != core.FinishReasonStop { + t.Fatalf("expected stop finish reason, got %q", got) + } + if got := convertFinishReason(genai.FinishReasonSafety); got != core.FinishReasonContentFilter { + t.Fatalf("expected content filter finish reason, got %q", got) + } + if got := convertFinishReason(genai.FinishReason("other")); got != core.FinishReasonStop { + t.Fatalf("expected default stop finish reason, got %q", got) + } +} diff --git a/provider/gemini/gemini_transport_test.go b/provider/gemini/gemini_transport_test.go new file mode 100644 index 0000000..3638305 --- /dev/null +++ b/provider/gemini/gemini_transport_test.go @@ -0,0 +1,137 @@ +package gemini + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "google.golang.org/genai" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +func TestGeminiNewRequestAndStreamWithLocalServer(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, ":generateContent"): + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "candidates":[ + { + "content":{"role":"model","parts":[{"text":"hello from gemini"}]}, + "finishReason":"STOP" + } + ], + "usageMetadata":{"promptTokenCount":3,"candidatesTokenCount":2} + }`) + case strings.HasSuffix(r.URL.Path, ":streamGenerateContent"): + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, strings.Join([]string{ + `data:{"candidates":[{"content":{"role":"model","parts":[{"text":"hello "},{"thought":true,"text":"think"},{"functionCall":{"name":"lookup","args":{"city":"Lima"}}}]}}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":4,"cachedContentTokenCount":1}}`, + ``, + }, "\n")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + t.Setenv("GOOGLE_GEMINI_BASE_URL", server.URL) + + model, err := New("gemini-2.5-pro", WithAPIKey("test-key")) + if err != nil { + t.Fatalf("New: %v", err) + } + + req := &core.ChatRequest{ + Model: "gemini-2.5-pro", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + } + + resp, err := model.Request(context.Background(), req) + if err != nil { + t.Fatalf("Request: %v", err) + } + if resp.Choices[0].Message.GetTextContent() != "hello from gemini" { + t.Fatalf("unexpected request output %q", resp.Choices[0].Message.GetTextContent()) + } + if resp.Usage.TotalTokens != 5 { + t.Fatalf("unexpected request usage %#v", resp.Usage) + } + + stream, err := model.RequestStream(context.Background(), req) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + var events []core.StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(events) != 5 { + t.Fatalf("expected 5 stream events, got %d: %#v", len(events), events) + } + if events[0].Type != core.StreamEventTextDelta || events[0].Delta != "hello " { + t.Fatalf("unexpected text delta %#v", events[0]) + } + if events[1].Type != core.StreamEventThinkingDelta || events[1].Delta != "think" { + t.Fatalf("unexpected thinking delta %#v", events[1]) + } + if events[2].Type != core.StreamEventToolCallStart || events[2].ToolUse == nil || events[2].ToolUse.Name != "lookup" { + t.Fatalf("unexpected tool-call start %#v", events[2]) + } + if events[3].Type != core.StreamEventToolCallDelta || events[3].ToolCallID == "" || events[3].Delta != `{"city":"Lima"}` { + t.Fatalf("unexpected tool-call delta %#v", events[3]) + } + if events[4].Type != core.StreamEventDone || events[4].Usage == nil || events[4].Usage.TotalTokens != 11 { + t.Fatalf("unexpected done event %#v", events[4]) + } +} + +func TestGeminiNewVertexAIWithCustomBaseURL(t *testing.T) { + server := httptest.NewServer(http.NotFoundHandler()) + defer server.Close() + + t.Setenv("GOOGLE_VERTEX_BASE_URL", server.URL) + + model, err := New("gemini-2.5-pro", WithVertexAI("", "")) + if err != nil { + t.Fatalf("expected Vertex AI client creation to succeed with custom base URL, got %v", err) + } + if model.Name() != "gemini-2.5-pro" { + t.Fatalf("unexpected model name %q", model.Name()) + } +} + +func TestGeminiSchemaAndCandidateHelpersCoverRemainingBranches(t *testing.T) { + schema := convertSchema(map[string]any{ + "type": "array", + "description": "outer", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "kind": map[string]any{ + "type": "string", + "enum": []any{"a", "b"}, + }, + }, + "required": []any{"kind"}, + }, + }) + + if schema.Items == nil || len(schema.Items.Required) != 1 { + t.Fatalf("unexpected converted schema %#v", schema) + } + if len(schema.Items.Properties["kind"].Enum) != 2 { + t.Fatalf("expected enum values on nested property, got %#v", schema.Items.Properties["kind"]) + } + + msg := convertCandidateMessage(&genai.Candidate{}) + if msg.Role != core.RoleAssistant || len(msg.Content) != 0 { + t.Fatalf("expected empty assistant message for nil content, got %#v", msg) + } +} diff --git a/provider/openai/openai_internal_test.go b/provider/openai/openai_internal_test.go new file mode 100644 index 0000000..c62b5c0 --- /dev/null +++ b/provider/openai/openai_internal_test.go @@ -0,0 +1,208 @@ +package openai + +import ( + "context" + "testing" + + sdkopenai "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/shared" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +func TestWithRequestOptions(t *testing.T) { + cfg := &config{} + WithRequestOptions(option.WithHeader("X-Test", "1"))(cfg) + if len(cfg.extraOpts) != 1 { + t.Fatalf("expected 1 extra option, got %d", len(cfg.extraOpts)) + } +} + +func TestOpenAIRequestValidationErrors(t *testing.T) { + model := &Model{model: "gpt-4o"} + + if _, err := model.Request(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected Request to fail validation") + } + if _, err := model.RequestStream(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected RequestStream to fail validation") + } +} + +func TestBuildParamsAppliesOptionalFields(t *testing.T) { + model := &Model{model: "gpt-4o"} + temperature := 0.6 + maxTokens := 128 + topP := 0.9 + frequencyPenalty := 0.1 + presencePenalty := 0.2 + toolChoice := core.ToolChoiceRequired + strict := true + + params := model.buildParams(&core.ChatRequest{ + Model: "gpt-4o", + Messages: []core.Message{ + core.NewTextMessage(core.RoleSystem, "system"), + core.NewTextMessage(core.RoleUser, "hello"), + }, + Temperature: &temperature, + MaxTokens: &maxTokens, + TopP: &topP, + FrequencyPenalty: &frequencyPenalty, + PresencePenalty: &presencePenalty, + Tools: []core.Tool{{ + Type: core.ToolTypeFunction, + Function: core.Function{ + Name: "lookup_weather", + Description: "Lookup weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }}, + ToolChoice: &toolChoice, + ResponseFormat: &core.ResponseFormat{ + Type: "json_schema", + JSONSchema: &core.JSONSchemaFormat{ + Name: "weather", + Description: "Weather response", + Schema: map[string]interface{}{"type": "object"}, + Strict: &strict, + }, + }, + Thinking: &core.ThinkingConfig{Enabled: true, BudgetTokens: 3000}, + }) + + if params.Model != shared.ChatModel("gpt-4o") { + t.Fatalf("unexpected model %q", params.Model) + } + if got := params.Temperature.Or(0); got != temperature { + t.Fatalf("expected temperature %.1f, got %.1f", temperature, got) + } + if got := params.MaxCompletionTokens.Or(0); got != int64(maxTokens) { + t.Fatalf("expected max tokens %d, got %d", maxTokens, got) + } + if got := params.TopP.Or(0); got != topP { + t.Fatalf("expected top_p %.1f, got %.1f", topP, got) + } + if got := params.FrequencyPenalty.Or(0); got != frequencyPenalty { + t.Fatalf("expected frequency penalty %.1f, got %.1f", frequencyPenalty, got) + } + if got := params.PresencePenalty.Or(0); got != presencePenalty { + t.Fatalf("expected presence penalty %.1f, got %.1f", presencePenalty, got) + } + if len(params.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(params.Tools)) + } + if params.ToolChoice.OfAuto.Or("") != string(sdkopenai.ChatCompletionToolChoiceOptionAutoRequired) { + t.Fatalf("unexpected tool choice %#v", params.ToolChoice) + } + if params.ResponseFormat.OfJSONSchema == nil { + t.Fatalf("expected JSON schema response format") + } + if params.ReasoningEffort != shared.ReasoningEffortLow { + t.Fatalf("expected low reasoning effort, got %q", params.ReasoningEffort) + } +} + +func TestConvertResponseFormat(t *testing.T) { + t.Run("json_object", func(t *testing.T) { + got := convertResponseFormat(&core.ResponseFormat{Type: "json_object"}) + if got.OfJSONObject == nil { + t.Fatal("expected json object response format") + } + }) + + t.Run("json_schema", func(t *testing.T) { + strict := true + got := convertResponseFormat(&core.ResponseFormat{ + Type: "json_schema", + JSONSchema: &core.JSONSchemaFormat{ + Name: "output", + Description: "Structured output", + Schema: map[string]interface{}{"type": "object"}, + Strict: &strict, + }, + }) + if got.OfJSONSchema == nil { + t.Fatal("expected json schema response format") + } + }) + + t.Run("json_schema_without_schema_falls_back", func(t *testing.T) { + got := convertResponseFormat(&core.ResponseFormat{Type: "json_schema"}) + if got.OfJSONObject == nil { + t.Fatal("expected fallback to json object response format") + } + }) + + t.Run("default_text", func(t *testing.T) { + got := convertResponseFormat(&core.ResponseFormat{Type: "text"}) + if got.OfText == nil { + t.Fatal("expected text response format") + } + }) +} + +func TestConvertContentPartAdditionalModalities(t *testing.T) { + imageData := convertContentPart(core.Part{ + Type: core.ContentImageData, + ImageData: &core.ImageData{ + Data: "AQID", + MediaType: "image/png", + VendorMetadata: map[string]interface{}{"detail": "high"}, + }, + }) + if imageData.OfImageURL == nil { + t.Fatal("expected image data to convert to image URL part") + } + + audio := convertContentPart(core.Part{ + Type: core.ContentAudioURL, + AudioURL: &core.AudioURL{URL: "https://example.com/audio.mp3", Format: "mp3"}, + }) + if audio.OfInputAudio == nil { + t.Fatal("expected audio URL to convert to input audio part") + } + + file := convertContentPart(core.Part{ + Type: core.ContentUploadedFile, + UploadedFile: &core.UploadedFile{FileID: "file_123"}, + }) + if file.OfFile == nil { + t.Fatal("expected uploaded file to convert to file part") + } + + cachePoint := convertContentPart(core.Part{Type: core.ContentCachePoint}) + if cachePoint.OfText == nil { + t.Fatal("expected cache point to be skipped as an empty text part") + } +} + +func TestExtractOpenAIUsageIncludesReasoningAndCache(t *testing.T) { + usage := extractOpenAIUsage(sdkopenai.CompletionUsage{ + PromptTokens: 100, + CompletionTokens: 40, + TotalTokens: 140, + CompletionTokensDetails: sdkopenai.CompletionUsageCompletionTokensDetails{ + ReasoningTokens: 12, + }, + PromptTokensDetails: sdkopenai.CompletionUsagePromptTokensDetails{ + CachedTokens: 8, + }, + }) + + if usage.PromptTokens != 100 || usage.CompletionTokens != 40 || usage.TotalTokens != 140 { + t.Fatalf("unexpected usage totals: %#v", usage) + } + if usage.ReasoningTokens != 12 { + t.Fatalf("expected reasoning tokens, got %#v", usage) + } + if usage.CacheReadTokens != 8 { + t.Fatalf("expected cache read tokens, got %#v", usage) + } +} diff --git a/provider/openai/openai_transport_test.go b/provider/openai/openai_transport_test.go new file mode 100644 index 0000000..ae630ec --- /dev/null +++ b/provider/openai/openai_transport_test.go @@ -0,0 +1,273 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/openai/openai-go/option" + sdkopenairesponses "github.com/openai/openai-go/responses" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +func TestOpenAIChatRequestAndStreamWithLocalServer(t *testing.T) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + http.NotFound(w, r) + return + } + + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request body: %v", err) + } + + stream, _ := body["stream"].(bool) + if !stream { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "id":"chatcmpl_req", + "object":"chat.completion", + "created":123, + "model":"gpt-4o", + "choices":[{"index":0,"message":{"role":"assistant","content":"hello from chat"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5} + }`) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, strings.Join([]string{ + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":123,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"hello "}}]}`, + ``, + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":123,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"lookup","arguments":"{\"city\":\""}}]}}]}`, + ``, + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":123,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Lima\"}"}}]}}]}`, + ``, + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":123,"model":"gpt-4o","choices":[],"usage":{"prompt_tokens":7,"completion_tokens":4,"total_tokens":11}}`, + ``, + `data: [DONE]`, + ``, + }, "\n")) + })) + defer server.Close() + + model, err := New( + "gpt-4o", + WithAPIKey("test-key"), + WithBaseURL(server.URL+"/v1"), + WithRequestOptions(option.WithHTTPClient(server.Client())), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + + req := &core.ChatRequest{ + Model: "gpt-4o", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + } + + resp, err := model.Request(context.Background(), req) + if err != nil { + t.Fatalf("Request: %v", err) + } + if resp.Choices[0].Message.GetTextContent() != "hello from chat" { + t.Fatalf("unexpected request output %q", resp.Choices[0].Message.GetTextContent()) + } + if resp.Usage.TotalTokens != 5 { + t.Fatalf("unexpected request usage %#v", resp.Usage) + } + + stream, err := model.RequestStream(context.Background(), req) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + var events []core.StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(events) != 5 { + t.Fatalf("expected 5 stream events, got %d: %#v", len(events), events) + } + if events[0].Type != core.StreamEventTextDelta || events[0].Delta != "hello " { + t.Fatalf("unexpected text delta %#v", events[0]) + } + if events[1].Type != core.StreamEventToolCallStart || events[1].ToolUse == nil || events[1].ToolUse.Name != "lookup" { + t.Fatalf("unexpected tool-call start %#v", events[1]) + } + if events[2].Type != core.StreamEventToolCallDelta || events[2].ToolCallID != "call_1" || events[2].Delta != `{"city":"` { + t.Fatalf("unexpected first tool-call delta %#v", events[2]) + } + if events[3].Type != core.StreamEventToolCallDelta || events[3].ToolCallID != "call_1" || events[3].Delta != `Lima"}` { + t.Fatalf("unexpected second tool-call delta %#v", events[3]) + } + if events[4].Type != core.StreamEventDone || events[4].Usage == nil || events[4].Usage.TotalTokens != 11 { + t.Fatalf("unexpected done event %#v", events[4]) + } +} + +func TestOpenAIResponsesRequestAndStreamWithLocalServer(t *testing.T) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/responses" { + http.NotFound(w, r) + return + } + + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request body: %v", err) + } + + stream, _ := body["stream"].(bool) + if !stream { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "id":"resp_req", + "object":"response", + "created_at":123, + "status":"completed", + "model":"gpt-4.1", + "output":[ + { + "id":"msg_1", + "type":"message", + "role":"assistant", + "status":"completed", + "content":[{"type":"output_text","text":"hello from responses"}] + } + ], + "usage":{"input_tokens":5,"output_tokens":4,"total_tokens":9} + }`) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, strings.Join([]string{ + `data: {"type":"response.output_item.added","sequence_number":1,"output_index":0,"item":{"id":"fc_1","type":"function_call","call_id":"call_1","name":"lookup","arguments":"{}","status":"in_progress"}}`, + ``, + `data: {"type":"response.function_call_arguments.delta","sequence_number":2,"output_index":0,"item_id":"fc_1","delta":"{\"city\":\"Lima\"}"}`, + ``, + `data: {"type":"response.output_text.delta","sequence_number":3,"output_index":1,"item_id":"msg_1","content_index":0,"delta":"hello "}`, + ``, + `data: {"type":"response.reasoning_summary_text.delta","sequence_number":4,"output_index":2,"item_id":"rs_1","summary_index":0,"text":"think"}`, + ``, + `data: {"type":"response.refusal.delta","sequence_number":5,"output_index":1,"item_id":"msg_1","content_index":1,"refusal":"no"}`, + ``, + `data: {"type":"response.completed","sequence_number":6,"response":{"id":"resp_stream","object":"response","created_at":123,"status":"completed","model":"gpt-4.1","output":[],"usage":{"input_tokens":7,"output_tokens":4,"total_tokens":11}}}`, + ``, + `data: [DONE]`, + ``, + }, "\n")) + })) + defer server.Close() + + model, err := NewResponses( + "gpt-4.1", + WithAPIKey("test-key"), + WithBaseURL(server.URL+"/v1"), + WithRequestOptions(option.WithHTTPClient(server.Client())), + ) + if err != nil { + t.Fatalf("NewResponses: %v", err) + } + + req := &core.ChatRequest{ + Model: "gpt-4.1", + Messages: []core.Message{core.NewTextMessage(core.RoleUser, "hello")}, + } + + resp, err := model.Request(context.Background(), req) + if err != nil { + t.Fatalf("Request: %v", err) + } + if resp.Choices[0].Message.GetTextContent() != "hello from responses" { + t.Fatalf("unexpected request output %q", resp.Choices[0].Message.GetTextContent()) + } + if resp.Usage.TotalTokens != 9 { + t.Fatalf("unexpected request usage %#v", resp.Usage) + } + + stream, err := model.RequestStream(context.Background(), req) + if err != nil { + t.Fatalf("RequestStream: %v", err) + } + + var events []core.StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(events) != 6 { + t.Fatalf("expected 6 stream events, got %d: %#v", len(events), events) + } + if events[0].Type != core.StreamEventToolCallStart || events[0].ToolUse == nil || events[0].ToolUse.ID != "call_1" { + t.Fatalf("unexpected tool-call start %#v", events[0]) + } + if events[1].Type != core.StreamEventToolCallDelta || events[1].ToolCallID != "call_1" || events[1].Delta != `{"city":"Lima"}` { + t.Fatalf("unexpected tool-call delta %#v", events[1]) + } + if events[2].Type != core.StreamEventTextDelta || events[2].Delta != "hello " { + t.Fatalf("unexpected text delta %#v", events[2]) + } + if events[3].Type != core.StreamEventThinkingDelta || events[3].Delta != "think" { + t.Fatalf("unexpected thinking delta %#v", events[3]) + } + if events[4].Type != core.StreamEventTextDelta || events[4].Delta != "no" { + t.Fatalf("unexpected refusal delta %#v", events[4]) + } + if events[5].Type != core.StreamEventDone || events[5].Usage == nil || events[5].Usage.TotalTokens != 11 { + t.Fatalf("unexpected done event %#v", events[5]) + } +} + +func TestResponsesHelpersCoverRemainingBranches(t *testing.T) { + schema := ensureAdditionalPropertiesFalse(map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "tags": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{"type": "string"}, + }, + }, + }, + }, + }) + + if schema["additionalProperties"] != false { + t.Fatalf("expected top-level additionalProperties=false, got %#v", schema) + } + required, _ := schema["required"].([]string) + if len(required) != 2 { + t.Fatalf("expected both properties to become required, got %#v", schema["required"]) + } + tags := schema["properties"].(map[string]any)["tags"].(map[string]any) + items := tags["items"].(map[string]any) + if items["additionalProperties"] != false { + t.Fatalf("expected nested object additionalProperties=false, got %#v", items) + } + + if got := convertResponsesFinishReason(&sdkopenairesponses.Response{Status: "failed"}); got != core.FinishReasonStop { + t.Fatalf("expected failed status to map to stop, got %q", got) + } + if got := convertResponsesFinishReason(&sdkopenairesponses.Response{Status: "canceled"}); got != core.FinishReasonStop { + t.Fatalf("expected canceled status to map to stop, got %q", got) + } + if got := convertTextConfig(&core.ResponseFormat{Type: "json_schema"}); got.Format.OfText != nil || got.Format.OfJSONSchema != nil { + t.Fatalf("expected empty text config when schema details are missing, got %#v", got) + } +} diff --git a/provider/openai/responses_internal_test.go b/provider/openai/responses_internal_test.go new file mode 100644 index 0000000..a2427ce --- /dev/null +++ b/provider/openai/responses_internal_test.go @@ -0,0 +1,208 @@ +package openai + +import ( + "context" + "testing" + + "github.com/openai/openai-go/responses" + "github.com/openai/openai-go/shared" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +func TestResponsesRequestValidationErrors(t *testing.T) { + model := &ResponsesModel{model: "gpt-4.1"} + + if _, err := model.Request(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected Request to fail validation") + } + if _, err := model.RequestStream(context.Background(), &core.ChatRequest{}); err == nil { + t.Fatal("expected RequestStream to fail validation") + } +} + +func TestResponsesBuildParams(t *testing.T) { + model := &ResponsesModel{model: "gpt-4.1"} + temperature := 0.6 + maxTokens := 256 + topP := 0.8 + toolChoice := core.ToolChoiceRequired + strict := true + + params := model.buildParams(&core.ChatRequest{ + Model: "gpt-4.1", + Messages: []core.Message{ + core.NewTextMessage(core.RoleSystem, "system"), + { + Role: core.RoleUser, + Content: []core.Part{ + {Type: core.ContentText, Text: "hello"}, + {Type: core.ContentImageURL, ImageURL: &core.ImageURL{URL: "https://example.com/image.png", Detail: "high"}}, + }, + }, + { + Role: core.RoleAssistant, + Content: []core.Part{ + {Type: core.ContentText, Text: "working"}, + {Type: core.ContentToolUse, ToolUse: &core.ToolUse{ID: "call_1", Name: "lookup", Input: map[string]interface{}{"city": "Lima"}}}, + {Type: core.ContentThinking, Thinking: &core.ThinkingBlock{Signature: "enc"}}, + }, + }, + core.NewToolResultMessage("call_1", `{"temp":72}`, false), + }, + Temperature: &temperature, + MaxTokens: &maxTokens, + TopP: &topP, + Tools: []core.Tool{{ + Type: core.ToolTypeFunction, + Function: core.Function{ + Name: "lookup", + Description: "Lookup a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }}, + ToolChoice: &toolChoice, + ResponseFormat: &core.ResponseFormat{ + Type: "json_schema", + JSONSchema: &core.JSONSchemaFormat{ + Name: "city_lookup", + Schema: map[string]interface{}{"type": "object"}, + Strict: &strict, + }, + }, + Thinking: &core.ThinkingConfig{Enabled: true, BudgetTokens: 25000}, + }) + + if params.Instructions.Or("") != "system" { + t.Fatalf("expected system instructions, got %#v", params.Instructions) + } + if params.Temperature.Or(0) != temperature { + t.Fatalf("expected temperature %.1f, got %.1f", temperature, params.Temperature.Or(0)) + } + if params.MaxOutputTokens.Or(0) != int64(maxTokens) { + t.Fatalf("expected max output tokens %d, got %d", maxTokens, params.MaxOutputTokens.Or(0)) + } + if params.TopP.Or(0) != topP { + t.Fatalf("expected top_p %.1f, got %.1f", topP, params.TopP.Or(0)) + } + if len(params.Input.OfInputItemList) != 5 { + t.Fatalf("expected 5 input items, got %d", len(params.Input.OfInputItemList)) + } + if len(params.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(params.Tools)) + } + if params.ToolChoice.OfToolChoiceMode.Or("") != responses.ToolChoiceOptionsRequired { + t.Fatalf("unexpected tool choice %#v", params.ToolChoice) + } + if params.Text.Format.OfJSONSchema == nil { + t.Fatalf("expected JSON schema text format") + } + if params.Reasoning.Effort != shared.ReasoningEffortHigh { + t.Fatalf("expected high reasoning effort, got %q", params.Reasoning.Effort) + } +} + +func TestConvertInputContentMultipart(t *testing.T) { + content := convertInputContent(core.Message{ + Role: core.RoleUser, + Content: []core.Part{ + {Type: core.ContentText, Text: "hello"}, + {Type: core.ContentImageURL, ImageURL: &core.ImageURL{URL: "https://example.com/image.png"}}, + {Type: core.ContentImageData, ImageData: &core.ImageData{MediaType: "image/png", Data: "AQID"}}, + {Type: core.ContentAudioURL, AudioURL: &core.AudioURL{URL: "https://example.com/audio.mp3"}}, + {Type: core.ContentDocumentURL, DocumentURL: &core.DocumentURL{URL: "https://example.com/file.pdf"}}, + {Type: core.ContentUploadedFile, UploadedFile: &core.UploadedFile{FileID: "file_123"}}, + }, + }) + + if len(content.OfInputItemContentList) != 6 { + t.Fatalf("expected 6 content parts, got %d", len(content.OfInputItemContentList)) + } + if content.OfInputItemContentList[1].OfInputImage == nil { + t.Fatalf("expected image URL content part") + } + if content.OfInputItemContentList[2].OfInputImage == nil { + t.Fatalf("expected inline image content part") + } + if content.OfInputItemContentList[3].OfInputFile == nil || content.OfInputItemContentList[4].OfInputFile == nil || content.OfInputItemContentList[5].OfInputFile == nil { + t.Fatalf("expected file-based content parts") + } +} + +func TestResponsesConvertResponseAndFinishReason(t *testing.T) { + model := &ResponsesModel{model: "gpt-4.1"} + resp := &responses.Response{ + ID: "resp_123", + Model: shared.ResponsesModel("gpt-4.1"), + CreatedAt: 123, + Status: "completed", + Output: []responses.ResponseOutputItemUnion{ + { + Type: "message", + Content: []responses.ResponseOutputMessageContentUnion{ + {Type: "output_text", Text: "answer"}, + {Type: "refusal", Refusal: " but limited"}, + }, + }, + { + Type: "function_call", + CallID: "call_1", + Name: "lookup", + Arguments: `{"city":"Lima"}`, + }, + { + Type: "reasoning", + EncryptedContent: "enc", + Summary: []responses.ResponseReasoningItemSummary{{ + Text: "reasoning summary", + }}, + }, + }, + Usage: responses.ResponseUsage{ + InputTokens: 100, + OutputTokens: 40, + TotalTokens: 140, + }, + } + + chatResp := model.convertResponse(resp) + if chatResp.ID != "resp_123" || chatResp.Model != "gpt-4.1" { + t.Fatalf("unexpected response metadata: %#v", chatResp) + } + if chatResp.Choices[0].FinishReason != core.FinishReasonToolCalls { + t.Fatalf("expected tool-calls finish reason, got %q", chatResp.Choices[0].FinishReason) + } + msg := chatResp.Choices[0].Message + if msg.GetTextContent() != "answer but limited" { + t.Fatalf("unexpected text content %q", msg.GetTextContent()) + } + if len(msg.GetToolUses()) != 1 || msg.GetToolUses()[0].Name != "lookup" { + t.Fatalf("unexpected tool uses %#v", msg.GetToolUses()) + } + if msg.GetThinkingContent() != "reasoning summary" { + t.Fatalf("unexpected thinking content %q", msg.GetThinkingContent()) + } +} + +func TestConvertResponsesFinishReason(t *testing.T) { + if got := convertResponsesFinishReason(&responses.Response{ + IncompleteDetails: responses.ResponseIncompleteDetails{Reason: "max_output_tokens"}, + }); got != core.FinishReasonLength { + t.Fatalf("expected length finish reason, got %q", got) + } + + if got := convertResponsesFinishReason(&responses.Response{ + IncompleteDetails: responses.ResponseIncompleteDetails{Reason: "content_filter"}, + }); got != core.FinishReasonContentFilter { + t.Fatalf("expected content filter finish reason, got %q", got) + } + + if got := convertResponsesFinishReason(&responses.Response{Status: "completed"}); got != core.FinishReasonStop { + t.Fatalf("expected stop finish reason, got %q", got) + } +} diff --git a/result_internal_test.go b/result_internal_test.go new file mode 100644 index 0000000..31c34d5 --- /dev/null +++ b/result_internal_test.go @@ -0,0 +1,17 @@ +package agentic + +import "testing" + +func TestRunResultNewMessagesNilWhenHistoryConsumesAll(t *testing.T) { + result := &RunResult{ + Messages: []Message{ + NewTextMessage(RoleUser, "history"), + NewTextMessage(RoleAssistant, "answer"), + }, + historyLen: 2, + } + + if got := result.NewMessages(); got != nil { + t.Fatalf("expected nil new messages, got %#v", got) + } +} diff --git a/run_loop_internal_test.go b/run_loop_internal_test.go new file mode 100644 index 0000000..acd88ef --- /dev/null +++ b/run_loop_internal_test.go @@ -0,0 +1,153 @@ +package agentic + +import ( + "context" + "errors" + "testing" + + "github.com/regularkevvv/agentic-go/internal/testutil" +) + +func TestPrepareLoopAndBuildRequestAppliesOverrides(t *testing.T) { + type input struct { + Value int `json:"value"` + } + + toolA, handlerA := MustToolPlain("tool_a", "tool a", func(input input) (string, error) { + return "a", nil + }) + toolB, handlerB := MustToolPlain("tool_b", "tool b", func(input input) (string, error) { + return "b", nil + }) + + seenToolCount := 0 + limits := UsageLimits{MaxRequests: IntPtr(2)} + agent := NewAgent[any]( + "base system", + &testutil.StubModel{NameValue: "coverage-model"}, + WithUsageLimits[any](limits), + WithToolPrepare[any](func(ctx RunContext[any], tools []Tool) ([]Tool, error) { + seenToolCount = len(tools) + return tools[:1], nil + }), + ).AddTool(toolA, handlerA).AddTool(toolB, handlerB) + + agent.systemPromptSuffix = "schema suffix" + agent.responseFormat = &ResponseFormat{Type: "json_object"} + agent.config.thinking = &ThinkingConfig{Enabled: true, BudgetTokens: 9} + + ls, err := agent.prepareLoop( + context.Background(), + "current prompt", + nil, + WithMessages(NewTextMessage(RoleUser, "history")), + WithRunHistoryProcessor(HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return messages[1:], nil + })), + WithRunTemperature(0.3), + WithRunMaxTokens(55), + WithRunMaxIterations(7), + WithRunEndStrategy(EndStrategyEarly), + ) + if err != nil { + t.Fatalf("prepareLoop: %v", err) + } + + if ls.maxIterations != 7 { + t.Fatalf("expected max iterations override, got %d", ls.maxIterations) + } + if ls.endStrategy != EndStrategyEarly { + t.Fatalf("expected end strategy override, got %v", ls.endStrategy) + } + if ls.usageLimits == nil || ls.usageLimits.MaxRequests == nil || *ls.usageLimits.MaxRequests != 2 { + t.Fatalf("expected inherited usage limits, got %#v", ls.usageLimits) + } + if len(ls.messages) != 3 { + t.Fatalf("expected system + history + prompt, got %#v", ls.messages) + } + if got := ls.messages[0].GetTextContent(); got != "base system\n\nschema suffix" { + t.Fatalf("unexpected system prompt %q", got) + } + + req, err := agent.buildRequest(ls, true) + if err != nil { + t.Fatalf("buildRequest: %v", err) + } + + if !req.Stream { + t.Fatal("expected streaming request") + } + if req.Temperature == nil || *req.Temperature != 0.3 { + t.Fatalf("expected run temperature override, got %#v", req.Temperature) + } + if req.MaxTokens == nil || *req.MaxTokens != 55 { + t.Fatalf("expected run max tokens override, got %#v", req.MaxTokens) + } + if len(req.Messages) != 2 { + t.Fatalf("expected history processor to trim one message, got %#v", req.Messages) + } + if req.Messages[0].GetTextContent() != "history" || req.Messages[1].GetTextContent() != "current prompt" { + t.Fatalf("unexpected processed messages %#v", req.Messages) + } + if seenToolCount != 2 { + t.Fatalf("expected tool prepare to see both tools, got %d", seenToolCount) + } + if len(req.Tools) != 1 { + t.Fatalf("expected tool prepare to filter tools, got %#v", req.Tools) + } + if req.ResponseFormat == nil || req.ResponseFormat.Type != "json_object" { + t.Fatalf("expected response format on request, got %#v", req.ResponseFormat) + } + if req.Thinking == nil || req.Thinking.BudgetTokens != 9 { + t.Fatalf("expected thinking config on request, got %#v", req.Thinking) + } +} + +func TestBuildRequestWrapsProcessorAndToolPrepareErrors(t *testing.T) { + t.Run("history processor error", func(t *testing.T) { + agent := NewAgent[any]("system", &testutil.StubModel{NameValue: "coverage-model"}) + ls, err := agent.prepareLoop( + context.Background(), + "prompt", + nil, + WithRunHistoryProcessor(HistoryProcessorFunc(func(ctx context.Context, messages []Message) ([]Message, error) { + return nil, errors.New("boom") + })), + ) + if err != nil { + t.Fatalf("prepareLoop: %v", err) + } + + _, err = agent.buildRequest(ls, false) + if err == nil || err.Error() != "history processor: boom" { + t.Fatalf("expected wrapped history processor error, got %v", err) + } + }) + + t.Run("tool prepare error", func(t *testing.T) { + type input struct { + Value int `json:"value"` + } + + toolDef, handler := MustToolPlain("tool_a", "tool a", func(input input) (string, error) { + return "ok", nil + }) + agent := NewAgent[any]( + "system", + &testutil.StubModel{NameValue: "coverage-model"}, + WithToolPrepare[any](func(ctx RunContext[any], tools []Tool) ([]Tool, error) { + return nil, errors.New("no tools today") + }), + ).AddTool(toolDef, handler) + + ls, err := agent.prepareLoop(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("prepareLoop: %v", err) + } + + _, err = agent.buildRequest(ls, false) + if err == nil || err.Error() != "tool prepare: no tools today" { + t.Fatalf("expected wrapped tool prepare error, got %v", err) + } + }) +} diff --git a/stream_internal_test.go b/stream_internal_test.go new file mode 100644 index 0000000..8e98480 --- /dev/null +++ b/stream_internal_test.go @@ -0,0 +1,293 @@ +package agentic + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/regularkevvv/agentic-go/internal/testutil" +) + +func TestRunStreamTrueText(t *testing.T) { + model := &testutil.ScriptedStreamModel{ + Streams: [][]StreamEvent{{ + {Type: StreamEventTextDelta, Delta: "hello"}, + {Type: StreamEventDone, Usage: &Usage{PromptTokens: 3, CompletionTokens: 2, TotalTokens: 5}}, + }}, + } + + agent := NewAgent[any]("system", model) + stream, err := agent.RunStream(context.Background(), "say hello", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + var events []StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(model.Requests) != 1 || !model.Requests[0].Stream { + t.Fatalf("expected one streaming request, got %#v", model.Requests) + } + if len(events) != 2 { + t.Fatalf("expected 2 events, got %d", len(events)) + } + if events[0].Type != StreamEventTextDelta || events[0].Delta != "hello" { + t.Fatalf("unexpected first event: %#v", events[0]) + } + if events[1].Type != StreamEventDone || events[1].Usage == nil || events[1].Usage.TotalTokens != 5 { + t.Fatalf("unexpected done event: %#v", events[1]) + } +} + +func TestRunStreamTrueWithToolCall(t *testing.T) { + type doubleInput struct { + X int `json:"x"` + } + type doubleOutput struct { + Y int `json:"y"` + } + + tool, handler := MustToolPlain("double", "Double a number", func(input doubleInput) (doubleOutput, error) { + return doubleOutput{Y: input.X * 2}, nil + }) + + model := &testutil.ScriptedStreamModel{ + Streams: [][]StreamEvent{ + { + {Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "call_1", Name: "double"}}, + {Type: StreamEventToolCallDelta, ToolCallID: "call_1", Delta: `{"x":5}`}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 3}}, + }, + { + {Type: StreamEventTextDelta, Delta: "done"}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 2}}, + }, + }, + } + + agent := NewAgent[any]("system", model).AddTool(tool, handler) + stream, err := agent.RunStream(context.Background(), "double 5", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + var toolResultSeen bool + var doneSeen bool + for event := range stream.Events { + if event.Type == StreamEventToolResult { + toolResultSeen = true + if event.ToolCallID != "call_1" || event.Delta != `{"y":10}` { + t.Fatalf("unexpected tool result event: %#v", event) + } + } + if event.Type == StreamEventDone { + doneSeen = true + } + } + + if len(model.Requests) != 2 { + t.Fatalf("expected 2 model requests, got %d", len(model.Requests)) + } + if !toolResultSeen || !doneSeen { + t.Fatalf("expected tool result and done events, got toolResult=%v done=%v", toolResultSeen, doneSeen) + } +} + +func TestRunStreamTrueErrorsWhenToolsAreMissing(t *testing.T) { + model := &testutil.ScriptedStreamModel{ + Streams: [][]StreamEvent{{ + {Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "call_1", Name: "missing"}}, + {Type: StreamEventToolCallDelta, ToolCallID: "call_1", Delta: `{}`}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }}, + } + + agent := NewAgent[any]("system", model) + stream, err := agent.RunStream(context.Background(), "run missing tool", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || err.Error() != "model requested tool calls but no tools are registered" { + t.Fatalf("expected missing tool registry error, got %v", err) + } +} + +func TestRunStreamTrueValidationErrorStopsStream(t *testing.T) { + model := &testutil.ScriptedStreamModel{ + Streams: [][]StreamEvent{{ + {Type: StreamEventTextDelta, Delta: "bad output"}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }}, + } + + agent := NewAgent[any]( + "system", + model, + WithOutputValidatorFunc[any](func(ctx RunContext[any], output string) error { + return errors.New("output rejected") + }), + WithMaxValidationRetries[any](0), + ) + + stream, err := agent.RunStream(context.Background(), "validate this", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + err = stream.Wait() + if err == nil || err.Error() != "output validation failed after 0 retries: output rejected" { + t.Fatalf("expected validation failure, got %v", err) + } +} + +func TestRunStreamTrueValidationRetryAndOutputTool(t *testing.T) { + t.Run("validation retry continues with another streamed request", func(t *testing.T) { + model := &testutil.ScriptedStreamModel{ + Streams: [][]StreamEvent{ + { + {Type: StreamEventTextDelta, Delta: "bad"}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 1}}, + }, + { + {Type: StreamEventTextDelta, Delta: "good"}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 2}}, + }, + }, + } + + agent := NewAgent[any]( + "system", + model, + WithOutputValidatorFunc[any](func(ctx RunContext[any], output string) error { + if output == "bad" { + return NewValidationError("retry please") + } + return nil + }), + WithMaxValidationRetries[any](1), + ) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + var events []StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(model.Requests) != 2 { + t.Fatalf("expected 2 streamed requests after validation retry, got %d", len(model.Requests)) + } + if len(events) != 3 { + t.Fatalf("expected 3 events (two deltas and done), got %#v", events) + } + last := events[len(events)-1] + if last.Type != StreamEventDone || last.Usage == nil || last.Usage.TotalTokens != 3 { + t.Fatalf("unexpected final event %#v", last) + } + + foundValidationMessage := false + for _, msg := range model.Requests[1].Messages { + if msg.Role == RoleUser && strings.Contains(msg.GetTextContent(), "Output validation error: retry please") { + foundValidationMessage = true + break + } + } + if !foundValidationMessage { + t.Fatalf("expected validation retry message in second request, got %#v", model.Requests[1].Messages) + } + }) + + t.Run("output tool ends streaming loop immediately", func(t *testing.T) { + model := &testutil.ScriptedStreamModel{ + Streams: [][]StreamEvent{{ + {Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "out_1", Name: "__output__"}}, + {Type: StreamEventToolCallDelta, ToolCallID: "out_1", Delta: `{"value":"ok"}`}, + {Type: StreamEventDone, Usage: &Usage{TotalTokens: 4}}, + }}, + } + + agent := NewAgent[any]("system", model).SetOutputToolNames(map[string]bool{"__output__": true}) + + stream, err := agent.RunStream(context.Background(), "prompt", nil) + if err != nil { + t.Fatalf("RunStream: %v", err) + } + + var events []StreamEvent + for event := range stream.Events { + events = append(events, event) + } + + if len(model.Requests) != 1 { + t.Fatalf("expected a single request, got %d", len(model.Requests)) + } + if len(events) != 3 { + t.Fatalf("expected streamed tool call events plus done, got %#v", events) + } + if events[len(events)-1].Type != StreamEventDone { + t.Fatalf("expected done event, got %#v", events[len(events)-1]) + } + }) +} + +func TestConsumeAndForward(t *testing.T) { + t.Run("reconstructs text thinking tool calls and forwards events", func(t *testing.T) { + out := make(chan StreamEvent, 8) + stream := testutil.NewScriptedStream( + StreamEvent{Type: StreamEventTextDelta, Delta: "hello"}, + StreamEvent{Type: StreamEventThinkingDelta, Delta: "considering"}, + StreamEvent{Type: StreamEventToolCallStart, ToolUse: &ToolUse{ID: "call_1", Name: "calc"}}, + StreamEvent{Type: StreamEventToolCallDelta, ToolCallID: "call_1", Delta: `{"x":1}`}, + StreamEvent{Type: StreamEventToolResult, Delta: "side-effect"}, + StreamEvent{Type: StreamEventDone, Usage: &Usage{TotalTokens: 7}}, + ) + + msg, usage, err := (&Agent[any]{}).consumeAndForward(stream, out) + if err != nil { + t.Fatalf("consumeAndForward: %v", err) + } + if usage.TotalTokens != 7 { + t.Fatalf("expected usage to be preserved, got %#v", usage) + } + if msg.GetTextContent() != "hello" { + t.Fatalf("expected text content %q, got %q", "hello", msg.GetTextContent()) + } + if msg.GetThinkingContent() != "considering" { + t.Fatalf("expected thinking content %q, got %q", "considering", msg.GetThinkingContent()) + } + toolUses := msg.GetToolUses() + if len(toolUses) != 1 || toolUses[0].Name != "calc" || toolUses[0].Input["x"] != float64(1) { + t.Fatalf("unexpected tool uses: %#v", toolUses) + } + if len(out) != 5 { + t.Fatalf("expected 5 forwarded events, got %d", len(out)) + } + }) + + t.Run("returns streamed errors immediately", func(t *testing.T) { + expected := errors.New("stream failed") + out := make(chan StreamEvent, 1) + stream := testutil.NewScriptedStream(StreamEvent{Type: StreamEventError, Error: expected}) + + _, _, err := (&Agent[any]{}).consumeAndForward(stream, out) + if !errors.Is(err, expected) { + t.Fatalf("expected %v, got %v", expected, err) + } + }) +} + +func TestNewScriptedStreamHelperUsesStableTimestamps(t *testing.T) { + stream := testutil.NewScriptedStream(StreamEvent{Type: StreamEventDone, Usage: &Usage{TotalTokens: int(time.Unix(0, 0).Unix())}}) + if stream == nil { + t.Fatal("expected helper to create a stream") + } +} diff --git a/system_prompt_test.go b/system_prompt_test.go index 14bf883..f75e19a 100644 --- a/system_prompt_test.go +++ b/system_prompt_test.go @@ -3,6 +3,7 @@ package agentic_test import ( "context" "fmt" + "strings" "testing" agentic "github.com/regularkevvv/agentic-go" @@ -129,7 +130,7 @@ func TestDynamicPromptErrorPropagation(t *testing.T) { if err == nil { t.Fatal("expected error from dynamic prompt") } - if !containsStr(err.Error(), "dynamic prompt failed") { + if !strings.Contains(err.Error(), "dynamic prompt failed") { t.Errorf("expected error to contain 'dynamic prompt failed', got %q", err.Error()) } } diff --git a/tool/deferred_test.go b/tool/deferred_test.go new file mode 100644 index 0000000..4e63e7b --- /dev/null +++ b/tool/deferred_test.go @@ -0,0 +1,300 @@ +package tool + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/regularkevvv/agentic-go/internal/core" +) + +type deferredInput struct { + Value int `json:"value"` +} + +func TestDeferredToolExecute(t *testing.T) { + tool, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string, 1) + ch <- "value" + close(ch) + return ch, nil + }) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + if tool.Function.Name != "async_value" { + t.Fatalf("unexpected tool name %q", tool.Function.Name) + } + + dh := handler.(*deferredHandler[deferredInput, string]) + if dh.Name() != "async_value" { + t.Fatalf("unexpected handler name %q", dh.Name()) + } + if dh.ToolConfig() != nil { + t.Fatalf("expected nil tool config, got %#v", dh.ToolConfig()) + } + + out, err := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if out.(string) != "value" { + t.Fatalf("unexpected output %#v", out) + } +} + +func TestDeferredToolExecuteErrors(t *testing.T) { + t.Run("marshal input error", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string, 1) + ch <- "ok" + close(ch) + return ch, nil + }) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": make(chan int)}, nil) + if execErr == nil || execErr.Error()[:14] != "marshal input:" { + t.Fatalf("expected marshal input error, got %v", execErr) + } + }) + + t.Run("unmarshal input error", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string, 1) + ch <- "ok" + close(ch) + return ch, nil + }) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": "bad"}, nil) + if execErr == nil || execErr.Error()[:24] != "unmarshal to tool.deferred"[:24] && execErr.Error()[:13] != "unmarshal to " { + t.Fatalf("expected unmarshal error, got %v", execErr) + } + }) + + t.Run("approval rejected", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string, 1) + ch <- "value" + close(ch) + return ch, nil + }, WithApproval(func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return false, nil + })) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + var retryErr *modelRetry + if !errors.As(execErr, &retryErr) { + t.Fatalf("expected modelRetry, got %T: %v", execErr, execErr) + } + if retryErr.Error() != `Tool "async_value" was rejected by approval` { + t.Fatalf("unexpected retry error %q", retryErr.Error()) + } + }) + + t.Run("approval function error", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string, 1) + ch <- "value" + close(ch) + return ch, nil + }, WithApproval(func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return false, errors.New("approval failed") + })) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + if execErr == nil || execErr.Error() != "approval: approval failed" { + t.Fatalf("expected wrapped approval error, got %v", execErr) + } + }) + + t.Run("handler returns error", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + return nil, errors.New("boom") + }) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + if execErr == nil || execErr.Error() != "boom" { + t.Fatalf("expected handler error, got %v", execErr) + } + }) + + t.Run("closed channel without result", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string) + close(ch) + return ch, nil + }) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + if execErr == nil || !strings.Contains(execErr.Error(), "channel closed without result") { + t.Fatalf("expected closed channel error, got %v", execErr) + } + }) + + t.Run("timeout", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + return make(chan string), nil + }, WithDeferredTimeout(10*time.Millisecond)) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + if execErr == nil || !strings.Contains(execErr.Error(), "timed out after") { + t.Fatalf("expected timeout error, got %v", execErr) + } + }) + + t.Run("context canceled", func(t *testing.T) { + _, handler, err := DeferredTool("async_value", "async value", func(ctx context.Context, input deferredInput) (<-chan string, error) { + return make(chan string), nil + }) + if err != nil { + t.Fatalf("DeferredTool: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, execErr := handler.Execute(ctx, map[string]interface{}{"value": 7}, nil) + if !errors.Is(execErr, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", execErr) + } + }) +} + +func TestMustDeferredTool(t *testing.T) { + if _, handler := MustDeferredTool("must_async", "must async", func(ctx context.Context, input deferredInput) (<-chan string, error) { + ch := make(chan string, 1) + ch <- "ok" + close(ch) + return ch, nil + }); handler == nil { + t.Fatal("expected handler") + } + + defer func() { + if r := recover(); r == nil { + t.Fatal("expected MustDeferredTool to panic on invalid tool") + } + }() + MustDeferredTool("", "desc", func(ctx context.Context, input deferredInput) (<-chan string, error) { + return nil, nil + }) +} + +func TestDeferredToolWithApproval(t *testing.T) { + tool, handler, err := DeferredToolWithApproval("sync_approval", "sync approval", func(ctx context.Context, input deferredInput) (string, error) { + return "approved", nil + }, func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return true, nil + }) + if err != nil { + t.Fatalf("DeferredToolWithApproval: %v", err) + } + if tool.Function.Name != "sync_approval" { + t.Fatalf("unexpected tool name %q", tool.Function.Name) + } + + dh := handler.(*deferredApprovalHandler[deferredInput, string]) + if dh.Name() != "sync_approval" { + t.Fatalf("unexpected handler name %q", dh.Name()) + } + if dh.ToolConfig() != nil { + t.Fatalf("expected nil tool config, got %#v", dh.ToolConfig()) + } + + out, err := handler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if out.(string) != "approved" { + t.Fatalf("unexpected output %#v", out) + } + + _, rejectingHandler, err := DeferredToolWithApproval("sync_reject", "sync reject", func(ctx context.Context, input deferredInput) (string, error) { + return "rejected", nil + }, func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return false, nil + }) + if err != nil { + t.Fatalf("DeferredToolWithApproval: %v", err) + } + if _, execErr := rejectingHandler.Execute(context.Background(), map[string]interface{}{"value": 7}, nil); execErr == nil { + t.Fatal("expected rejection error") + } + + t.Run("approval constructor options are applied", func(t *testing.T) { + _, handler, err := DeferredToolWithApproval("sync_approval", "sync approval", func(ctx context.Context, input deferredInput) (string, error) { + return "approved", nil + }, func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return true, nil + }, WithDeferredTimeout(time.Second)) + if err != nil { + t.Fatalf("DeferredToolWithApproval: %v", err) + } + + dh := handler.(*deferredApprovalHandler[deferredInput, string]) + if dh.config == nil || dh.config.timeout != time.Second { + t.Fatalf("expected approval config timeout to be applied, got %#v", dh.config) + } + }) + + t.Run("approval handler unmarshal error", func(t *testing.T) { + _, handler, err := DeferredToolWithApproval("sync_approval", "sync approval", func(ctx context.Context, input deferredInput) (string, error) { + return "approved", nil + }, func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return true, nil + }) + if err != nil { + t.Fatalf("DeferredToolWithApproval: %v", err) + } + + _, execErr := handler.Execute(context.Background(), map[string]interface{}{"value": "bad"}, nil) + if execErr == nil || execErr.Error()[:13] != "unmarshal to " { + t.Fatalf("expected unmarshal error, got %v", execErr) + } + }) +} + +func TestMustDeferredToolWithApproval(t *testing.T) { + if _, handler := MustDeferredToolWithApproval("must_sync_approval", "must sync approval", func(ctx context.Context, input deferredInput) (string, error) { + return "approved", nil + }, func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return true, nil + }); handler == nil { + t.Fatal("expected handler") + } + + defer func() { + if r := recover(); r == nil { + t.Fatal("expected MustDeferredToolWithApproval to panic on invalid tool") + } + }() + MustDeferredToolWithApproval("", "desc", func(ctx context.Context, input deferredInput) (string, error) { + return "", nil + }, func(ctx context.Context, toolCall core.ToolUse) (bool, error) { + return true, nil + }) +} diff --git a/tool/option_test.go b/tool/option_test.go new file mode 100644 index 0000000..4d83db6 --- /dev/null +++ b/tool/option_test.go @@ -0,0 +1,14 @@ +package tool + +import "testing" + +func TestWithToolMaxRetriesAndApplyToolOptions(t *testing.T) { + if cfg := applyToolOptions(nil); cfg != nil { + t.Fatalf("expected nil config when no options are provided, got %#v", cfg) + } + + cfg := applyToolOptions([]ToolOption{WithToolMaxRetries(3)}) + if cfg == nil || cfg.MaxRetries == nil || *cfg.MaxRetries != 3 { + t.Fatalf("expected max retries to be set to 3, got %#v", cfg) + } +} diff --git a/tool/registry_test.go b/tool/registry_test.go index e4141f2..872eafb 100644 --- a/tool/registry_test.go +++ b/tool/registry_test.go @@ -55,6 +55,83 @@ func TestRegistryExecuteBatch(t *testing.T) { } } +func TestRegistryBasic(t *testing.T) { + reg := NewRegistry() + + type Input struct { + X int `json:"x"` + } + type Output struct { + Y int `json:"y"` + } + + tool, handler := MustToolPlain("double", "Double a number", func(input Input) (Output, error) { + return Output{Y: input.X * 2}, nil + }) + + if err := reg.Register(tool, handler); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reg.Has("double") { + t.Error("expected registry to have 'double'") + } + if reg.Count() != 1 { + t.Errorf("expected count 1, got %d", reg.Count()) + } + + result, err := reg.Execute(context.Background(), core.ToolUse{ + ID: "call_1", Name: "double", Input: map[string]interface{}{"x": float64(5)}, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IsError { + t.Errorf("unexpected tool error: %v", result.Error) + } + + output, ok := result.Content.(Output) + if !ok { + t.Fatalf("expected Output, got %T", result.Content) + } + if output.Y != 10 { + t.Errorf("expected 10, got %d", output.Y) + } +} + +func TestRegistryDuplicate(t *testing.T) { + reg := NewRegistry() + + type Input struct{} + type Output struct{} + + tool, handler := MustToolPlain("test", "test tool", func(input Input) (Output, error) { + return Output{}, nil + }) + + if err := reg.Register(tool, handler); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err := reg.Register(tool, handler); err == nil { + t.Error("expected error for duplicate registration") + } +} + +func TestRegistryUnknownTool(t *testing.T) { + reg := NewRegistry() + + result, err := reg.Execute(context.Background(), core.ToolUse{ + ID: "call_1", Name: "unknown", Input: map[string]interface{}{}, + }, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Error("expected error for unknown tool") + } +} + func TestRegistryExecuteBatchUnknownTool(t *testing.T) { reg := NewRegistry() @@ -143,3 +220,38 @@ func TestRegistryExecuteToolError(t *testing.T) { t.Error("expected non-nil Error field") } } + +func TestRegistryTools(t *testing.T) { + type input struct { + Value int `json:"value"` + } + + plainTool, plainHandler, err := ToolPlain("plain", "plain tool", func(input input) (string, error) { + return "ok", nil + }) + if err != nil { + t.Fatalf("ToolPlain: %v", err) + } + + depsTool, depsHandler, err := ToolWithDeps[input, string, int]("deps", "deps tool", func(ctx core.RunContext[int], input input) (string, error) { + return "ok", nil + }) + if err != nil { + t.Fatalf("ToolWithDeps: %v", err) + } + + reg := NewRegistry() + if err := reg.Register(plainTool, plainHandler); err != nil { + t.Fatalf("Register plain tool: %v", err) + } + if err := reg.Register(depsTool, depsHandler); err != nil { + t.Fatalf("Register deps tool: %v", err) + } + + if !reg.Has("plain") || !reg.Has("deps") || reg.Count() != 2 { + t.Fatalf("unexpected registry state count=%d", reg.Count()) + } + if got := reg.Tools(); len(got) != 2 { + t.Fatalf("expected 2 registered tools, got %#v", got) + } +} diff --git a/tool/tool_test.go b/tool/tool_test.go index db9c93e..9ccfca2 100644 --- a/tool/tool_test.go +++ b/tool/tool_test.go @@ -131,93 +131,41 @@ func TestNewToolFromStructValidation(t *testing.T) { } } -func TestRegistryBasic(t *testing.T) { - reg := NewRegistry() - - type Input struct { - X int `json:"x"` - } - type Output struct { - Y int `json:"y"` - } - - tool, handler := MustToolPlain("double", "Double a number", func(input Input) (Output, error) { - return Output{Y: input.X * 2}, nil - }) - - if err := reg.Register(tool, handler); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if !reg.Has("double") { - t.Error("expected registry to have 'double'") - } - if reg.Count() != 1 { - t.Errorf("expected count 1, got %d", reg.Count()) - } - - // Execute - result, err := reg.Execute(context.Background(), core.ToolUse{ - ID: "call_1", Name: "double", Input: map[string]interface{}{"x": float64(5)}, - }, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.IsError { - t.Errorf("unexpected tool error: %v", result.Error) +func TestFormatToolResult(t *testing.T) { + if got := core.FormatToolResult(nil); got != "" { + t.Errorf("expected empty, got %q", got) } - - output, ok := result.Content.(Output) - if !ok { - t.Fatalf("expected Output, got %T", result.Content) + if got := core.FormatToolResult("hello"); got != "hello" { + t.Errorf("expected 'hello', got %q", got) } - if output.Y != 10 { - t.Errorf("expected 10, got %d", output.Y) + if got := core.FormatToolResult(map[string]int{"x": 1}); got != `{"x":1}` { + t.Errorf("expected JSON, got %q", got) } } -func TestRegistryDuplicate(t *testing.T) { - reg := NewRegistry() - - type Input struct{} - type Output struct{} - - tool, handler := MustToolPlain("test", "test tool", func(input Input) (Output, error) { - return Output{}, nil - }) - - if err := reg.Register(tool, handler); err != nil { - t.Fatalf("unexpected error: %v", err) +func TestHandlersExposeToolConfig(t *testing.T) { + type input struct { + Value int `json:"value"` } - // Second registration should fail - if err := reg.Register(tool, handler); err == nil { - t.Error("expected error for duplicate registration") + _, plainHandler, err := ToolPlain("plain", "plain tool", func(input input) (string, error) { + return "ok", nil + }, WithToolMaxRetries(2)) + if err != nil { + t.Fatalf("ToolPlain: %v", err) } -} - -func TestRegistryUnknownTool(t *testing.T) { - reg := NewRegistry() - result, err := reg.Execute(context.Background(), core.ToolUse{ - ID: "call_1", Name: "unknown", Input: map[string]interface{}{}, - }, nil) + _, depsHandler, err := ToolWithDeps[input, string, int]("deps", "deps tool", func(ctx core.RunContext[int], input input) (string, error) { + return "ok", nil + }, WithToolMaxRetries(3)) if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !result.IsError { - t.Error("expected error for unknown tool") + t.Fatalf("ToolWithDeps: %v", err) } -} -func TestFormatToolResult(t *testing.T) { - if got := core.FormatToolResult(nil); got != "" { - t.Errorf("expected empty, got %q", got) + if cfg := plainHandler.(*PlainToolHandler[input, string]).ToolConfig(); cfg == nil || cfg.MaxRetries == nil || *cfg.MaxRetries != 2 { + t.Fatalf("unexpected plain tool config %#v", cfg) } - if got := core.FormatToolResult("hello"); got != "hello" { - t.Errorf("expected 'hello', got %q", got) - } - if got := core.FormatToolResult(map[string]int{"x": 1}); got != `{"x":1}` { - t.Errorf("expected JSON, got %q", got) + if cfg := depsHandler.(*DepsToolHandler[input, string, int]).ToolConfig(); cfg == nil || cfg.MaxRetries == nil || *cfg.MaxRetries != 3 { + t.Fatalf("unexpected deps tool config %#v", cfg) } } diff --git a/tool/toolset_extra_test.go b/tool/toolset_extra_test.go deleted file mode 100644 index 928d689..0000000 --- a/tool/toolset_extra_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package tool - -import ( - "context" - "testing" -) - -func TestPrefixToolsetExecutesWithRenamedHandler(t *testing.T) { - type In struct { - X int `json:"x"` - } - type Out struct { - Y int `json:"y"` - } - - tool, handler := MustToolPlain("original", "test", func(in In) (Out, error) { - return Out{Y: in.X * 2}, nil - }) - - prefixed := PrefixToolset(NewToolset().Add(tool, handler), "prefixed") - tools, handlers := prefixed.ToolsAndHandlers() - if len(tools) != 1 || len(handlers) != 1 { - t.Fatalf("expected 1 tool and handler, got %d and %d", len(tools), len(handlers)) - } - if tools[0].Function.Name != "prefixed__original" { - t.Fatalf("expected prefixed tool name, got %q", tools[0].Function.Name) - } - if handlers[0].Name() != "prefixed__original" { - t.Fatalf("expected handler name %q, got %q", "prefixed__original", handlers[0].Name()) - } - - result, err := handlers[0].Execute(context.Background(), map[string]interface{}{"x": float64(5)}, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - out, ok := result.(Out) - if !ok { - t.Fatalf("expected Out, got %T", result) - } - if out.Y != 10 { - t.Errorf("expected 10, got %d", out.Y) - } -} diff --git a/tool/toolset_test.go b/tool/toolset_test.go index 3ff509c..7836591 100644 --- a/tool/toolset_test.go +++ b/tool/toolset_test.go @@ -1,6 +1,7 @@ package tool import ( + "context" "testing" ) @@ -102,6 +103,43 @@ func TestPrefixToolset(t *testing.T) { } } +func TestPrefixToolsetExecutesWithRenamedHandler(t *testing.T) { + type In struct { + X int `json:"x"` + } + type Out struct { + Y int `json:"y"` + } + + tool, handler := MustToolPlain("original", "test", func(in In) (Out, error) { + return Out{Y: in.X * 2}, nil + }) + + prefixed := PrefixToolset(NewToolset().Add(tool, handler), "prefixed") + tools, handlers := prefixed.ToolsAndHandlers() + if len(tools) != 1 || len(handlers) != 1 { + t.Fatalf("expected 1 tool and handler, got %d and %d", len(tools), len(handlers)) + } + if tools[0].Function.Name != "prefixed__original" { + t.Fatalf("expected prefixed tool name, got %q", tools[0].Function.Name) + } + if handlers[0].Name() != "prefixed__original" { + t.Fatalf("expected handler name %q, got %q", "prefixed__original", handlers[0].Name()) + } + + result, err := handlers[0].Execute(context.Background(), map[string]interface{}{"x": float64(5)}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out, ok := result.(Out) + if !ok { + t.Fatalf("expected Out, got %T", result) + } + if out.Y != 10 { + t.Errorf("expected 10, got %d", out.Y) + } +} + func TestRegisterToolset(t *testing.T) { type In struct { X int `json:"x"` diff --git a/typed_agent_internal_test.go b/typed_agent_internal_test.go new file mode 100644 index 0000000..e163433 --- /dev/null +++ b/typed_agent_internal_test.go @@ -0,0 +1,105 @@ +package agentic + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/regularkevvv/agentic-go/internal/testutil" +) + +type typedAgentCoverageValue struct { + Value string `json:"value"` +} + +type wrongTypeOutputSpec struct{} + +func (wrongTypeOutputSpec) Tools() []Tool { + return nil +} + +func (wrongTypeOutputSpec) Parse(Message) (any, error) { + return 123, nil +} + +func TestTypedAgentHelpersAndTextProcessorError(t *testing.T) { + type input struct { + Value int `json:"value"` + } + + auxTool, auxHandler := MustToolPlain("aux", "aux tool", func(input input) (string, error) { + return "ok", nil + }) + + ta := NewTypedAgentWithMode[any, typedAgentCoverageValue]( + "system", + &testutil.StubModel{NameValue: "typed-model"}, + NewToolOutput[typedAgentCoverageValue]("desc"), + ) + + if got := ta.AddToolset(NewToolset().Add(auxTool, auxHandler)); got != ta { + t.Fatal("expected AddToolset to return the same typed agent") + } + if ta.agent.registry == nil || !ta.agent.registry.Has("aux") || !ta.agent.registry.Has("__output__") { + t.Fatalf("expected registry to contain toolset and output tools, got %#v", ta.agent.registry) + } + + newRegistry := NewRegistry() + if got := ta.SetRegistry(newRegistry); got != ta { + t.Fatal("expected SetRegistry to return the same typed agent") + } + if !newRegistry.Has("__output__") { + t.Fatal("expected output tool to be re-registered on replacement registry") + } + if newRegistry.Has("aux") { + t.Fatal("expected replacement registry to contain only re-registered output tools") + } + + last := ta.lastAssistantMsg([]Message{ + NewTextMessage(RoleUser, "question"), + NewTextMessage(RoleAssistant, "answer"), + NewTextMessage(RoleUser, "follow-up"), + }) + if last == nil || last.GetTextContent() != "answer" { + t.Fatalf("unexpected last assistant message %#v", last) + } + if ta.lastAssistantMsg([]Message{NewTextMessage(RoleUser, "only user")}) != nil { + t.Fatal("expected nil when there is no assistant message") + } + + if _, err := ta.parseResult(&RunResult{Messages: []Message{NewTextMessage(RoleUser, "question")}}); err == nil || err.Error() != "no assistant message found in result" { + t.Fatalf("expected missing assistant message error, got %v", err) + } + + wrong := &TypedAgent[any, typedAgentCoverageValue]{ + agent: NewAgent[any]("system", &testutil.StubModel{NameValue: "typed-model"}), + outputSpec: wrongTypeOutputSpec{}, + } + _, err := wrong.parseResult(&RunResult{Messages: []Message{NewTextMessage(RoleAssistant, "answer")}}) + if err == nil || !strings.Contains(err.Error(), "unexpected output type") { + t.Fatalf("expected type assertion error, got %v", err) + } + + model := &testutil.StubModel{ + NameValue: "typed-model", + Response: &ChatResponse{ + Choices: []Choice{{ + Message: NewTextMessage(RoleAssistant, "bad"), + FinishReason: FinishReasonStop, + }}, + }, + } + textAgent := NewTypedAgentWithMode[any, int]( + "system", + model, + NewTextProcessorOutput(func(text string) (int, error) { + return 0, errors.New("cannot convert") + }), + ) + + _, err = textAgent.Run(context.Background(), "prompt", nil) + if err == nil || err.Error() != "text processor: cannot convert" { + t.Fatalf("expected text processor error, got %v", err) + } +} diff --git a/usage_limits_internal_test.go b/usage_limits_internal_test.go new file mode 100644 index 0000000..aa612df --- /dev/null +++ b/usage_limits_internal_test.go @@ -0,0 +1,42 @@ +package agentic + +import ( + "errors" + "testing" +) + +func TestUsageLimitsCheckBeforeRequestBranches(t *testing.T) { + t.Run("requests", func(t *testing.T) { + limits := UsageLimits{MaxRequests: IntPtr(1)} + err := limits.checkBeforeRequest(Usage{Requests: 1}) + var exceeded *UsageLimitExceededError + if !errors.As(err, &exceeded) || exceeded.LimitName != "requests" { + t.Fatalf("expected request limit error, got %v", err) + } + }) + + t.Run("request tokens", func(t *testing.T) { + limits := UsageLimits{MaxRequestTokens: IntPtr(5)} + err := limits.checkBeforeRequest(Usage{PromptTokens: 6}) + var exceeded *UsageLimitExceededError + if !errors.As(err, &exceeded) || exceeded.LimitName != "request_tokens" { + t.Fatalf("expected request token limit error, got %v", err) + } + }) + + t.Run("total tokens", func(t *testing.T) { + limits := UsageLimits{MaxTotalTokens: IntPtr(5)} + err := limits.checkBeforeRequest(Usage{TotalTokens: 6}) + var exceeded *UsageLimitExceededError + if !errors.As(err, &exceeded) || exceeded.LimitName != "total_tokens" { + t.Fatalf("expected total token limit error, got %v", err) + } + }) + + t.Run("success", func(t *testing.T) { + limits := UsageLimits{MaxRequests: IntPtr(2), MaxRequestTokens: IntPtr(6), MaxTotalTokens: IntPtr(10)} + if err := limits.checkBeforeRequest(Usage{Requests: 1, PromptTokens: 6, TotalTokens: 10}); err != nil { + t.Fatalf("expected limits to pass, got %v", err) + } + }) +} diff --git a/validator_internal_test.go b/validator_internal_test.go new file mode 100644 index 0000000..cd89cc8 --- /dev/null +++ b/validator_internal_test.go @@ -0,0 +1,103 @@ +package agentic + +import ( + "context" + "errors" + "strings" + "testing" +) + +type validatorCoverageValue struct { + Value string `json:"value"` +} + +func TestValidateStructAdditionalTagsAndTypedValidator(t *testing.T) { + tests := []struct { + name string + value any + substr string + }{ + { + name: "len", + value: struct { + Code string `json:"code" validate:"len=3"` + }{Code: "abcd"}, + substr: "code must have exactly 3 elements", + }, + { + name: "email", + value: struct { + Email string `json:"email" validate:"email"` + }{Email: "not-an-email"}, + substr: "email must be a valid email address", + }, + { + name: "url", + value: struct { + URL string `json:"url" validate:"url"` + }{URL: "not-a-url"}, + substr: "url must be a valid URL", + }, + { + name: "contains", + value: struct { + Body string `json:"body" validate:"contains=ok"` + }{Body: "missing"}, + substr: "body must contain 'ok'", + }, + { + name: "gt", + value: struct { + Count int `json:"count" validate:"gt=5"` + }{Count: 5}, + substr: "count must be greater than 5", + }, + { + name: "gte", + value: struct { + Count int `json:"count" validate:"gte=5"` + }{Count: 4}, + substr: "count must be greater than or equal to 5", + }, + { + name: "lt", + value: struct { + Count int `json:"count" validate:"lt=5"` + }{Count: 5}, + substr: "count must be less than 5", + }, + { + name: "lte", + value: struct { + Count int `json:"count" validate:"lte=5"` + }{Count: 6}, + substr: "count must be less than or equal to 5", + }, + { + name: "default branch", + value: struct { + Name string `json:"name" validate:"startswith=ab"` + }{Name: "zz"}, + substr: "name failed 'startswith' validation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateStruct(tt.value) + if err == nil || !strings.Contains(err.Error(), tt.substr) { + t.Fatalf("expected error containing %q, got %v", tt.substr, err) + } + }) + } + + validator := TypedOutputValidatorFunc[any, validatorCoverageValue](func(ctx RunContext[any], output validatorCoverageValue) error { + if output.Value != "ok" { + return errors.New("unexpected output") + } + return nil + }) + if err := validator.ValidateTyped(RunContext[any]{Ctx: context.Background()}, validatorCoverageValue{Value: "ok"}); err != nil { + t.Fatalf("expected typed validator to pass, got %v", err) + } +}