Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ Also see [architecture](docs/ARCHITECTURE.md).
| Provider | Status | Provider | Status |
|----------|--------|----------|--------|
| OpenAI | ✅ | Azure OpenAI | ✅ |
| Anthropic Claude | ✅ | Google Gemini | ✅ |
| AWS Bedrock | ✅ | Mistral AI | ✅ |
| Ollama (local) | ✅ | Anyscale | ✅ |
| OpenAI Compatible | ✅ | Anthropic Claude | ✅ |
| AWS Bedrock | ✅ | Google Gemini | ✅ |
| Ollama (local) | ✅ | Mistral AI | ✅ |
| Anyscale | ✅ | | |

### Vector Databases

Expand Down
21 changes: 21 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/base1.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,27 @@ def validate_model(adapter_metadata: dict[str, "Any"]) -> str:
return f"openai/{model}"


class OpenAICompatibleLLMParameters(BaseChatCompletionParameters):
"""See https://docs.litellm.ai/docs/providers/openai_compatible/."""

api_key: str | None = None
api_base: str

@staticmethod
def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]:
adapter_metadata["model"] = OpenAICompatibleLLMParameters.validate_model(
adapter_metadata
)
return OpenAICompatibleLLMParameters(**adapter_metadata).model_dump()

@staticmethod
def validate_model(adapter_metadata: dict[str, "Any"]) -> str:
model = adapter_metadata.get("model", "")
if model.startswith("custom_openai/"):
return model
return f"custom_openai/{model}"


class AzureOpenAILLMParameters(BaseChatCompletionParameters):
"""See https://docs.litellm.ai/docs/providers/azure/#completion---using-azure_ad_token-api_base-api_version."""

Expand Down
2 changes: 2 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unstract.sdk1.adapters.llm1.bedrock import AWSBedrockLLMAdapter
from unstract.sdk1.adapters.llm1.ollama import OllamaLLMAdapter
from unstract.sdk1.adapters.llm1.openai import OpenAILLMAdapter
from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter
from unstract.sdk1.adapters.llm1.vertexai import VertexAILLMAdapter

adapters: dict[str, dict[str, Any]] = {}
Expand All @@ -22,5 +23,6 @@
"AzureOpenAILLMAdapter",
"OllamaLLMAdapter",
"OpenAILLMAdapter",
"OpenAICompatibleLLMAdapter",
"VertexAILLMAdapter",
]
40 changes: 40 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any

from unstract.sdk1.adapters.base1 import BaseAdapter, OpenAICompatibleLLMParameters
from unstract.sdk1.adapters.enums import AdapterTypes


class OpenAICompatibleLLMAdapter(OpenAICompatibleLLMParameters, BaseAdapter):
@staticmethod
def get_id() -> str:
return "openaicompatible|b6d10f33-2c41-49fc-a8c2-58d2b247fc09"

@staticmethod
def get_metadata() -> dict[str, Any]:
return {
"name": "OpenAI Compatible",
"version": "1.0.0",
"adapter": OpenAICompatibleLLMAdapter,
"description": "OpenAI-compatible LLM adapter",
"is_active": True,
}

@staticmethod
def get_name() -> str:
return "OpenAI Compatible"

@staticmethod
def get_description() -> str:
return "OpenAI-compatible LLM adapter"

@staticmethod
def get_provider() -> str:
return "custom_openai"

@staticmethod
def get_icon() -> str:
return "/icons/adapter-icons/OpenAI.png"

@staticmethod
def get_adapter_type() -> AdapterTypes:
return AdapterTypes.LLM
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
{
"title": "OpenAI Compatible LLM",
"type": "object",
"required": [
"adapter_name",
"api_base"
],
Comment thread
jimmyzhuu marked this conversation as resolved.
"properties": {
"adapter_name": {
"type": "string",
"title": "Name",
"default": "",
"description": "Provide a unique name for this adapter instance. Example: compatible-gateway-1"
},
"api_key": {
"type": [
"string",
"null"
],
"title": "API Key",
"format": "password",
"description": "API key for your OpenAI-compatible endpoint. Leave empty if the endpoint does not require one."
},
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"model": {
"type": "string",
"title": "Model",
"default": "gpt-4o-mini",
"description": "The model name expected by your OpenAI-compatible endpoint. Examples: gpt-4o-mini, ERNIE-4.0-8K (Baidu Qianfan), qwen-max, openai/gpt-4o"
},
"api_base": {
"type": "string",
"format": "url",
"title": "API Base",
"default": "https://your-endpoint.example.com/v1",
"description": "Base URL for the OpenAI-compatible endpoint. Examples: https://your-endpoint.example.com/v1, https://qianfan.baidubce.com/v2"
},
"max_tokens": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Maximum Output Tokens",
"default": 4096,
"description": "Maximum number of output tokens to limit LLM replies. Leave it empty to use the provider default."
},
"max_retries": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 5,
"description": "The maximum number of times to retry a request if it fails."
},
"timeout": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Timeout",
"default": 900,
"description": "Timeout in seconds."
}
}
}
15 changes: 13 additions & 2 deletions unstract/sdk1/src/unstract/sdk1/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,10 +539,21 @@ def _record_usage(
usage: Mapping[str, int] | None,
llm_api: str,
) -> None:
prompt_tokens = token_counter(model=model, messages=messages)
usage_data: Mapping[str, int] = usage or {}
prompt_tokens = usage_data.get("prompt_tokens")
if prompt_tokens is None:
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception as e:
prompt_tokens = 0
logger.warning(
"[sdk1][LLM][%s][%s] Failed to estimate prompt tokens: %s",
model,
llm_api,
e,
)
all_tokens = TokenCounterCompat(
prompt_tokens=usage_data.get("prompt_tokens", 0),
prompt_tokens=usage_data.get("prompt_tokens", prompt_tokens or 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
Expand Down
117 changes: 117 additions & 0 deletions unstract/sdk1/tests/test_openai_compatible_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
from functools import lru_cache
from importlib import import_module
from unittest.mock import MagicMock, patch

from unstract.sdk1.adapters.base1 import OpenAICompatibleLLMParameters
from unstract.sdk1.adapters.constants import Common
from unstract.sdk1.adapters.llm1 import adapters
from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter


@lru_cache(maxsize=1)
def _load_llm_module() -> object:
import sys
from types import ModuleType

with patch.dict(
sys.modules,
{
# Stub python-magic so importing LLM does not depend on libmagic
# being available in the test environment.
"magic": ModuleType("magic")
},
):
return import_module("unstract.sdk1.llm")


def _load_llm_class() -> type:
return _load_llm_module().LLM


def test_openai_compatible_adapter_is_registered() -> None:
adapter_id = OpenAICompatibleLLMAdapter.get_id()

assert adapter_id in adapters
assert adapters[adapter_id][Common.MODULE] is OpenAICompatibleLLMAdapter


def test_openai_compatible_validate_prefixes_model() -> None:
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://gateway.example.com/v1",
"api_key": "test-key",
"model": "ERNIE-4.0-8K",
}
)

assert validated["model"] == "custom_openai/ERNIE-4.0-8K"


def test_openai_compatible_validate_preserves_prefixed_model() -> None:
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://gateway.example.com/v1",
"model": "custom_openai/openai/gpt-4o",
}
)

assert validated["model"] == "custom_openai/openai/gpt-4o"
assert validated["api_key"] is None


def test_openai_compatible_schema_is_loadable() -> None:
schema = json.loads(OpenAICompatibleLLMAdapter.get_json_schema())

assert schema["title"] == "OpenAI Compatible LLM"
assert schema["properties"]["api_key"]["type"] == ["string", "null"]
assert "ERNIE-4.0-8K" in schema["properties"]["model"]["description"]


def test_record_usage_uses_reported_prompt_tokens_without_estimating() -> None:
llm_module = _load_llm_module()
llm_cls = llm_module.LLM

llm = llm_cls.__new__(llm_cls)
llm._platform_api_key = "platform-key"
llm.platform_kwargs = {"run_id": "run-1"}
llm.adapter = MagicMock()
llm.adapter.get_provider.return_value = "custom_openai"

Comment thread
greptile-apps[bot] marked this conversation as resolved.
with (
patch.object(llm_module, "token_counter") as mock_token_counter,
patch.object(llm_module, "Audit") as mock_audit,
):
llm._record_usage(
model="custom_openai/ERNIE-4.0-8K",
messages=[{"role": "user", "content": "hello"}],
usage={"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7},
llm_api="complete",
)

mock_token_counter.assert_not_called()
mock_audit.return_value.push_usage_data.assert_called_once()


def test_record_usage_tolerates_unmapped_models_without_prompt_tokens() -> None:
llm_module = _load_llm_module()
llm_cls = llm_module.LLM

llm = llm_cls.__new__(llm_cls)
llm._platform_api_key = "platform-key"
llm.platform_kwargs = {"run_id": "run-1"}
llm.adapter = MagicMock()
llm.adapter.get_provider.return_value = "custom_openai"

with (
patch.object(llm_module, "token_counter", side_effect=Exception("unmapped")),
patch.object(llm_module, "Audit") as mock_audit,
):
llm._record_usage(
model="custom_openai/ERNIE-4.0-8K",
messages=[{"role": "user", "content": "hello"}],
usage={"completion_tokens": 4, "total_tokens": 7},
llm_api="complete",
)

mock_audit.return_value.push_usage_data.assert_called_once()