diff --git a/.env.example b/.env.example index f9b4728..18fa939 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,19 @@ # Bot config -GEMINI_API_KEY= ADDR=:8080 +# AI Provider Configuration +# AI_PROVIDER can be 'gemini' or 'openai'. Defaults to 'gemini'. +AI_PROVIDER=gemini + +# Gemini Configuration (Always required for embeddings) +GEMINI_API_KEY= +# GEMINI_MODEL=gemini-flash-latest + +# OpenAI Configuration (Required if AI_PROVIDER=openai) +OPENAI_API_KEY= +# OPENAI_MODEL=gpt-4o-mini +# OPENAI_BASE_URL=https://api.openai.com/v1 + # Telegram client config BOT_TOKEN= USERNAME_LIMITS= \ No newline at end of file diff --git a/README.md b/README.md index cb6716c..dff7ec7 100644 --- a/README.md +++ b/README.md @@ -67,9 +67,13 @@ You can configure the server using Environment Variables. | Variable | Description | Default | Required | | :--- | :--- | :--- | :---: | -| `GEMINI_API_KEY` | Your Google Gemini API access key. | - | ✅ | +| `AI_PROVIDER` | The AI provider to use (`gemini` or `openai`). | `gemini` | ❌ | +| `GEMINI_API_KEY` | Your Google Gemini API access key (Always required for embedding). | - | ✅ | | `ADDR` | Server listen address. | `:8080` | ❌ | | `GEMINI_MODEL` | The specific model version to use. | `gemini-flash-latest` | ❌ | +| `OPENAI_API_KEY` | Your OpenAI API access key (Required if `AI_PROVIDER` is `openai`). | - | ❌ | +| `OPENAI_MODEL` | The specific OpenAI model to use. | `gpt-4o-mini` | ❌ | +| `OPENAI_BASE_URL` | Base URL for OpenAI compatible APIs. | - | ❌ | | `MCP_SERVERS` | Comma-separated list of MCP HTTP stream servers (e.g., `http://localhost:8081/mcp`). | - | ❌ | | `GEMINI_SEARCH_DISABLED` | Set to `true` or `1` to disable Google Search grounding. Search is **enabled by default**. | `false` | ❌ | | `HISTORY_SUMMARY` | Message count trigger for history summarization (`0` to disable). | `20` | ❌ | diff --git a/cmd/server-bot/main.go b/cmd/server-bot/main.go index cd4436f..689c3bb 100644 --- a/cmd/server-bot/main.go +++ b/cmd/server-bot/main.go @@ -13,11 +13,13 @@ import ( "hairy-botter/internal/ai/agent" "hairy-botter/internal/ai/gemini" gemini_embedding "hairy-botter/internal/ai/gemini-embedding" + "hairy-botter/internal/ai/openai" "hairy-botter/internal/history" "hairy-botter/internal/rag" "hairy-botter/internal/server" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/genkit" ) @@ -102,17 +104,48 @@ func main() { searchEnable = false } - // Initialize the Gemini AI logic + aiProvider := os.Getenv("AI_PROVIDER") + if aiProvider == "" { + aiProvider = "gemini" + } + + plugins := make([]api.Plugin, 0) + + // Initialize the Gemini AI logic (Always needed for embedding and summarization) ga := gemini.ConfigPlugin(geminiKey) - g := genkit.Init(context.Background(), genkit.WithPlugins(ga)) + plugins = append(plugins, ga) - model, err := gemini.ConfigModel(g, ga, os.Getenv("GEMINI_MODEL")) + var oai openai.AgentConfigurator + if aiProvider == "openai" { + oaiKey := os.Getenv("OPENAI_API_KEY") + if oaiKey == "" { + logger.Error("OPENAI_API_KEY is not set but AI_PROVIDER is openai") + return + } + oaiBaseURL := os.Getenv("OPENAI_BASE_URL") + oai = openai.ConfigPlugin(oaiKey, oaiBaseURL) + plugins = append(plugins, oai) + } + + g := genkit.Init(context.Background(), genkit.WithPlugins(plugins...)) + + geminiModel, err := gemini.ConfigModel(g, ga, os.Getenv("GEMINI_MODEL")) if err != nil { - logger.Error("failed to define model", slog.String("err", err.Error())) + logger.Error("failed to define Gemini model", slog.String("err", err.Error())) return } - customModelConfig := gemini.CustomConfig(searchEnable) + + var activeModel ai.Model + var customModelConfig any + + if aiProvider == "openai" { + activeModel = openai.ConfigModel(oai, os.Getenv("OPENAI_MODEL")) + customModelConfig = nil // No custom config for OpenAI for now + } else { + activeModel = geminiModel + customModelConfig = gemini.CustomConfig(searchEnable) + } // TODO: Make a better, more separated embedder config embedder, err := ga.DefineEmbedder(g, "gemini-embedding-001", &ai.EmbedderOptions{}) @@ -134,11 +167,11 @@ func main() { HistorySummary: historySummary, Summarizer: &genkitSummarizer{ g: g, - model: model, + model: geminiModel, }, }) - aiLogic, err := agent.New(logger, g, model, hist, mcpClientAddrs, ragL, customModelConfig) + aiLogic, err := agent.New(logger, g, activeModel, hist, mcpClientAddrs, ragL, customModelConfig) if err != nil { logger.Error("failed to create AI logic", slog.String("err", err.Error())) diff --git a/go.mod b/go.mod index c743b34..8863ca1 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,13 @@ go 1.24.3 require ( github.com/briandowns/spinner v1.23.2 - github.com/firebase/genkit/go v1.4.0 + github.com/firebase/genkit/go v1.5.0 github.com/go-chi/chi/v5 v5.2.1 github.com/go-telegram/bot v1.17.0 github.com/mark3labs/mcp-go v0.29.1-0.20250521213157-f99e5472f312 + github.com/openai/openai-go v1.8.2 github.com/philippgille/chromem-go v0.7.0 + google.golang.org/genai v1.51.0 ) require ( @@ -17,7 +19,8 @@ require ( cloud.google.com/go/compute/metadata v0.7.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/fatih/color v1.7.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fatih/color v1.17.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -31,10 +34,15 @@ require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect - github.com/mattn/go-colorable v0.1.2 // indirect - github.com/mattn/go-isatty v0.0.8 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/spf13/cast v1.7.1 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect @@ -51,7 +59,6 @@ require ( golang.org/x/sys v0.34.0 // indirect golang.org/x/term v0.33.0 // indirect golang.org/x/text v0.27.0 // indirect - google.golang.org/genai v1.51.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/grpc v1.73.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/go.sum b/go.sum index 10489c2..44718b8 100644 --- a/go.sum +++ b/go.sum @@ -13,12 +13,12 @@ github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= +github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/firebase/genkit/go v1.4.0 h1:CP1hNWk7z0hosyY53zMH6MFKFO1fMLtj58jGPllQo6I= -github.com/firebase/genkit/go v1.4.0/go.mod h1:HX6m7QOaGc3MDNr/DrpQZrzPLzxeuLxrkTvfFtCYlGw= +github.com/firebase/genkit/go v1.5.0 h1:GovQzZy11bwsNBWDfAy9vPpxGBIsRGIk5bnU5N8eDxk= +github.com/firebase/genkit/go v1.5.0/go.mod h1:vu8ZAqNU6MU5qDza66bvqTtzJoUrqhO/+z5/6dtouJQ= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= @@ -58,12 +58,15 @@ github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4 github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.29.1-0.20250521213157-f99e5472f312 h1:0N4N+5c2sgIIcxjaEWUCCAhNCR3LvHQF3VvhadFniuk= github.com/mark3labs/mcp-go v0.29.1-0.20250521213157-f99e5472f312/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= -github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= +github.com/openai/openai-go v1.8.2 h1:UqSkJ1vCOPUpz9Ka5tS0324EJFEuOvMc+lA/EarJWP8= +github.com/openai/openai-go v1.8.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -77,6 +80,16 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -110,15 +123,14 @@ golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= -google.golang.org/genai v1.41.0 h1:ayXl75LjTmqTu0y94yr96d17gIb4zF8gWVzX2TgioEY= -google.golang.org/genai v1.41.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg= google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= diff --git a/internal/ai/openai/config.go b/internal/ai/openai/config.go new file mode 100644 index 0000000..6c13719 --- /dev/null +++ b/internal/ai/openai/config.go @@ -0,0 +1,46 @@ +package openai + +import ( + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/plugins/compat_oai" + "github.com/openai/openai-go/option" +) + +// AgentConfigurator . +type AgentConfigurator interface { + api.Plugin + modelDefiner +} + +type modelDefiner interface { + DefineModel(provider string, name string, opts ai.ModelOptions) ai.Model +} + +// ConfigPlugin . +func ConfigPlugin(apiKey string, baseURL string) AgentConfigurator { + opts := []option.RequestOption{ + option.WithAPIKey(apiKey), + } + if baseURL != "" { + opts = append(opts, option.WithBaseURL(baseURL)) + } + + return &compat_oai.OpenAICompatible{ + Provider: "openai", + Opts: opts, + } +} + +// ConfigModel . +func ConfigModel(ga modelDefiner, modelName string) ai.Model { + if modelName == "" { + modelName = "gpt-4o-mini" // Default to gpt-4o-mini + } + + return ga.DefineModel("openai", modelName, ai.ModelOptions{ + Label: "OpenAI Model", + Versions: []string{}, + Stage: ai.ModelStageUnstable, + }) +}