diff --git a/lib/rubyai/client.rb b/lib/rubyai/client.rb index d17a1cd..4f4c0da 100644 --- a/lib/rubyai/client.rb +++ b/lib/rubyai/client.rb @@ -3,7 +3,11 @@ class Client attr_accessor :configuration def initialize(config = {}) - @configuration = RubyAI.config(config) + if RubyAI.instance_variable_defined?(:@config) + RubyAI.config.set_instance_variables(config) + return @configuration = RubyAI.config + end + @configuration = RubyAI.config(config) end def call diff --git a/lib/rubyai/configuration.rb b/lib/rubyai/configuration.rb index bcf6ce4..b500124 100644 --- a/lib/rubyai/configuration.rb +++ b/lib/rubyai/configuration.rb @@ -11,40 +11,52 @@ class Configuration :provider def initialize(config = nil) - @openai ||= Providers::OpenAI.new - @anthropic ||= Providers::Anthropic.new - @gemini ||= Providers::Gemini.new - @bedrock_anthropic ||= Providers::BedrockAnthropic.new + set_instances_of_providers return if config.nil? || config.empty? + + set_instance_variables(config) + end + + def self.config(config = {}) + @config ||= RubyAI.config(config) + end + + def self.configure + yield(config) + end - provider_key = if config.fetch(:model, - nil).nil? - "openai" - else - Provider.models.find do |_provider, models| - models.include?(config[:model]) - end.first - end + def set_instance_variables(config) + @provider = find_provider(config) - raise ArgumentError, "Unknown model: #{config[:model]}" unless provider_key + raise ArgumentError, "Unknown model: #{config[:model]}" unless @provider @model = config.fetch(:model, Provider.models["openai"]["gpt-3.5-turbo"]) - provider = send(provider_key) - provider.api = config[:api_key] - @api_key = provider.api - provider.messages = config[:messages] + provider = send(@provider) + provider.api = config.fetch(:api_key) if provider.api.nil? + provider.messages = config.fetch(:messages, nil) @messages = provider.messages - @provider = provider_key + @api_key = provider.api provider.temperature = config.fetch(:temperature, 0.7) @temperature = provider.temperature end - def self.config - @config ||= RubyAI.config({}) + private + + def set_instances_of_providers + @openai ||= Providers::OpenAI.new + @anthropic ||= Providers::Anthropic.new + @gemini ||= Providers::Gemini.new + @bedrock_anthropic ||= Providers::BedrockAnthropic.new end - def self.configure - yield(config) + def find_provider(config) + if config.fetch(:model, nil).nil? + "openai" + else + Provider.models.find do |_provider, models| + models.include?(config[:model]) + end.first + end end end end diff --git a/lib/rubyai/providers/anthropic.rb b/lib/rubyai/providers/anthropic.rb index 945cc25..8a7ff0a 100644 --- a/lib/rubyai/providers/anthropic.rb +++ b/lib/rubyai/providers/anthropic.rb @@ -1,7 +1,7 @@ module RubyAI module Providers class Anthropic < Base - attr_accessor :api, :messages, :temperature, :max_tokens + attr_accessor :max_tokens def initialize(api: nil, messages: nil, temperature: 0.7, model: "claude-2") @api = api @@ -21,6 +21,10 @@ def models }.freeze end + def default_model + models['claude-1.3-sonnet'] + end + def build_http_body(messages = nil, model = nil, temperature = nil) { "model" => RubyAI::Provider.models["anthropic"][model || @model], diff --git a/lib/rubyai/providers/base.rb b/lib/rubyai/providers/base.rb index 4e5b9ab..5431567 100644 --- a/lib/rubyai/providers/base.rb +++ b/lib/rubyai/providers/base.rb @@ -1,6 +1,8 @@ module RubyAI module Providers class Base + attr_accessor :api, :messages, :model, :temperature + def configure yield self end diff --git a/lib/rubyai/providers/bedrock_anthropic.rb b/lib/rubyai/providers/bedrock_anthropic.rb index fe4491c..9be18a8 100644 --- a/lib/rubyai/providers/bedrock_anthropic.rb +++ b/lib/rubyai/providers/bedrock_anthropic.rb @@ -4,10 +4,7 @@ module RubyAI module Providers class BedrockAnthropic < Base - attr_accessor :messages, - :temperature, - :max_tokens, - :model, + attr_accessor :max_tokens, :region, :aws_access_key_id, :aws_secret_access_key, @@ -55,6 +52,10 @@ def models }.freeze end + def default_model + "claude-instant" + end + def build_http_body(messages = nil, _model = nil, temperature = nil) { "anthropic_version" => "bedrock-2023-05-31", diff --git a/lib/rubyai/providers/gemini.rb b/lib/rubyai/providers/gemini.rb index de955d2..5a99474 100644 --- a/lib/rubyai/providers/gemini.rb +++ b/lib/rubyai/providers/gemini.rb @@ -1,7 +1,7 @@ module RubyAI module Providers class Gemini < Base - attr_accessor :api, :messages, :temperature, :max_tokens, :model + attr_accessor :max_tokens def initialize(api: nil, messages: nil, temperature: 0.7, max_tokens: 1000) @api = api @@ -18,6 +18,10 @@ def models } end + def default_model + models["gemini-1.5-flash"] + end + def build_http_body(messages = nil, _model, temperature, max_tokens: 1000) { contents: [ diff --git a/lib/rubyai/providers/openai.rb b/lib/rubyai/providers/openai.rb index 0d3675f..289590f 100644 --- a/lib/rubyai/providers/openai.rb +++ b/lib/rubyai/providers/openai.rb @@ -1,8 +1,6 @@ module RubyAI module Providers class OpenAI < Base - attr_accessor :api, :messages, :temperature, :model - def initialize(api: nil, messages: nil, temperature: 0.7) @api = api @messages = messages @@ -22,6 +20,10 @@ def models } end + def default_model + models["gpt-3.5-turbo"] + end + def build_http_body(messages = nil, model = "gpt-3.5-turbo", temperature = nil) { model: RubyAI::Provider.models["openai"][model],