Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/cai/sdk/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2706,6 +2706,28 @@ async def _fetch_response(
elif "/" in model_str:
# Handle provider/model format
provider = model_str.split("/")[0]
explicit_openai_compatible_base = (
os.getenv("OLLAMA_API_BASE")
or os.getenv("OPENAI_API_BASE")
or os.getenv("OPENAI_BASE_URL")
)

# Route unknown provider/model prefixes through the user-configured
# OpenAI-compatible base instead of letting LiteLLM guess a provider
# from the left-hand side of the model id.
if explicit_openai_compatible_base and provider not in {
"ollama_cloud",
"deepseek",
"claude",
"gemini",
}:
litellm.drop_params = True
kwargs["api_base"] = get_ollama_api_base()
kwargs["custom_llm_provider"] = "openai"
kwargs.pop("parallel_tool_calls", None)
kwargs.pop("store", None)
if not converted_tools:
kwargs.pop("tool_choice", None)

# Apply provider-specific configurations
if provider == "ollama_cloud":
Expand Down
23 changes: 10 additions & 13 deletions src/cai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,21 +750,18 @@ def process_total_cost(


def get_ollama_api_base():
"""Get the Ollama API base URL from environment variable or default to localhost:8000.
"""Get the OpenAI-compatible base URL for Ollama Cloud, local Ollama, or custom gateways.

Supports both:
- OLLAMA_API_BASE: For local Ollama instances (e.g., http://localhost:8000/v1)
- OPENAI_BASE_URL: For Ollama Cloud or other OpenAI-compatible services (e.g., https://ollama.com/api/v1)
Priority:
- OLLAMA_API_BASE: explicit Ollama/local gateway override
- OPENAI_API_BASE: OpenAI-compatible local/custom gateway override
- OPENAI_BASE_URL: generic OpenAI-compatible base URL
- fallback: local Ollama default
"""
# First check OLLAMA_API_BASE for local Ollama
ollama_base = os.environ.get("OLLAMA_API_BASE")
if ollama_base:
return ollama_base

# Then check OPENAI_BASE_URL for Ollama Cloud or other services
openai_base = os.environ.get("OPENAI_BASE_URL")
if openai_base and "ollama.com" in openai_base:
return openai_base
for env_var in ("OLLAMA_API_BASE", "OPENAI_API_BASE", "OPENAI_BASE_URL"):
api_base = os.environ.get(env_var)
if api_base:
return api_base

# Default to local Ollama
return "http://localhost:8000/v1"
Expand Down
77 changes: 76 additions & 1 deletion tests/core/test_openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

import httpx
import litellm
import pytest
from openai import NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion, Choice
Expand Down Expand Up @@ -31,6 +32,7 @@
generation_span,
)
from cai.sdk.agents.models.fake_id import FAKE_RESPONSES_ID
from cai.util import get_ollama_api_base
import os
cai_model = os.getenv('CAI_MODEL', "qwen2.5:14b")

Expand Down Expand Up @@ -360,4 +362,77 @@ async def patched_fetch_response(self, *args, **kwargs):
)

# Counter should now be 2 (one increment per turn, not per item)
assert model.interaction_counter == 2
assert model.interaction_counter == 2


def test_get_ollama_api_base_prefers_openai_compatible_env_vars(monkeypatch) -> None:
with monkeypatch.context() as m:
m.delenv("OLLAMA_API_BASE", raising=False)
m.setenv("OPENAI_API_BASE", "http://127.0.0.1:8080/v1")
m.setenv("OPENAI_BASE_URL", "https://example.invalid/v1")
assert get_ollama_api_base() == "http://127.0.0.1:8080/v1"

with monkeypatch.context() as m:
m.delenv("OLLAMA_API_BASE", raising=False)
m.delenv("OPENAI_API_BASE", raising=False)
m.setenv("OPENAI_BASE_URL", "https://gateway.example/v1")
assert get_ollama_api_base() == "https://gateway.example/v1"


@pytest.mark.asyncio
async def test_fetch_response_routes_unknown_prefixed_models_to_openai_compatible_base(
monkeypatch,
) -> None:
class DummyCompletions:
def __init__(self) -> None:
self.kwargs: dict[str, Any] = {}

async def create(self, **kwargs: Any) -> Any:
self.kwargs = kwargs
return chat

class DummyClient:
def __init__(self, completions: DummyCompletions) -> None:
self.chat = type("_Chat", (), {"completions": completions})()
self.base_url = httpx.URL("http://fake")

msg = ChatCompletionMessage(role="assistant", content="gateway ok")
choice = Choice(index=0, finish_reason="stop", message=msg)
chat = ChatCompletion(
id="resp-id",
created=0,
model="acme/custom-1.1",
object="chat.completion",
choices=[choice],
)
completions = DummyCompletions()
dummy_client = DummyClient(completions)
captured: dict[str, Any] = {}

async def fake_acompletion(**kwargs: Any) -> Any:
captured.update(kwargs)
return chat

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
monkeypatch.setenv("OPENAI_API_BASE", "http://127.0.0.1:9999/v1")
monkeypatch.delenv("OLLAMA_API_BASE", raising=False)
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)

model = OpenAIChatCompletionsModel(model="acme/custom-1.1", openai_client=dummy_client) # type: ignore[arg-type]
with generation_span(disabled=True) as span:
result = await model._fetch_response(
system_instructions="sys",
input="hi",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
span=span,
tracing=ModelTracing.DISABLED,
stream=False,
)

assert result is chat
assert captured["model"] == "acme/custom-1.1"
assert captured["api_base"] == "http://127.0.0.1:9999/v1"
assert captured["custom_llm_provider"] == "openai"