diff --git a/src/cai/sdk/agents/models/openai_chatcompletions.py b/src/cai/sdk/agents/models/openai_chatcompletions.py index 8931edd63..1681e5db4 100644 --- a/src/cai/sdk/agents/models/openai_chatcompletions.py +++ b/src/cai/sdk/agents/models/openai_chatcompletions.py @@ -2706,6 +2706,28 @@ async def _fetch_response( elif "/" in model_str: # Handle provider/model format provider = model_str.split("/")[0] + explicit_openai_compatible_base = ( + os.getenv("OLLAMA_API_BASE") + or os.getenv("OPENAI_API_BASE") + or os.getenv("OPENAI_BASE_URL") + ) + + # Route unknown provider/model prefixes through the user-configured + # OpenAI-compatible base instead of letting LiteLLM guess a provider + # from the left-hand side of the model id. + if explicit_openai_compatible_base and provider not in { + "ollama_cloud", + "deepseek", + "claude", + "gemini", + }: + litellm.drop_params = True + kwargs["api_base"] = get_ollama_api_base() + kwargs["custom_llm_provider"] = "openai" + kwargs.pop("parallel_tool_calls", None) + kwargs.pop("store", None) + if not converted_tools: + kwargs.pop("tool_choice", None) # Apply provider-specific configurations if provider == "ollama_cloud": diff --git a/src/cai/util.py b/src/cai/util.py index b5acca68c..8108161c2 100644 --- a/src/cai/util.py +++ b/src/cai/util.py @@ -750,21 +750,18 @@ def process_total_cost( def get_ollama_api_base(): - """Get the Ollama API base URL from environment variable or default to localhost:8000. + """Get the OpenAI-compatible base URL for Ollama Cloud, local Ollama, or custom gateways. - Supports both: - - OLLAMA_API_BASE: For local Ollama instances (e.g., http://localhost:8000/v1) - - OPENAI_BASE_URL: For Ollama Cloud or other OpenAI-compatible services (e.g., https://ollama.com/api/v1) + Priority: + - OLLAMA_API_BASE: explicit Ollama/local gateway override + - OPENAI_API_BASE: OpenAI-compatible local/custom gateway override + - OPENAI_BASE_URL: generic OpenAI-compatible base URL + - fallback: local Ollama default """ - # First check OLLAMA_API_BASE for local Ollama - ollama_base = os.environ.get("OLLAMA_API_BASE") - if ollama_base: - return ollama_base - - # Then check OPENAI_BASE_URL for Ollama Cloud or other services - openai_base = os.environ.get("OPENAI_BASE_URL") - if openai_base and "ollama.com" in openai_base: - return openai_base + for env_var in ("OLLAMA_API_BASE", "OPENAI_API_BASE", "OPENAI_BASE_URL"): + api_base = os.environ.get(env_var) + if api_base: + return api_base # Default to local Ollama return "http://localhost:8000/v1" diff --git a/tests/core/test_openai_chatcompletions.py b/tests/core/test_openai_chatcompletions.py index 7c8ab0394..0b5199b7c 100644 --- a/tests/core/test_openai_chatcompletions.py +++ b/tests/core/test_openai_chatcompletions.py @@ -4,6 +4,7 @@ from typing import Any import httpx +import litellm import pytest from openai import NOT_GIVEN from openai.types.chat.chat_completion import ChatCompletion, Choice @@ -31,6 +32,7 @@ generation_span, ) from cai.sdk.agents.models.fake_id import FAKE_RESPONSES_ID +from cai.util import get_ollama_api_base import os cai_model = os.getenv('CAI_MODEL', "qwen2.5:14b") @@ -360,4 +362,77 @@ async def patched_fetch_response(self, *args, **kwargs): ) # Counter should now be 2 (one increment per turn, not per item) - assert model.interaction_counter == 2 \ No newline at end of file + assert model.interaction_counter == 2 + + +def test_get_ollama_api_base_prefers_openai_compatible_env_vars(monkeypatch) -> None: + with monkeypatch.context() as m: + m.delenv("OLLAMA_API_BASE", raising=False) + m.setenv("OPENAI_API_BASE", "http://127.0.0.1:8080/v1") + m.setenv("OPENAI_BASE_URL", "https://example.invalid/v1") + assert get_ollama_api_base() == "http://127.0.0.1:8080/v1" + + with monkeypatch.context() as m: + m.delenv("OLLAMA_API_BASE", raising=False) + m.delenv("OPENAI_API_BASE", raising=False) + m.setenv("OPENAI_BASE_URL", "https://gateway.example/v1") + assert get_ollama_api_base() == "https://gateway.example/v1" + + +@pytest.mark.asyncio +async def test_fetch_response_routes_unknown_prefixed_models_to_openai_compatible_base( + monkeypatch, +) -> None: + class DummyCompletions: + def __init__(self) -> None: + self.kwargs: dict[str, Any] = {} + + async def create(self, **kwargs: Any) -> Any: + self.kwargs = kwargs + return chat + + class DummyClient: + def __init__(self, completions: DummyCompletions) -> None: + self.chat = type("_Chat", (), {"completions": completions})() + self.base_url = httpx.URL("http://fake") + + msg = ChatCompletionMessage(role="assistant", content="gateway ok") + choice = Choice(index=0, finish_reason="stop", message=msg) + chat = ChatCompletion( + id="resp-id", + created=0, + model="acme/custom-1.1", + object="chat.completion", + choices=[choice], + ) + completions = DummyCompletions() + dummy_client = DummyClient(completions) + captured: dict[str, Any] = {} + + async def fake_acompletion(**kwargs: Any) -> Any: + captured.update(kwargs) + return chat + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + monkeypatch.setenv("OPENAI_API_BASE", "http://127.0.0.1:9999/v1") + monkeypatch.delenv("OLLAMA_API_BASE", raising=False) + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + + model = OpenAIChatCompletionsModel(model="acme/custom-1.1", openai_client=dummy_client) # type: ignore[arg-type] + with generation_span(disabled=True) as span: + result = await model._fetch_response( + system_instructions="sys", + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + span=span, + tracing=ModelTracing.DISABLED, + stream=False, + ) + + assert result is chat + assert captured["model"] == "acme/custom-1.1" + assert captured["api_base"] == "http://127.0.0.1:9999/v1" + assert captured["custom_llm_provider"] == "openai"