Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lib/rubyai/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ class Client
attr_accessor :configuration

def initialize(config = {})
@configuration = RubyAI.config(config)
if RubyAI.instance_variable_defined?(:@config)
RubyAI.config.set_instance_variables(config)
return @configuration = RubyAI.config
end
@configuration = RubyAI.config(config)
end

def call
Expand Down
56 changes: 34 additions & 22 deletions lib/rubyai/configuration.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,52 @@ class Configuration
:provider

def initialize(config = nil)
@openai ||= Providers::OpenAI.new
@anthropic ||= Providers::Anthropic.new
@gemini ||= Providers::Gemini.new
@bedrock_anthropic ||= Providers::BedrockAnthropic.new
set_instances_of_providers
return if config.nil? || config.empty?

set_instance_variables(config)
end

def self.config(config = {})
@config ||= RubyAI.config(config)
end

def self.configure
yield(config)
end

provider_key = if config.fetch(:model,
nil).nil?
"openai"
else
Provider.models.find do |_provider, models|
models.include?(config[:model])
end.first
end
def set_instance_variables(config)
@provider = find_provider(config)

raise ArgumentError, "Unknown model: #{config[:model]}" unless provider_key
raise ArgumentError, "Unknown model: #{config[:model]}" unless @provider

@model = config.fetch(:model, Provider.models["openai"]["gpt-3.5-turbo"])
provider = send(provider_key)
provider.api = config[:api_key]
@api_key = provider.api
provider.messages = config[:messages]
provider = send(@provider)
provider.api = config.fetch(:api_key) if provider.api.nil?
provider.messages = config.fetch(:messages, nil)
@messages = provider.messages
@provider = provider_key
@api_key = provider.api
provider.temperature = config.fetch(:temperature, 0.7)
@temperature = provider.temperature
end

def self.config
@config ||= RubyAI.config({})
private

def set_instances_of_providers
@openai ||= Providers::OpenAI.new
@anthropic ||= Providers::Anthropic.new
@gemini ||= Providers::Gemini.new
@bedrock_anthropic ||= Providers::BedrockAnthropic.new
end

def self.configure
yield(config)
def find_provider(config)
if config.fetch(:model, nil).nil?
"openai"
else
Provider.models.find do |_provider, models|
models.include?(config[:model])
end.first
end
end
end
end
6 changes: 5 additions & 1 deletion lib/rubyai/providers/anthropic.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module RubyAI
module Providers
class Anthropic < Base
attr_accessor :api, :messages, :temperature, :max_tokens
attr_accessor :max_tokens

def initialize(api: nil, messages: nil, temperature: 0.7, model: "claude-2")
@api = api
Expand All @@ -21,6 +21,10 @@ def models
}.freeze
end

def default_model
models['claude-1.3-sonnet']
end

def build_http_body(messages = nil, model = nil, temperature = nil)
{
"model" => RubyAI::Provider.models["anthropic"][model || @model],
Expand Down
2 changes: 2 additions & 0 deletions lib/rubyai/providers/base.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module RubyAI
module Providers
class Base
attr_accessor :api, :messages, :model, :temperature

def configure
yield self
end
Expand Down
9 changes: 5 additions & 4 deletions lib/rubyai/providers/bedrock_anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
module RubyAI
module Providers
class BedrockAnthropic < Base
attr_accessor :messages,
:temperature,
:max_tokens,
:model,
attr_accessor :max_tokens,
:region,
:aws_access_key_id,
:aws_secret_access_key,
Expand Down Expand Up @@ -55,6 +52,10 @@ def models
}.freeze
end

def default_model
"claude-instant"
end

def build_http_body(messages = nil, _model = nil, temperature = nil)
{
"anthropic_version" => "bedrock-2023-05-31",
Expand Down
6 changes: 5 additions & 1 deletion lib/rubyai/providers/gemini.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module RubyAI
module Providers
class Gemini < Base
attr_accessor :api, :messages, :temperature, :max_tokens, :model
attr_accessor :max_tokens

def initialize(api: nil, messages: nil, temperature: 0.7, max_tokens: 1000)
@api = api
Expand All @@ -18,6 +18,10 @@ def models
}
end

def default_model
models["gemini-1.5-flash"]
end

def build_http_body(messages = nil, _model, temperature, max_tokens: 1000)
{
contents: [
Expand Down
6 changes: 4 additions & 2 deletions lib/rubyai/providers/openai.rb
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module RubyAI
module Providers
class OpenAI < Base
attr_accessor :api, :messages, :temperature, :model

def initialize(api: nil, messages: nil, temperature: 0.7)
@api = api
@messages = messages
Expand All @@ -22,6 +20,10 @@ def models
}
end

def default_model
models["gpt-3.5-turbo"]
end

def build_http_body(messages = nil, model = "gpt-3.5-turbo", temperature = nil)
{
model: RubyAI::Provider.models["openai"][model],
Expand Down