diff --git a/lib/rubyai.rb b/lib/rubyai.rb index 0712424..f694941 100644 --- a/lib/rubyai.rb +++ b/lib/rubyai.rb @@ -6,6 +6,7 @@ require_relative "rubyai/providers/openai" require_relative "rubyai/providers/anthropic" require_relative "rubyai/providers/gemini" +require_relative "rubyai/providers/bedrock_anthropic" require_relative "rubyai/provider" require_relative "rubyai/configuration" require_relative "rubyai/http" diff --git a/lib/rubyai/configuration.rb b/lib/rubyai/configuration.rb index 2f08b8c..bcf6ce4 100644 --- a/lib/rubyai/configuration.rb +++ b/lib/rubyai/configuration.rb @@ -1,10 +1,9 @@ module RubyAI class Configuration - # providers config accessors attr_accessor :openai, :anthropic, :gemini, - # default api config accessors + :bedrock_anthropic, :api, :model, :messages, @@ -15,6 +14,7 @@ def initialize(config = nil) @openai ||= Providers::OpenAI.new @anthropic ||= Providers::Anthropic.new @gemini ||= Providers::Gemini.new + @bedrock_anthropic ||= Providers::BedrockAnthropic.new return if config.nil? || config.empty? provider_key = if config.fetch(:model, diff --git a/lib/rubyai/http.rb b/lib/rubyai/http.rb index 9444ffe..79ae5b3 100644 --- a/lib/rubyai/http.rb +++ b/lib/rubyai/http.rb @@ -17,7 +17,7 @@ def connect(messages:, provider: , model:, temperature:) response = case provider when "bedrock_anthropic" RubyAI.config.bedrock_anthropic.client.invoke_model(model_id: model, - body: BedrockAnthropic.build_http_body.to_json, + body: RubyAI.config.bedrock_anthropic.build_http_body.to_json, content_type: "application/json") else connection.post do |req| diff --git a/lib/rubyai/providers/bedrock_anthropic.rb b/lib/rubyai/providers/bedrock_anthropic.rb new file mode 100644 index 0000000..fe4491c --- /dev/null +++ b/lib/rubyai/providers/bedrock_anthropic.rb @@ -0,0 +1,112 @@ +require "aws-sdk-bedrockruntime" +require "json" + +module RubyAI + module Providers + class BedrockAnthropic < Base + attr_accessor :messages, + :temperature, + :max_tokens, + :model, + :region, + :aws_access_key_id, + :aws_secret_access_key, + :aws_session_token + + def initialize( + region: "us-east-1", + messages: nil, + temperature: 0.7, + model: "claude-3-sonnet", + max_tokens: 4096, + aws_access_key_id: nil, + aws_secret_access_key: nil, + aws_session_token: nil + ) + @messages = messages + @temperature = temperature + @model = model + @max_tokens = max_tokens + @region = region + + # Initialize AWS Bedrock Runtime client + @client_options = { region: @region } + + @aws_access_key_id = aws_access_key_id + @aws_secret_access_key = aws_secret_access_key + @aws_session_token = aws_session_token + + # Use explicit credentials if provided, otherwise fall back to default credential chain + return unless aws_access_key_id && aws_secret_access_key + + build_credentials + end + + def models + { + "claude-3-opus" => "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-sonnet" => "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-haiku" => "anthropic.claude-3-haiku-20240307-v1:0", + "claude-3-5-sonnet" => "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-5-haiku" => "anthropic.claude-3-5-haiku-20241022-v1:0", + "claude-instant" => "anthropic.claude-instant-v1", + "claude-v2" => "anthropic.claude-v2", + "claude-v2-1" => "anthropic.claude-v2:1" + }.freeze + end + + def build_http_body(messages = nil, _model = nil, temperature = nil) + { + "anthropic_version" => "bedrock-2023-05-31", + "max_tokens" => @max_tokens, + "messages" => format_messages_for_bedrock(messages || @messages), + "temperature" => temperature || @temperature + } + end + + def build_http_headers(_provider) + # For Bedrock, headers are handled by AWS SDK + # Return empty hash as headers are not needed for direct HTTP calls + {} + end + + def format_messages_for_bedrock(messages) + case messages + when String + [{ "role" => "user", "content" => messages }] + when Array + messages.map do |msg| + case msg + when Hash + # Ensure required keys exist + { + "role" => msg["role"] || msg[:role] || "user", + "content" => msg["content"] || msg[:content] || "" + } + when String + { "role" => "user", "content" => msg } + else + { "role" => "user", "content" => msg.to_s } + end + end + else + [{ "role" => "user", "content" => messages.to_s }] + end + end + + def build_credentials + credentials = Aws::Credentials.new( + aws_access_key_id, + aws_secret_access_key, + aws_session_token + ) + @client_options[:credentials] = credentials + end + + def client + build_credentials + @client ||= Aws::BedrockRuntime::Client.new(@client_options) + end + end + end +end diff --git a/spec/rubyai/providers/bedrock_anthropic_spec.rb b/spec/rubyai/providers/bedrock_anthropic_spec.rb new file mode 100644 index 0000000..5a50b4b --- /dev/null +++ b/spec/rubyai/providers/bedrock_anthropic_spec.rb @@ -0,0 +1,582 @@ +require_relative "../../../lib/rubyai/providers/bedrock_anthropic" +require "spec_helper" +require "aws-sdk-bedrockruntime" + +RSpec.describe RubyAI::Providers::BedrockAnthropic do + let(:aws_access_key_id) { "test-access-key" } + let(:aws_secret_access_key) { "test-secret-key" } + let(:aws_session_token) { "test-session-token" } + let(:region) { "us-west-2" } + let(:messages) { [{ "role" => "user", "content" => "Hello, world!" }] } + let(:temperature) { 0.8 } + let(:model) { "claude-3-sonnet" } + let(:max_tokens) { 2048 } + + # Mock AWS SDK classes + let(:mock_credentials) { instance_double(Aws::Credentials) } + let(:mock_client) { instance_double(Aws::BedrockRuntime::Client) } + + before do + allow(Aws::Credentials).to receive(:new).and_return(mock_credentials) + allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(mock_client) + end + + describe "#initialize" do + context "with default parameters" do + subject { described_class.new } + + it "sets default values" do + expect(subject.region).to eq("us-east-1") + expect(subject.messages).to be_nil + expect(subject.temperature).to eq(0.7) + expect(subject.model).to eq("claude-3-sonnet") + expect(subject.max_tokens).to eq(4096) + expect(subject.aws_access_key_id).to be_nil + expect(subject.aws_secret_access_key).to be_nil + expect(subject.aws_session_token).to be_nil + end + + it "sets up client options with default region" do + expect(subject.instance_variable_get(:@client_options)).to eq({ region: "us-east-1" }) + end + end + + context "with custom parameters" do + subject do + described_class.new( + region: region, + messages: messages, + temperature: temperature, + model: model, + max_tokens: max_tokens, + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: aws_secret_access_key, + aws_session_token: aws_session_token + ) + end + + it "sets custom values" do + expect(subject.region).to eq(region) + expect(subject.messages).to eq(messages) + expect(subject.temperature).to eq(temperature) + expect(subject.model).to eq(model) + expect(subject.max_tokens).to eq(max_tokens) + # NOTE: AWS credentials are stored as instance variables but no attr_accessor exists + expect(subject.instance_variable_get(:@aws_access_key_id)).to eq(aws_access_key_id) + expect(subject.instance_variable_get(:@aws_secret_access_key)).to eq(aws_secret_access_key) + expect(subject.instance_variable_get(:@aws_session_token)).to eq(aws_session_token) + end + it "builds credentials when AWS credentials are provided" do + expect(Aws::Credentials).to receive(:new).with( + aws_access_key_id, + aws_secret_access_key, + aws_session_token + ).and_return(mock_credentials) + + described_class.new( + region: region, + messages: messages, + temperature: temperature, + model: model, + max_tokens: max_tokens, + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: aws_secret_access_key, + aws_session_token: aws_session_token + ) + end + end + + context "with partial AWS credentials" do + subject do + described_class.new( + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: nil + ) + end + + it "does not build credentials when incomplete" do + expect(Aws::Credentials).not_to have_received(:new) + expect(subject.instance_variable_get(:@client_options)[:credentials]).to be_nil + end + end + end + + describe "#models" do + subject { described_class.new.models } + + it "returns a frozen hash of available models" do + expect(subject).to be_frozen + expect(subject).to be_a(Hash) + end + + it "includes all expected Claude models" do + expected_models = { + "claude-3-opus" => "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-sonnet" => "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-haiku" => "anthropic.claude-3-haiku-20240307-v1:0", + "claude-3-5-sonnet" => "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-5-haiku" => "anthropic.claude-3-5-haiku-20241022-v1:0", + "claude-instant" => "anthropic.claude-instant-v1", + "claude-v2" => "anthropic.claude-v2", + "claude-v2-1" => "anthropic.claude-v2:1" + } + + expect(subject).to eq(expected_models) + end + + it "maps model names to Bedrock model IDs" do + expect(subject["claude-3-sonnet"]).to eq("anthropic.claude-3-sonnet-20240229-v1:0") + expect(subject["claude-3-5-sonnet"]).to eq("anthropic.claude-3-5-sonnet-20240620-v1:0") + end + end + + describe "#build_http_body" do + let(:instance) do + described_class.new( + messages: messages, + temperature: temperature, + max_tokens: max_tokens + ) + end + + context "with default instance values" do + subject { instance.build_http_body } + + it "builds correct HTTP body structure" do + expect(subject).to be_a(Hash) + expect(subject).to have_key("anthropic_version") + expect(subject).to have_key("max_tokens") + expect(subject).to have_key("messages") + expect(subject).to have_key("temperature") + end + + it "uses correct Bedrock API version" do + expect(subject["anthropic_version"]).to eq("bedrock-2023-05-31") + end + + it "uses instance values" do + expect(subject["max_tokens"]).to eq(max_tokens) + expect(subject["messages"]).to eq(messages) + expect(subject["temperature"]).to eq(temperature) + end + end + + context "with method parameters" do + let(:custom_messages) { [{ "role" => "assistant", "content" => "Custom response" }] } + let(:custom_temperature) { 0.9 } + + subject { instance.build_http_body(custom_messages, "ignored_model", custom_temperature) } + + it "uses method parameters over instance values" do + expect(subject["messages"]).to eq(custom_messages) + expect(subject["temperature"]).to eq(custom_temperature) + end + + it "ignores the model parameter" do + # The method accepts _model parameter but doesn't use it + expect { subject }.not_to raise_error + end + end + + context "with nil parameters" do + let(:instance) { described_class.new } + subject { instance.build_http_body } + + it "handles nil values gracefully" do + expect { subject }.not_to raise_error + expect(subject["messages"]).to eq([{ "role" => "user", "content" => "" }]) + expect(subject["temperature"]).to eq(0.7) # default value + end + end + end + + describe "#build_http_headers" do + let(:instance) { described_class.new } + let(:provider) { "bedrock" } + + subject { instance.build_http_headers(provider) } + + it "returns empty hash for Bedrock" do + expect(subject).to eq({}) + end + + it "ignores the provider parameter" do + expect { instance.build_http_headers("different_provider") }.not_to raise_error + expect(subject).to eq({}) + end + end + + describe "#format_messages_for_bedrock" do + let(:instance) { described_class.new } + + context "with string input" do + let(:string_message) { "Hello, world!" } + subject { instance.format_messages_for_bedrock(string_message) } + + it "converts string to message array" do + expect(subject).to be_an(Array) + expect(subject.length).to eq(1) + expect(subject.first).to eq({ "role" => "user", "content" => string_message }) + end + end + + context "with array of hash messages" do + let(:hash_messages) do + [ + { "role" => "user", "content" => "Hello" }, + { role: "assistant", content: "Hi there!" }, + { "role" => "user", "content" => "How are you?" } + ] + end + subject { instance.format_messages_for_bedrock(hash_messages) } + + it "formats hash messages correctly" do + expect(subject).to be_an(Array) + expect(subject.length).to eq(3) + expect(subject[0]).to eq({ "role" => "user", "content" => "Hello" }) + expect(subject[1]).to eq({ "role" => "assistant", "content" => "Hi there!" }) + expect(subject[2]).to eq({ "role" => "user", "content" => "How are you?" }) + end + end + + context "with array of string messages" do + let(:string_messages) { ["Hello", "How are you?", "Goodbye"] } + subject { instance.format_messages_for_bedrock(string_messages) } + + it "converts strings to user messages" do + expect(subject).to be_an(Array) + expect(subject.length).to eq(3) + subject.each_with_index do |msg, index| + expect(msg).to eq({ "role" => "user", "content" => string_messages[index] }) + end + end + end + + context "with array of mixed types" do + let(:mixed_messages) do + [ + { "role" => "user", "content" => "Hello" }, + "How are you?", + 123, + { role: "assistant", content: "I'm good!" } + ] + end + subject { instance.format_messages_for_bedrock(mixed_messages) } + + it "handles mixed types correctly" do + expect(subject).to be_an(Array) + expect(subject.length).to eq(4) + expect(subject[0]).to eq({ "role" => "user", "content" => "Hello" }) + expect(subject[1]).to eq({ "role" => "user", "content" => "How are you?" }) + expect(subject[2]).to eq({ "role" => "user", "content" => "123" }) + expect(subject[3]).to eq({ "role" => "assistant", "content" => "I'm good!" }) + end + end + + context "with hash messages missing keys" do + let(:incomplete_messages) do + [ + { "role" => "user" }, # missing content + { "content" => "Hello" }, # missing role + {} # missing both + ] + end + subject { instance.format_messages_for_bedrock(incomplete_messages) } + + it "provides default values for missing keys" do + expect(subject).to be_an(Array) + expect(subject.length).to eq(3) + expect(subject[0]).to eq({ "role" => "user", "content" => "" }) + expect(subject[1]).to eq({ "role" => "user", "content" => "Hello" }) + expect(subject[2]).to eq({ "role" => "user", "content" => "" }) + end + end + + context "with nil input" do + subject { instance.format_messages_for_bedrock(nil) } + + it "converts nil to default user message" do + expect(subject).to eq([{ "role" => "user", "content" => "" }]) + end + end + + context "with other input types" do + let(:number_input) { 42 } + let(:symbol_input) { :hello } + + it "converts numbers to string content" do + result = instance.format_messages_for_bedrock(number_input) + expect(result).to eq([{ "role" => "user", "content" => "42" }]) + end + + it "converts symbols to string content" do + result = instance.format_messages_for_bedrock(symbol_input) + expect(result).to eq([{ "role" => "user", "content" => "hello" }]) + end + end + end + + describe "#build_credentials" do + context "when called with credentials set" do + let(:instance) do + described_class.new( + region: region, + messages: messages, + temperature: temperature, + model: model, + max_tokens: max_tokens, + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: aws_secret_access_key, + aws_session_token: aws_session_token + ) + end + + it "builds credentials when called" do + instance = described_class.new + instance.instance_variable_set(:@aws_access_key_id, aws_access_key_id) + instance.instance_variable_set(:@aws_secret_access_key, aws_secret_access_key) + instance.instance_variable_set(:@aws_session_token, aws_session_token) + + expect(Aws::Credentials).to receive(:new).with( + aws_access_key_id, + aws_secret_access_key, + aws_session_token + ).and_return(mock_credentials) + + instance.build_credentials + expect(instance.instance_variable_get(:@client_options)[:credentials]).to eq(mock_credentials) + end + + it "adds credentials to client options" do + expect(instance.instance_variable_get(:@client_options)[:credentials]).to eq(mock_credentials) + end + end + + context "when called manually" do + let(:instance) { described_class.new } + + before do + instance.instance_variable_set(:@aws_access_key_id, aws_access_key_id) + instance.instance_variable_set(:@aws_secret_access_key, aws_secret_access_key) + instance.instance_variable_set(:@aws_session_token, aws_session_token) + + # Reset the mock to track new calls + allow(Aws::Credentials).to receive(:new).and_return(mock_credentials) + end + + it "builds credentials when called" do + instance.build_credentials + + expect(Aws::Credentials).to have_received(:new).with( + aws_access_key_id, + aws_secret_access_key, + aws_session_token + ) + expect(instance.instance_variable_get(:@client_options)[:credentials]).to eq(mock_credentials) + end + end + end + + describe "#client" do + context "with AWS credentials" do + let(:instance) do + described_class.new( + region: region, + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: aws_secret_access_key, + aws_session_token: aws_session_token + ) + end + + it "creates and returns a BedrockRuntime client" do + # Reset the mock to track new calls from client method + allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(mock_client) + + client = instance.client + expect(client).to eq(mock_client) + expect(Aws::BedrockRuntime::Client).to have_received(:new).with( + region: region, + credentials: mock_credentials + ) + end + + it "memoizes the client" do + client1 = instance.client + client2 = instance.client + expect(client1).to eq(client2) + # Should only be called once due to memoization + expect(Aws::BedrockRuntime::Client).to have_received(:new).once + end + end + + context "without AWS credentials" do + let(:instance) { described_class.new(region: region) } + + before do + # Reset the mock to track new calls + allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(mock_client) + end + + it "creates client but still calls build_credentials" do + client = instance.client + expect(client).to eq(mock_client) + # The client method calls build_credentials, which tries to build credentials + # even without credentials set, so it still gets credentials in the options + expect(Aws::BedrockRuntime::Client).to have_received(:new).with( + region: region, + credentials: mock_credentials + ) + end + end + end + + describe "attribute accessors" do + let(:instance) { described_class.new } + + it "allows reading and writing accessible attributes" do + # Test messages + new_messages = [{ "role" => "user", "content" => "New message" }] + instance.messages = new_messages + expect(instance.messages).to eq(new_messages) + + # Test temperature + instance.temperature = 0.5 + expect(instance.temperature).to eq(0.5) + + # Test max_tokens + instance.max_tokens = 1024 + expect(instance.max_tokens).to eq(1024) + + # Test model + instance.model = "claude-3-haiku" + expect(instance.model).to eq("claude-3-haiku") + + # Test region + instance.region = "eu-west-1" + expect(instance.region).to eq("eu-west-1") + + # Test AWS credentials + instance.aws_access_key_id = "new-key" + expect(instance.aws_access_key_id).to eq("new-key") + + instance.aws_secret_access_key = "new-secret" + expect(instance.aws_secret_access_key).to eq("new-secret") + + instance.aws_session_token = "new-token" + expect(instance.aws_session_token).to eq("new-token") + end + end + + describe "integration tests" do + let(:instance) do + described_class.new( + region: region, + messages: messages, + temperature: temperature, + model: model, + max_tokens: max_tokens, + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: aws_secret_access_key, + aws_session_token: aws_session_token + ) + end + + describe "complete request building" do + it "builds a complete request structure" do + headers = instance.build_http_headers("bedrock") + body = instance.build_http_body + + expect(headers).to eq({}) + expect(body["anthropic_version"]).to eq("bedrock-2023-05-31") + expect(body["max_tokens"]).to eq(max_tokens) + expect(body["messages"]).to eq(messages) + expect(body["temperature"]).to eq(temperature) + end + end + + describe "message formatting workflow" do + it "formats string messages correctly in the full workflow" do + string_message = "Test message" + formatted = instance.format_messages_for_bedrock(string_message) + body = instance.build_http_body(formatted) + + expect(body["messages"]).to eq([{ "role" => "user", "content" => string_message }]) + end + end + + describe "client initialization workflow" do + let(:instance) do + described_class.new( + region: region, + messages: messages, + temperature: temperature, + model: model, + max_tokens: max_tokens, + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: aws_secret_access_key, + aws_session_token: aws_session_token + ) + end + + it "initializes client with proper credentials" do + # Reset the mock to track new calls from client method + allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(mock_client) + + client = instance.client + expect(client).to eq(mock_client) + + # Check that credentials were built during initialization + expect(Aws::Credentials).to have_received(:new).with( + aws_access_key_id, + aws_secret_access_key, + aws_session_token + ).twice + + # Check that client was created with proper options + expect(Aws::BedrockRuntime::Client).to have_received(:new).with( + region: region, + credentials: mock_credentials + ) + end + end + end + + describe "edge cases and error handling" do + context "when AWS SDK is not available" do + before do + allow(Aws::Credentials).to receive(:new).and_raise(StandardError.new("AWS SDK not available")) + end + + it "propagates AWS SDK errors" do + expect do + described_class.new( + aws_access_key_id: aws_access_key_id, + aws_secret_access_key: aws_secret_access_key + ) + end.to raise_error(StandardError, "AWS SDK not available") + end + end + + context "with empty string credentials" do + let(:instance) do + described_class.new( + aws_access_key_id: "", + aws_secret_access_key: "" + ) + end + + it "does not build credentials for empty strings" do + expect(Aws::Credentials).not_to have_received(:new) + end + end + + context "with very large max_tokens" do + let(:instance) { described_class.new(max_tokens: 1_000_000) } + + it "handles large max_tokens values" do + body = instance.build_http_body + expect(body["max_tokens"]).to eq(1_000_000) + end + end + end +end