diff --git a/pkg/llm/provider/bedrock/bedrock.go b/pkg/llm/provider/bedrock/bedrock.go new file mode 100644 index 0000000..b7408a0 --- /dev/null +++ b/pkg/llm/provider/bedrock/bedrock.go @@ -0,0 +1,189 @@ +// Package bedrock implements the Provider interface for AWS Bedrock's +// InvokeModel API with Anthropic Claude models. +// +// When using Bedrock's InvokeModel endpoint, the request body follows the +// Anthropic Messages API format with the model specified in the URL path +// rather than the request body, and an anthropic_version field in the body. +// The response format is identical to the native Anthropic Messages API. +package bedrock + +import ( + "encoding/json" + "strings" + "time" + + "github.com/papercomputeco/tapes/pkg/llm" +) + +// Provider implements the Provider interface for AWS Bedrock. +type Provider struct{} + +// New creates a new Bedrock provider. +func New() *Provider { return &Provider{} } + +// Name returns the provider name. +func (p *Provider) Name() string { + return "bedrock" +} + +// DefaultStreaming returns false. Bedrock's InvokeModel does not stream by +// default; streaming requires InvokeModelWithResponseStream. +func (p *Provider) DefaultStreaming() bool { + return false +} + +func (p *Provider) ParseRequest(payload []byte) (*llm.ChatRequest, error) { + var req bedrockRequest + if err := json.Unmarshal(payload, &req); err != nil { + return nil, err + } + + system := parseBedrockSystem(req.System) + messages := make([]llm.Message, 0, len(req.Messages)) + for _, msg := range req.Messages { + converted := llm.Message{Role: msg.Role} + + switch content := msg.Content.(type) { + case string: + converted.Content = []llm.ContentBlock{{Type: "text", Text: content}} + case []any: + for _, item := range content { + if block, ok := item.(map[string]any); ok { + cb := llm.ContentBlock{} + if t, ok := block["type"].(string); ok { + cb.Type = t + } + if text, ok := block["text"].(string); ok { + cb.Text = text + } + if source, ok := block["source"].(map[string]any); ok { + if mt, ok := source["media_type"].(string); ok { + cb.MediaType = mt + } + if data, ok := source["data"].(string); ok { + cb.ImageBase64 = data + } + } + + // Tool use + if id, ok := block["id"].(string); ok { + cb.ToolUseID = id + } + if name, ok := block["name"].(string); ok { + cb.ToolName = name + } + if input, ok := block["input"].(map[string]any); ok { + cb.ToolInput = input + } + converted.Content = append(converted.Content, cb) + } + } + } + + messages = append(messages, converted) + } + + result := &llm.ChatRequest{ + Model: req.Model, + Messages: messages, + System: system, + MaxTokens: &req.MaxTokens, + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + Stop: req.Stop, + Stream: req.Stream, + RawRequest: payload, + } + + if req.AnthropicVersion != "" { + result.Extra = map[string]any{ + "anthropic_version": req.AnthropicVersion, + } + } + + return result, nil +} + +func parseBedrockSystem(system any) string { + if system == nil { + return "" + } + + switch value := system.(type) { + case string: + return value + case []any: + var builder strings.Builder + for _, item := range value { + block, ok := item.(map[string]any) + if !ok { + continue + } + blockType, _ := block["type"].(string) + text, _ := block["text"].(string) + if blockType == "text" && text != "" { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString(text) + } + } + return builder.String() + default: + return "" + } +} + +func (p *Provider) ParseResponse(payload []byte) (*llm.ChatResponse, error) { + var resp bedrockResponse + if err := json.Unmarshal(payload, &resp); err != nil { + return nil, err + } + + content := make([]llm.ContentBlock, 0, len(resp.Content)) + for _, block := range resp.Content { + cb := llm.ContentBlock{Type: block.Type} + switch block.Type { + case "text": + cb.Text = block.Text + case "tool_use": + cb.ToolUseID = block.ID + cb.ToolName = block.Name + cb.ToolInput = block.Input + } + content = append(content, cb) + } + + var usage *llm.Usage + if resp.Usage != nil { + usage = &llm.Usage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + } + + result := &llm.ChatResponse{ + Model: resp.Model, + Message: llm.Message{ + Role: resp.Role, + Content: content, + }, + Done: true, + StopReason: resp.StopReason, + Usage: usage, + CreatedAt: time.Now(), + RawResponse: payload, + Extra: map[string]any{ + "id": resp.ID, + "type": resp.Type, + }, + } + + return result, nil +} + +func (p *Provider) ParseStreamChunk(_ []byte) (*llm.StreamChunk, error) { + panic("not implemented") +} diff --git a/pkg/llm/provider/bedrock/bedrock_suite_test.go b/pkg/llm/provider/bedrock/bedrock_suite_test.go new file mode 100644 index 0000000..3759192 --- /dev/null +++ b/pkg/llm/provider/bedrock/bedrock_suite_test.go @@ -0,0 +1,13 @@ +package bedrock_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestBedrock(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Bedrock Provider Suite") +} diff --git a/pkg/llm/provider/bedrock/bedrock_test.go b/pkg/llm/provider/bedrock/bedrock_test.go new file mode 100644 index 0000000..b0a7d7c --- /dev/null +++ b/pkg/llm/provider/bedrock/bedrock_test.go @@ -0,0 +1,404 @@ +package bedrock_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/papercomputeco/tapes/pkg/llm/provider" + "github.com/papercomputeco/tapes/pkg/llm/provider/bedrock" +) + +var _ = Describe("Bedrock Provider", func() { + var p provider.Provider + + BeforeEach(func() { + p = bedrock.New() + }) + + Describe("Name", func() { + It("returns 'bedrock'", func() { + Expect(p.Name()).To(Equal("bedrock")) + }) + }) + + Describe("ParseRequest", func() { + Context("with a simple text request", func() { + It("parses messages correctly", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, Claude!"} + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Model).To(Equal("")) + Expect(*req.MaxTokens).To(Equal(1024)) + Expect(req.Messages).To(HaveLen(1)) + Expect(req.Messages[0].Role).To(Equal("user")) + Expect(req.Messages[0].GetText()).To(Equal("Hello, Claude!")) + }) + }) + + Context("with anthropic_version field", func() { + It("stores anthropic_version in Extra", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Extra).To(HaveKeyWithValue("anthropic_version", "bedrock-2023-05-31")) + }) + }) + + Context("with model field present", func() { + It("parses the model field when provided", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Model).To(Equal("anthropic.claude-3-sonnet-20240229-v1:0")) + }) + }) + + Context("with content block array format", func() { + It("parses text content blocks", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"} + ] + } + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Messages).To(HaveLen(1)) + Expect(req.Messages[0].Content).To(HaveLen(1)) + Expect(req.Messages[0].Content[0].Type).To(Equal("text")) + Expect(req.Messages[0].Content[0].Text).To(Equal("What's in this image?")) + }) + + It("parses image content blocks with base64 source", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "iVBORw0KGgo..." + } + } + ] + } + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Messages[0].Content).To(HaveLen(2)) + Expect(req.Messages[0].Content[1].MediaType).To(Equal("image/png")) + Expect(req.Messages[0].Content[1].ImageBase64).To(Equal("iVBORw0KGgo...")) + }) + }) + + Context("with system prompt", func() { + It("parses the system field as string", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "system": "You are a helpful coding assistant.", + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.System).To(Equal("You are a helpful coding assistant.")) + }) + + It("parses the system field as array of content blocks", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "system": [ + {"type": "text", "text": "You are helpful."}, + {"type": "text", "text": "Be concise."} + ], + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.System).To(Equal("You are helpful.\nBe concise.")) + }) + }) + + Context("with generation parameters", func() { + It("parses temperature, top_p, top_k", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(*req.Temperature).To(BeNumerically("~", 0.7, 0.001)) + Expect(*req.TopP).To(BeNumerically("~", 0.9, 0.001)) + Expect(*req.TopK).To(Equal(40)) + }) + + It("parses stop_sequences", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "stop_sequences": ["END", "STOP"], + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Stop).To(ConsistOf("END", "STOP")) + }) + }) + + Context("with streaming flag", func() { + It("parses stream: true", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "stream": true, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(*req.Stream).To(BeTrue()) + }) + + It("parses stream: false", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "stream": false, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(*req.Stream).To(BeFalse()) + }) + }) + + Context("with tool use in messages", func() { + It("parses tool_use content blocks", func() { + payload := []byte(`{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "get_weather", + "input": {"location": "San Francisco"} + } + ] + } + ] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Messages[0].Content).To(HaveLen(1)) + Expect(req.Messages[0].Content[0].Type).To(Equal("tool_use")) + Expect(req.Messages[0].Content[0].ToolUseID).To(Equal("toolu_123")) + Expect(req.Messages[0].Content[0].ToolName).To(Equal("get_weather")) + Expect(req.Messages[0].Content[0].ToolInput).To(HaveKeyWithValue("location", "San Francisco")) + }) + }) + + Context("with invalid payload", func() { + It("returns an error for invalid JSON", func() { + payload := []byte(`not valid json`) + _, err := p.ParseRequest(payload) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("preserves raw request", func() { + It("stores the original payload in RawRequest", func() { + payload := []byte(`{"anthropic_version": "bedrock-2023-05-31", "max_tokens": 1024, "messages": []}`) + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect([]byte(req.RawRequest)).To(Equal(payload)) + }) + }) + + Context("without anthropic_version", func() { + It("does not set Extra when anthropic_version is absent", func() { + payload := []byte(`{ + "max_tokens": 1024, + "messages": [{"role": "user", "content": "Hello"}] + }`) + + req, err := p.ParseRequest(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Extra).To(BeNil()) + }) + }) + }) + + Describe("ParseResponse", func() { + Context("with a simple text response", func() { + It("parses the response correctly", func() { + payload := []byte(`{ + "id": "msg_01234567890", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello! How can I help you today?"} + ], + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "stop_reason": "end_turn", + "usage": { + "input_tokens": 10, + "output_tokens": 25 + } + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Model).To(Equal("anthropic.claude-3-sonnet-20240229-v1:0")) + Expect(resp.Message.Role).To(Equal("assistant")) + Expect(resp.Message.GetText()).To(Equal("Hello! How can I help you today?")) + Expect(resp.StopReason).To(Equal("end_turn")) + Expect(resp.Done).To(BeTrue()) + }) + }) + + Context("with usage metrics", func() { + It("parses token counts correctly", func() { + payload := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hi"}], + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "stop_reason": "end_turn", + "usage": { + "input_tokens": 100, + "output_tokens": 50 + } + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Usage).NotTo(BeNil()) + Expect(resp.Usage.PromptTokens).To(Equal(100)) + Expect(resp.Usage.CompletionTokens).To(Equal(50)) + Expect(resp.Usage.TotalTokens).To(Equal(150)) + }) + }) + + Context("with tool_use response", func() { + It("parses tool_use content blocks", func() { + payload := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll check the weather for you."}, + { + "type": "tool_use", + "id": "toolu_456", + "name": "get_weather", + "input": {"location": "NYC", "unit": "celsius"} + } + ], + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "stop_reason": "tool_use" + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Message.Content).To(HaveLen(2)) + Expect(resp.Message.Content[0].Type).To(Equal("text")) + Expect(resp.Message.Content[1].Type).To(Equal("tool_use")) + Expect(resp.Message.Content[1].ToolUseID).To(Equal("toolu_456")) + Expect(resp.Message.Content[1].ToolName).To(Equal("get_weather")) + Expect(resp.StopReason).To(Equal("tool_use")) + }) + }) + + Context("with Extra fields", func() { + It("stores id and type in Extra", func() { + payload := []byte(`{ + "id": "msg_abc123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hi"}], + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "stop_reason": "end_turn" + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Extra).To(HaveKeyWithValue("id", "msg_abc123")) + Expect(resp.Extra).To(HaveKeyWithValue("type", "message")) + }) + }) + + Context("preserves raw response", func() { + It("stores the original payload in RawResponse", func() { + payload := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hi"}], + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "stop_reason": "end_turn" + }`) + + resp, err := p.ParseResponse(payload) + Expect(err).NotTo(HaveOccurred()) + Expect([]byte(resp.RawResponse)).To(Equal(payload)) + }) + }) + + Context("with invalid payload", func() { + It("returns an error for invalid JSON", func() { + payload := []byte(`not valid json`) + _, err := p.ParseResponse(payload) + Expect(err).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/pkg/llm/provider/bedrock/types.go b/pkg/llm/provider/bedrock/types.go new file mode 100644 index 0000000..8ce3cc5 --- /dev/null +++ b/pkg/llm/provider/bedrock/types.go @@ -0,0 +1,59 @@ +package bedrock + +// bedrockRequest represents the AWS Bedrock InvokeModel request format for Claude. +// This is similar to the Anthropic Messages API but without a model field +// (the model is specified in the URL path) and with an anthropic_version field. +type bedrockRequest struct { + AnthropicVersion string `json:"anthropic_version,omitempty"` + Model string `json:"model,omitempty"` + Messages []bedrockMessage `json:"messages"` + System any `json:"system,omitempty"` + MaxTokens int `json:"max_tokens"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + Stop []string `json:"stop_sequences,omitempty"` + Stream *bool `json:"stream,omitempty"` +} + +// bedrockMessage represents a message in the Bedrock request/response format. +type bedrockMessage struct { + Role string `json:"role"` + + // Union type: can be "string" or "[]bedrockContentBlock" + Content any `json:"content"` +} + +// bedrockContentBlock represents a content block in the Bedrock format. +type bedrockContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *bedrockSource `json:"source,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]any `json:"input,omitempty"` +} + +type bedrockSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +// bedrockResponse represents the AWS Bedrock InvokeModel response format for Claude. +// This is identical to the Anthropic Messages API response format. +type bedrockResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []bedrockContentBlock `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage *bedrockUsage `json:"usage,omitempty"` +} + +type bedrockUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/pkg/llm/provider/supported.go b/pkg/llm/provider/supported.go index 969db3a..efee5b0 100644 --- a/pkg/llm/provider/supported.go +++ b/pkg/llm/provider/supported.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/papercomputeco/tapes/pkg/llm/provider/anthropic" + "github.com/papercomputeco/tapes/pkg/llm/provider/bedrock" "github.com/papercomputeco/tapes/pkg/llm/provider/ollama" "github.com/papercomputeco/tapes/pkg/llm/provider/openai" ) @@ -11,13 +12,14 @@ import ( // Supported provider type constants const ( Anthropic = "anthropic" + Bedrock = "bedrock" OpenAI = "openai" Ollama = "ollama" ) // SupportedProviders returns the list of all supported provider type names. func SupportedProviders() []string { - return []string{Anthropic, OpenAI, Ollama} + return []string{Anthropic, Bedrock, OpenAI, Ollama} } // New creates a new Provider instance for the given provider type. @@ -26,6 +28,8 @@ func New(providerType string) (Provider, error) { switch providerType { case Anthropic: return anthropic.New(), nil + case Bedrock: + return bedrock.New(), nil case OpenAI: return openai.New(), nil case Ollama: