-
Notifications
You must be signed in to change notification settings - Fork 660
feat: add LiteLLM as unified LLM provider #3182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| import logging | ||
| from typing import Dict, List | ||
|
|
||
| from consts.const import DEFAULT_LLM_MAX_TOKENS | ||
| from services.providers.base import AbstractModelProvider, _classify_provider_error | ||
|
|
||
| logger = logging.getLogger("model_provider") | ||
|
|
||
|
|
||
| class LiteLLMModelProvider(AbstractModelProvider): | ||
| """Provider that discovers models via LiteLLM's /v1/models endpoint. | ||
|
|
||
| LiteLLM supports 100+ LLM providers (OpenAI, Anthropic, Google Gemini, | ||
| Azure, Bedrock, Ollama, etc.) through a unified interface. When pointed | ||
| at a LiteLLM proxy, this provider fetches the available model catalog. | ||
|
|
||
| For direct SDK usage (no proxy), users should add models manually with | ||
| the ``litellm`` provider and use LiteLLM model identifiers like | ||
| ``anthropic/claude-sonnet-4-20250514`` or ``gemini/gemini-2.5-flash``. | ||
| """ | ||
|
|
||
| async def get_models(self, provider_config: Dict) -> List[Dict]: | ||
| """ | ||
| Fetch models from a LiteLLM-compatible /v1/models endpoint. | ||
|
|
||
| Args: | ||
| provider_config: Configuration dict containing model_type, api_key, and base_url | ||
|
|
||
| Returns: | ||
| List of models with canonical fields. | ||
| """ | ||
| import httpx | ||
|
|
||
| try: | ||
| model_type: str = provider_config.get("model_type", "llm") | ||
| api_key: str = provider_config.get("api_key", "") | ||
| base_url: str = provider_config.get("base_url", "").rstrip("/") | ||
|
|
||
| if not base_url: | ||
| return [] | ||
|
|
||
| headers = {} | ||
| if api_key: | ||
| headers["Authorization"] = f"Bearer {api_key}" | ||
|
|
||
| models_url = f"{base_url}/models" | ||
|
|
||
| async with httpx.AsyncClient(verify=False, timeout=15.0) as client: | ||
|
Check failure on line 48 in backend/services/providers/litellm_provider.py
|
||
| response = await client.get(models_url, headers=headers) | ||
| response.raise_for_status() | ||
| data = response.json().get("data", []) | ||
|
|
||
| model_list = [] | ||
| for item in data: | ||
| model_id = item.get("id", "") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 所有通过 LiteLLM 发现的模型都使用同一个 |
||
| if not model_id: | ||
| continue | ||
|
|
||
| model_entry = { | ||
| "id": model_id, | ||
| "model_type": model_type, | ||
| "max_tokens": DEFAULT_LLM_MAX_TOKENS, | ||
| } | ||
|
|
||
| if model_type in ("llm", "vlm"): | ||
| model_entry["model_tag"] = "chat" | ||
| elif model_type in ("embedding", "multi_embedding"): | ||
| model_entry["model_tag"] = "embedding" | ||
| elif model_type == "rerank": | ||
| model_entry["model_tag"] = "rerank" | ||
|
|
||
| model_list.append(model_entry) | ||
|
|
||
| return model_list | ||
|
|
||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [代码规范] |
||
| return _classify_provider_error("LiteLLM", exception=e) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| """LiteLLM-backed LLM model for nexent. | ||
|
|
||
| Provides access to 100+ LLM providers (OpenAI, Anthropic, Google Gemini, | ||
| Azure, Bedrock, Ollama, etc.) through ``litellm.completion()`` as an SDK | ||
| dependency. Follows the same interface as ``OpenAIModel``. | ||
| """ | ||
|
|
||
| import logging | ||
| import threading | ||
| import time | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| from smolagents import Tool | ||
| from smolagents.models import ChatMessage, MessageRole | ||
|
|
||
| from ..utils.observer import MessageObserver, ProcessType | ||
|
|
||
| logger = logging.getLogger("litellm_llm") | ||
|
|
||
|
|
||
| class LiteLLMModel: | ||
| """LLM model backed by LiteLLM SDK. | ||
|
|
||
| Uses ``litellm.completion()`` directly, supporting any model identifier | ||
| that LiteLLM recognizes (e.g. ``anthropic/claude-sonnet-4-20250514``, | ||
| ``gemini/gemini-2.5-flash``, ``azure/gpt-4o``). | ||
|
|
||
| See https://docs.litellm.ai/docs/providers for the full provider list. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_id: str, | ||
| api_key: Optional[str] = None, | ||
| api_base: Optional[str] = None, | ||
| temperature: float = 0.2, | ||
| top_p: float = 0.95, | ||
| observer: MessageObserver = MessageObserver, | ||
| display_name: Optional[str] = None, | ||
| **kwargs: Any, | ||
| ): | ||
| self.model_id = model_id | ||
| self.api_key = api_key | ||
| self.api_base = api_base | ||
| self.temperature = temperature | ||
| self.top_p = top_p | ||
| self.observer = observer | ||
| self.display_name = display_name | ||
| self.stop_event = threading.Event() | ||
| self.last_input_token_count = 0 | ||
| self.last_output_token_count = 0 | ||
|
|
||
| def __call__( | ||
|
Check failure on line 53 in sdk/nexent/core/models/litellm_llm.py
|
||
| self, | ||
| messages: List[Dict[str, Any]], | ||
| stop_sequences: Optional[List[str]] = None, | ||
| response_format: Optional[Dict[str, str]] = None, | ||
| tools_to_call_from: Optional[List[Tool]] = None, | ||
| **kwargs: Any, | ||
| ) -> ChatMessage: | ||
| try: | ||
| import litellm | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "litellm is required for LiteLLMModel. " | ||
| "Install it with: pip install 'litellm>=1.80,<1.87'" | ||
| ) from e | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| # Normalize messages to dicts | ||
| normalized: List[Dict[str, Any]] = [] | ||
| for msg in messages or []: | ||
| if isinstance(msg, ChatMessage): | ||
| normalized.append({ | ||
| "role": msg.role.value if hasattr(msg.role, "value") else str(msg.role), | ||
| "content": msg.content, | ||
| }) | ||
| elif isinstance(msg, dict): | ||
| normalized.append(msg) | ||
| else: | ||
| raise TypeError("Messages must be ChatMessage or dict objects.") | ||
|
|
||
| completion_kwargs: Dict[str, Any] = { | ||
| "model": self.model_id, | ||
| "messages": normalized, | ||
| "temperature": self.temperature, | ||
| "stream": True, | ||
| "drop_params": True, | ||
| "stream_options": {"include_usage": True}, | ||
| } | ||
|
|
||
| if self.api_key: | ||
| completion_kwargs["api_key"] = self.api_key | ||
| if self.api_base: | ||
| completion_kwargs["api_base"] = self.api_base | ||
| if stop_sequences: | ||
| completion_kwargs["stop"] = stop_sequences | ||
| if response_format: | ||
| completion_kwargs["response_format"] = response_format | ||
|
|
||
| # Handle tool calling | ||
| if tools_to_call_from: | ||
| tool_definitions = [] | ||
| for tool in tools_to_call_from: | ||
| if hasattr(tool, "to_openai_tool"): | ||
| tool_definitions.append(tool.to_openai_tool()) | ||
| if tool_definitions: | ||
| completion_kwargs["tools"] = tool_definitions | ||
|
|
||
| current_request = litellm.completion(**completion_kwargs) | ||
|
|
||
| # Process streaming response | ||
| chunk_list = [] | ||
| token_join = [] | ||
| role = None | ||
|
|
||
| self.observer.current_mode = ProcessType.MODEL_OUTPUT_THINKING | ||
|
|
||
| try: | ||
| for chunk in current_request: | ||
| if not hasattr(chunk, "choices") or not chunk.choices: | ||
| chunk_list.append(chunk) | ||
| continue | ||
|
|
||
| delta = chunk.choices[0].delta | ||
| new_token = getattr(delta, "content", None) | ||
| reasoning_content = getattr(delta, "reasoning_content", None) | ||
|
|
||
| if reasoning_content is not None: | ||
| self.observer.add_model_reasoning_content(reasoning_content) | ||
|
|
||
| if new_token is not None: | ||
| self.observer.add_model_new_token(new_token) | ||
| token_join.append(new_token) | ||
| role = getattr(delta, "role", None) or role | ||
|
|
||
| chunk_list.append(chunk) | ||
| if self.stop_event.is_set(): | ||
| raise RuntimeError("Model is interrupted by stop event") | ||
|
|
||
| self.observer.flush_remaining_tokens() | ||
| model_output = "".join(token_join) | ||
|
|
||
| # Extract token usage from the last chunk | ||
| input_tokens = 0 | ||
| output_tokens = 0 | ||
| if chunk_list and hasattr(chunk_list[-1], "usage") and chunk_list[-1].usage is not None: | ||
| usage = chunk_list[-1].usage | ||
| input_tokens = getattr(usage, "prompt_tokens", 0) or 0 | ||
| output_tokens = getattr(usage, "completion_tokens", 0) or 0 | ||
|
|
||
| self.last_input_token_count = input_tokens | ||
| self.last_output_token_count = output_tokens | ||
|
|
||
| from openai.types.chat.chat_completion_message import ChatCompletionMessage | ||
|
|
||
| message = ChatMessage.from_dict( | ||
| ChatCompletionMessage( | ||
| role=role if role else "assistant", | ||
| content=model_output, | ||
| ).model_dump(include={"role", "content", "tool_calls"}) | ||
| ) | ||
|
|
||
| from smolagents.monitoring import TokenUsage | ||
|
|
||
| if input_tokens > 0 or output_tokens > 0: | ||
| message.token_usage = TokenUsage( | ||
| input_tokens=input_tokens, | ||
| output_tokens=output_tokens, | ||
| ) | ||
|
|
||
| message.raw = current_request | ||
| message.role = MessageRole.ASSISTANT | ||
| return message | ||
|
|
||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [代码规范] |
||
| if "context_length_exceeded" in str(e): | ||
| raise ValueError(f"Token limit exceeded: {str(e)}") | ||
| raise | ||
|
|
||
| async def check_connectivity(self) -> bool: | ||
| """Test if the LLM provider connection works.""" | ||
| try: | ||
| import litellm | ||
| import asyncio | ||
|
|
||
| kwargs: Dict[str, Any] = { | ||
| "model": self.model_id, | ||
| "messages": [{"role": "user", "content": "Hello"}], | ||
| "max_tokens": 5, | ||
| "drop_params": True, | ||
| } | ||
| if self.api_key: | ||
| kwargs["api_key"] = self.api_key | ||
| if self.api_base: | ||
| kwargs["api_base"] = self.api_base | ||
|
|
||
| await litellm.acompletion(**kwargs) | ||
| return True | ||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [代码规范] |
||
| logger.error(f"LiteLLM connectivity check failed: {e}") | ||
|
Check failure on line 200 in sdk/nexent/core/models/litellm_llm.py
|
||
| return False | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
安全风险:
verify=False禁用了 SSL 证书验证,生产环境下容易遭受 MITM 攻击。项目已有ssl_verify配置字段,应该从provider_config读取并默认为True。