diff --git a/README.md b/README.md index a3ab108c6b..9e282b3e10 100644 --- a/README.md +++ b/README.md @@ -176,9 +176,10 @@ Also see [architecture](docs/ARCHITECTURE.md). | Provider | Status | Provider | Status | |----------|--------|----------|--------| | OpenAI | ✅ | Azure OpenAI | ✅ | -| Anthropic Claude | ✅ | Google Gemini | ✅ | -| AWS Bedrock | ✅ | Mistral AI | ✅ | -| Ollama (local) | ✅ | Anyscale | ✅ | +| OpenAI Compatible | ✅ | Anthropic Claude | ✅ | +| AWS Bedrock | ✅ | Google Gemini | ✅ | +| Ollama (local) | ✅ | Mistral AI | ✅ | +| Anyscale | ✅ | | | ### Vector Databases diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py index 8ad721c3d4..ed24cc77a9 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py @@ -225,6 +225,27 @@ def validate_model(adapter_metadata: dict[str, "Any"]) -> str: return f"openai/{model}" +class OpenAICompatibleLLMParameters(BaseChatCompletionParameters): + """See https://docs.litellm.ai/docs/providers/openai_compatible/.""" + + api_key: str | None = None + api_base: str + + @staticmethod + def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]: + adapter_metadata["model"] = OpenAICompatibleLLMParameters.validate_model( + adapter_metadata + ) + return OpenAICompatibleLLMParameters(**adapter_metadata).model_dump() + + @staticmethod + def validate_model(adapter_metadata: dict[str, "Any"]) -> str: + model = adapter_metadata.get("model", "") + if model.startswith("custom_openai/"): + return model + return f"custom_openai/{model}" + + class AzureOpenAILLMParameters(BaseChatCompletionParameters): """See https://docs.litellm.ai/docs/providers/azure/#completion---using-azure_ad_token-api_base-api_version.""" diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py index c23a33390a..1da3590f51 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py @@ -8,6 +8,7 @@ from unstract.sdk1.adapters.llm1.bedrock import AWSBedrockLLMAdapter from unstract.sdk1.adapters.llm1.ollama import OllamaLLMAdapter from unstract.sdk1.adapters.llm1.openai import OpenAILLMAdapter +from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter from unstract.sdk1.adapters.llm1.vertexai import VertexAILLMAdapter adapters: dict[str, dict[str, Any]] = {} @@ -22,5 +23,6 @@ "AzureOpenAILLMAdapter", "OllamaLLMAdapter", "OpenAILLMAdapter", + "OpenAICompatibleLLMAdapter", "VertexAILLMAdapter", ] diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py new file mode 100644 index 0000000000..1ed942ba10 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py @@ -0,0 +1,40 @@ +from typing import Any + +from unstract.sdk1.adapters.base1 import BaseAdapter, OpenAICompatibleLLMParameters +from unstract.sdk1.adapters.enums import AdapterTypes + + +class OpenAICompatibleLLMAdapter(OpenAICompatibleLLMParameters, BaseAdapter): + @staticmethod + def get_id() -> str: + return "openaicompatible|b6d10f33-2c41-49fc-a8c2-58d2b247fc09" + + @staticmethod + def get_metadata() -> dict[str, Any]: + return { + "name": "OpenAI Compatible", + "version": "1.0.0", + "adapter": OpenAICompatibleLLMAdapter, + "description": "OpenAI-compatible LLM adapter", + "is_active": True, + } + + @staticmethod + def get_name() -> str: + return "OpenAI Compatible" + + @staticmethod + def get_description() -> str: + return "OpenAI-compatible LLM adapter" + + @staticmethod + def get_provider() -> str: + return "custom_openai" + + @staticmethod + def get_icon() -> str: + return "/icons/adapter-icons/OpenAI.png" + + @staticmethod + def get_adapter_type() -> AdapterTypes: + return AdapterTypes.LLM diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json new file mode 100644 index 0000000000..00f629b41e --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json @@ -0,0 +1,62 @@ +{ + "title": "OpenAI Compatible LLM", + "type": "object", + "required": [ + "adapter_name", + "api_base" + ], + "properties": { + "adapter_name": { + "type": "string", + "title": "Name", + "default": "", + "description": "Provide a unique name for this adapter instance. Example: compatible-gateway-1" + }, + "api_key": { + "type": [ + "string", + "null" + ], + "title": "API Key", + "format": "password", + "description": "API key for your OpenAI-compatible endpoint. Leave empty if the endpoint does not require one." + }, + "model": { + "type": "string", + "title": "Model", + "default": "gpt-4o-mini", + "description": "The model name expected by your OpenAI-compatible endpoint. Examples: gpt-4o-mini, ERNIE-4.0-8K (Baidu Qianfan), qwen-max, openai/gpt-4o" + }, + "api_base": { + "type": "string", + "format": "url", + "title": "API Base", + "default": "https://your-endpoint.example.com/v1", + "description": "Base URL for the OpenAI-compatible endpoint. Examples: https://your-endpoint.example.com/v1, https://qianfan.baidubce.com/v2" + }, + "max_tokens": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Maximum Output Tokens", + "default": 4096, + "description": "Maximum number of output tokens to limit LLM replies. Leave it empty to use the provider default." + }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." + }, + "timeout": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Timeout", + "default": 900, + "description": "Timeout in seconds." + } + } +} diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index 8ff29a89d5..e6a49a8bb1 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -539,10 +539,21 @@ def _record_usage( usage: Mapping[str, int] | None, llm_api: str, ) -> None: - prompt_tokens = token_counter(model=model, messages=messages) usage_data: Mapping[str, int] = usage or {} + prompt_tokens = usage_data.get("prompt_tokens") + if prompt_tokens is None: + try: + prompt_tokens = token_counter(model=model, messages=messages) + except Exception as e: + prompt_tokens = 0 + logger.warning( + "[sdk1][LLM][%s][%s] Failed to estimate prompt tokens: %s", + model, + llm_api, + e, + ) all_tokens = TokenCounterCompat( - prompt_tokens=usage_data.get("prompt_tokens", 0), + prompt_tokens=usage_data.get("prompt_tokens", prompt_tokens or 0), completion_tokens=usage_data.get("completion_tokens", 0), total_tokens=usage_data.get("total_tokens", 0), ) diff --git a/unstract/sdk1/tests/test_openai_compatible_adapter.py b/unstract/sdk1/tests/test_openai_compatible_adapter.py new file mode 100644 index 0000000000..1f58d636ec --- /dev/null +++ b/unstract/sdk1/tests/test_openai_compatible_adapter.py @@ -0,0 +1,117 @@ +import json +from functools import lru_cache +from importlib import import_module +from unittest.mock import MagicMock, patch + +from unstract.sdk1.adapters.base1 import OpenAICompatibleLLMParameters +from unstract.sdk1.adapters.constants import Common +from unstract.sdk1.adapters.llm1 import adapters +from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter + + +@lru_cache(maxsize=1) +def _load_llm_module() -> object: + import sys + from types import ModuleType + + with patch.dict( + sys.modules, + { + # Stub python-magic so importing LLM does not depend on libmagic + # being available in the test environment. + "magic": ModuleType("magic") + }, + ): + return import_module("unstract.sdk1.llm") + + +def _load_llm_class() -> type: + return _load_llm_module().LLM + + +def test_openai_compatible_adapter_is_registered() -> None: + adapter_id = OpenAICompatibleLLMAdapter.get_id() + + assert adapter_id in adapters + assert adapters[adapter_id][Common.MODULE] is OpenAICompatibleLLMAdapter + + +def test_openai_compatible_validate_prefixes_model() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "api_key": "test-key", + "model": "ERNIE-4.0-8K", + } + ) + + assert validated["model"] == "custom_openai/ERNIE-4.0-8K" + + +def test_openai_compatible_validate_preserves_prefixed_model() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "model": "custom_openai/openai/gpt-4o", + } + ) + + assert validated["model"] == "custom_openai/openai/gpt-4o" + assert validated["api_key"] is None + + +def test_openai_compatible_schema_is_loadable() -> None: + schema = json.loads(OpenAICompatibleLLMAdapter.get_json_schema()) + + assert schema["title"] == "OpenAI Compatible LLM" + assert schema["properties"]["api_key"]["type"] == ["string", "null"] + assert "ERNIE-4.0-8K" in schema["properties"]["model"]["description"] + + +def test_record_usage_uses_reported_prompt_tokens_without_estimating() -> None: + llm_module = _load_llm_module() + llm_cls = llm_module.LLM + + llm = llm_cls.__new__(llm_cls) + llm._platform_api_key = "platform-key" + llm.platform_kwargs = {"run_id": "run-1"} + llm.adapter = MagicMock() + llm.adapter.get_provider.return_value = "custom_openai" + + with ( + patch.object(llm_module, "token_counter") as mock_token_counter, + patch.object(llm_module, "Audit") as mock_audit, + ): + llm._record_usage( + model="custom_openai/ERNIE-4.0-8K", + messages=[{"role": "user", "content": "hello"}], + usage={"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}, + llm_api="complete", + ) + + mock_token_counter.assert_not_called() + mock_audit.return_value.push_usage_data.assert_called_once() + + +def test_record_usage_tolerates_unmapped_models_without_prompt_tokens() -> None: + llm_module = _load_llm_module() + llm_cls = llm_module.LLM + + llm = llm_cls.__new__(llm_cls) + llm._platform_api_key = "platform-key" + llm.platform_kwargs = {"run_id": "run-1"} + llm.adapter = MagicMock() + llm.adapter.get_provider.return_value = "custom_openai" + + with ( + patch.object(llm_module, "token_counter", side_effect=Exception("unmapped")), + patch.object(llm_module, "Audit") as mock_audit, + ): + llm._record_usage( + model="custom_openai/ERNIE-4.0-8K", + messages=[{"role": "user", "content": "hello"}], + usage={"completion_tokens": 4, "total_tokens": 7}, + llm_api="complete", + ) + + mock_audit.return_value.push_usage_data.assert_called_once()