diff --git a/examples/01_standalone_sdk/49_switch_llm_tool.py b/examples/01_standalone_sdk/49_switch_llm_tool.py new file mode 100644 index 0000000000..28f414cabe --- /dev/null +++ b/examples/01_standalone_sdk/49_switch_llm_tool.py @@ -0,0 +1,81 @@ +"""Switch LLM profiles with the built-in switch_llm tool. + +This example creates two temporary LLM profiles, starts the conversation on a +GPT profile, asks the agent to call the switch_llm tool, and then verifies that +future model calls use the Claude profile. + +Usage: + LLM_API_KEY=... LLM_BASE_URL=https://llm-proxy.app.all-hands.dev \ + uv run python examples/01_standalone_sdk/49_switch_llm_tool.py +""" + +import os + +from pydantic import SecretStr + +from openhands.sdk import LLM, Agent, LocalConversation +from openhands.sdk.llm.llm_profile_store import LLMProfileStore + + +GPT_PROFILE = "example-gpt55" +CLAUDE_PROFILE = "example-claude" +DEFAULT_BASE_URL = "https://llm-proxy.app.all-hands.dev" +GPT_MODEL = "openai/gpt-5.5" +CLAUDE_MODEL = "openai/prod/claude-sonnet-4-5-20250929" + +api_key = os.getenv("LLM_API_KEY") +assert api_key is not None, "LLM_API_KEY environment variable is not set." +base_url = os.getenv("LLM_BASE_URL", DEFAULT_BASE_URL) + +store = LLMProfileStore() +store.save( + GPT_PROFILE, + LLM( + model=GPT_MODEL, + api_key=SecretStr(api_key), + base_url=base_url, + usage_id="gpt55", + ), + include_secrets=True, +) +store.save( + CLAUDE_PROFILE, + LLM( + model=CLAUDE_MODEL, + api_key=SecretStr(api_key), + base_url=base_url, + usage_id="claude", + ), + include_secrets=True, +) + +try: + initial_llm = store.load(GPT_PROFILE) + agent = Agent( + llm=initial_llm, + tools=[], + include_default_tools=["FinishTool", "SwitchLLMTool"], + ) + conversation = LocalConversation(agent=agent, workspace=os.getcwd()) + + print(f"Starting model: {conversation.agent.llm.model}") + conversation.send_message( + f"Call the switch_llm tool now with profile_name={CLAUDE_PROFILE!r}. " + "After the tool succeeds, answer in one short sentence naming the " + "active model value from the tool observation exactly." + ) + conversation.run() + + active_model = conversation.agent.llm.model + print(f"Active model after tool switch: {active_model}") + assert active_model == CLAUDE_MODEL + + for usage_id, metrics in conversation.state.stats.usage_to_metrics.items(): + print(f" [{usage_id}] cost=${metrics.accumulated_cost:.6f}") + + combined = conversation.state.stats.get_combined_metrics() + print(f"Total cost: ${combined.accumulated_cost:.6f}") + print(f"EXAMPLE_COST: {combined.accumulated_cost}") +finally: + store.delete(GPT_PROFILE) + store.delete(CLAUDE_PROFILE) diff --git a/openhands-sdk/openhands/sdk/tool/builtins/__init__.py b/openhands-sdk/openhands/sdk/tool/builtins/__init__.py index 8c1dca5eaf..15dbf75e67 100644 --- a/openhands-sdk/openhands/sdk/tool/builtins/__init__.py +++ b/openhands-sdk/openhands/sdk/tool/builtins/__init__.py @@ -17,6 +17,12 @@ InvokeSkillObservation, InvokeSkillTool, ) +from openhands.sdk.tool.builtins.switch_llm import ( + SwitchLLMAction, + SwitchLLMExecutor, + SwitchLLMObservation, + SwitchLLMTool, +) from openhands.sdk.tool.builtins.think import ( ThinkAction, ThinkExecutor, @@ -30,12 +36,13 @@ # AgentSkills-format skill is loaded (see BUILT_IN_TOOL_CLASSES below). BUILT_IN_TOOLS = [FinishTool, ThinkTool] -# Map of built-in tool class names to their classes. Includes -# `InvokeSkillTool` so it can be resolved by name from `include_default_tools` -# and the conditional wiring in `Agent._initialize`. +# Map of built-in tool class names to their classes. Includes optional built-ins +# so they can be resolved by name from `include_default_tools` and the +# conditional wiring in `Agent._initialize`. BUILT_IN_TOOL_CLASSES = { **{tool.__name__: tool for tool in BUILT_IN_TOOLS}, InvokeSkillTool.__name__: InvokeSkillTool, + SwitchLLMTool.__name__: SwitchLLMTool, } __all__ = [ @@ -49,6 +56,10 @@ "InvokeSkillAction", "InvokeSkillObservation", "InvokeSkillExecutor", + "SwitchLLMTool", + "SwitchLLMAction", + "SwitchLLMObservation", + "SwitchLLMExecutor", "ThinkTool", "ThinkAction", "ThinkObservation", diff --git a/openhands-sdk/openhands/sdk/tool/builtins/switch_llm.py b/openhands-sdk/openhands/sdk/tool/builtins/switch_llm.py new file mode 100644 index 0000000000..5338f4baa5 --- /dev/null +++ b/openhands-sdk/openhands/sdk/tool/builtins/switch_llm.py @@ -0,0 +1,176 @@ +from collections.abc import Sequence +from typing import TYPE_CHECKING, Self + +from pydantic import Field +from rich.text import Text + +from openhands.sdk.llm.llm_profile_store import LLMProfileStore +from openhands.sdk.tool.tool import ( + Action, + Observation, + ToolAnnotations, + ToolDefinition, + ToolExecutor, +) + + +if TYPE_CHECKING: + from openhands.sdk.conversation.impl.local_conversation import LocalConversation + from openhands.sdk.conversation.state import ConversationState + + +class SwitchLLMAction(Action): + """Action for switching this conversation to a saved LLM profile.""" + + profile_name: str = Field( + description="Name of the saved LLM profile to use for future agent steps." + ) + reason: str = Field( + description="Brief reason why this profile is a better fit for the next step." + ) + + @property + def visualize(self) -> Text: + content = Text() + content.append("Switch LLM profile: ", style="bold magenta") + content.append(self.profile_name) + if self.reason: + content.append("\nReason: ", style="bold") + content.append(self.reason) + return content + + +class SwitchLLMObservation(Observation): + """Observation returned after switching this conversation's LLM profile.""" + + profile_name: str = Field( + description="Name of the profile that the tool attempted to activate." + ) + reason: str | None = Field( + default=None, + description="Reason the agent gave for attempting this LLM profile switch.", + ) + active_model: str | None = Field( + default=None, + description="Model configured by the activated profile, when available.", + ) + + @property + def visualize(self) -> Text: + content = Text() + if self.is_error: + content.append("Failed to switch LLM profile", style="bold red") + else: + content.append("Switched LLM profile", style="bold green") + content.append(f": {self.profile_name}") + if self.active_model: + content.append(f" ({self.active_model})") + if self.reason: + content.append("\nReason: ", style="bold") + content.append(self.reason) + return content + + +_DESCRIPTION_TEMPLATE = ( + "Switch this conversation to a saved LLM profile.\n\n" + "Use this when another available profile is better suited for the next step. " + "The current tool call is still executed by the current model; the switch " + "takes effect on the next LLM call.\n\n" + "Available LLM profiles:\n" + "{profiles}\n\n" + "Provide the profile_name exactly as listed and include a concise reason " + "for the switch." +) + + +def _format_profiles(profile_names: Sequence[str]) -> str: + if not profile_names: + return "- No saved LLM profiles are currently available." + return "\n".join(f"- {name}" for name in sorted(profile_names)) + + +class SwitchLLMExecutor(ToolExecutor): + def __call__( + self, + action: SwitchLLMAction, + conversation: "LocalConversation | None" = None, + ) -> SwitchLLMObservation: + if conversation is None: + return SwitchLLMObservation.from_text( + text="Cannot switch LLM profile without an active conversation.", + is_error=True, + profile_name=action.profile_name, + reason=action.reason, + ) + + try: + conversation.switch_profile(action.profile_name) + except FileNotFoundError: + return SwitchLLMObservation.from_text( + text=f"LLM profile '{action.profile_name}' was not found.", + is_error=True, + profile_name=action.profile_name, + reason=action.reason, + ) + except ValueError as exc: + return SwitchLLMObservation.from_text( + text=str(exc), + is_error=True, + profile_name=action.profile_name, + reason=action.reason, + ) + except Exception as exc: + return SwitchLLMObservation.from_text( + text=( + f"Failed to switch LLM profile '{action.profile_name}': " + f"{type(exc).__name__}: {exc}" + ), + is_error=True, + profile_name=action.profile_name, + reason=action.reason, + ) + + active_model = conversation.agent.llm.model + return SwitchLLMObservation.from_text( + text=( + f"Switched LLM profile to '{action.profile_name}' " + f"with active model '{active_model}'. Reason: {action.reason} " + "Future agent steps will use this profile." + ), + profile_name=action.profile_name, + reason=action.reason, + active_model=active_model, + ) + + +class SwitchLLMTool(ToolDefinition[SwitchLLMAction, SwitchLLMObservation]): + """Tool for switching a conversation to a saved LLM profile.""" + + @classmethod + def create( + cls, + conv_state: "ConversationState | None" = None, # noqa: ARG003 + **params, + ) -> Sequence[Self]: + if params: + raise ValueError("SwitchLLMTool doesn't accept parameters") + + profile_names = [ + name.removesuffix(".json") for name in LLMProfileStore().list() + ] + return [ + cls( + description=_DESCRIPTION_TEMPLATE.format( + profiles=_format_profiles(profile_names) + ), + action_type=SwitchLLMAction, + observation_type=SwitchLLMObservation, + executor=SwitchLLMExecutor(), + annotations=ToolAnnotations( + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + ) + ] diff --git a/tests/sdk/tool/test_switch_llm.py b/tests/sdk/tool/test_switch_llm.py new file mode 100644 index 0000000000..c951eb7101 --- /dev/null +++ b/tests/sdk/tool/test_switch_llm.py @@ -0,0 +1,113 @@ +from pathlib import Path + +import pytest + +from openhands.sdk import LLM, LocalConversation +from openhands.sdk.agent import Agent +from openhands.sdk.llm import llm_profile_store +from openhands.sdk.llm.llm_profile_store import LLMProfileStore +from openhands.sdk.testing import TestLLM +from openhands.sdk.tool.builtins import ( + SwitchLLMAction, + SwitchLLMObservation, + SwitchLLMTool, +) + + +def _make_llm(model: str, usage_id: str) -> LLM: + return TestLLM.from_messages([], model=model, usage_id=usage_id) + + +@pytest.fixture() +def profile_store(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LLMProfileStore: + profile_dir = tmp_path / "profiles" + profile_dir.mkdir() + monkeypatch.setattr(llm_profile_store, "_DEFAULT_PROFILE_DIR", profile_dir) + + store = LLMProfileStore(base_dir=profile_dir) + store.save("fast", _make_llm("fast-model", "fast")) + store.save("slow", _make_llm("slow-model", "slow")) + return store + + +def _make_conversation() -> LocalConversation: + return LocalConversation( + agent=Agent( + llm=_make_llm("default-model", "default"), + tools=[], + include_default_tools=["SwitchLLMTool"], + ), + workspace=Path.cwd(), + ) + + +def test_switch_llm_tool_description_lists_available_profiles(profile_store): + tool = SwitchLLMTool.create()[0] + + assert "Available LLM profiles:" in tool.description + assert "- fast" in tool.description + assert "- slow" in tool.description + + +def test_switch_llm_tool_switches_conversation_profile(profile_store): + conversation = _make_conversation() + + observation = conversation.execute_tool( + "switch_llm", + SwitchLLMAction(profile_name="fast", reason="Need a faster profile."), + ) + + assert isinstance(observation, SwitchLLMObservation) + assert not observation.is_error + assert observation.profile_name == "fast" + assert observation.reason == "Need a faster profile." + assert observation.active_model == "fast-model" + assert "active model 'fast-model'" in observation.text + assert "Reason: Need a faster profile." in observation.text + assert "Need a faster profile." in observation.visualize.plain + assert conversation.agent.llm.model == "fast-model" + assert conversation.state.agent.llm.model == "fast-model" + + +def test_switch_llm_tool_reports_missing_profile(profile_store): + conversation = _make_conversation() + + observation = conversation.execute_tool( + "switch_llm", + SwitchLLMAction(profile_name="missing", reason="Try another model."), + ) + + assert isinstance(observation, SwitchLLMObservation) + assert observation.is_error + assert observation.profile_name == "missing" + assert observation.reason == "Try another model." + assert observation.active_model is None + assert "was not found" in observation.text + assert conversation.agent.llm.model == "default-model" + assert conversation.state.agent.llm.model == "default-model" + + +def test_switch_llm_tool_reports_unexpected_profile_load_error( + profile_store, monkeypatch: pytest.MonkeyPatch +): + conversation = _make_conversation() + + def _raise_permission_error(profile_name: str) -> None: + raise PermissionError(f"Cannot read {profile_name}") + + monkeypatch.setattr(conversation, "switch_profile", _raise_permission_error) + + observation = conversation.execute_tool( + "switch_llm", + SwitchLLMAction(profile_name="fast", reason="Need access to Claude."), + ) + + assert isinstance(observation, SwitchLLMObservation) + assert observation.is_error + assert observation.profile_name == "fast" + assert observation.reason == "Need access to Claude." + assert observation.active_model is None + assert "PermissionError" in observation.text + assert "Cannot read fast" in observation.text + assert conversation.agent.llm.model == "default-model" + assert conversation.state.agent.llm.model == "default-model"