diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 3bee201e0b..bf11af7ae1 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -992,40 +992,40 @@ async def acompletion( if tools and not use_native_fc: telemetry_ctx["raw_messages"] = original_fncall_msgs - resp: ModelResponse | None = None - async for attempt in self.async_retry( - num_retries=self.num_retries, - retry_exceptions=LLM_RETRY_EXCEPTIONS, - retry_min_wait=self.retry_min_wait, - retry_max_wait=self.retry_max_wait, - retry_multiplier=self.retry_multiplier, - retry_listener=self._retry_listener_fn, - ): - with attempt: - assert self._telemetry is not None - self._telemetry.on_request(telemetry_ctx=telemetry_ctx) - resp = await self._atransport_call( - messages=formatted_messages, - **call_kwargs, - enable_streaming=enable_streaming, - on_token=on_token, - ) - raw_resp: ModelResponse | None = None - if use_mock_tools: - raw_resp = copy.deepcopy(resp) - resp = self.post_response_prompt_mock( - resp, - nonfncall_msgs=formatted_messages, - tools=cc_tools, - include_security_params=add_security_risk_prediction, - ) - self._telemetry.on_response(resp, raw_resp=raw_resp) - if not resp.get("choices") or len(resp["choices"]) < 1: - raise LLMNoResponseError( - "Response choices is less than 1. Response: " + str(resp) + try: + resp: ModelResponse | None = None + async for attempt in self.async_retry( + num_retries=self.num_retries, + retry_exceptions=LLM_RETRY_EXCEPTIONS, + retry_min_wait=self.retry_min_wait, + retry_max_wait=self.retry_max_wait, + retry_multiplier=self.retry_multiplier, + retry_listener=self._retry_listener_fn, + ): + with attempt: + assert self._telemetry is not None + self._telemetry.on_request(telemetry_ctx=telemetry_ctx) + resp = await self._atransport_call( + messages=formatted_messages, + **call_kwargs, + enable_streaming=enable_streaming, + on_token=on_token, ) + raw_resp: ModelResponse | None = None + if use_mock_tools: + raw_resp = copy.deepcopy(resp) + resp = self.post_response_prompt_mock( + resp, + nonfncall_msgs=formatted_messages, + tools=cc_tools, + include_security_params=add_security_risk_prediction, + ) + self._telemetry.on_response(resp, raw_resp=raw_resp) + if not resp.get("choices") or len(resp["choices"]) < 1: + raise LLMNoResponseError( + "Response choices is less than 1. Response: " + str(resp) + ) - try: assert resp is not None first_choice = resp["choices"][0] message = Message.from_llm_chat_message(first_choice["message"]) @@ -1143,106 +1143,104 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: assert self._telemetry is not None self._telemetry.on_request(telemetry_ctx=telemetry_ctx) final_kwargs = {**call_kwargs, **retry_kwargs} - with self._litellm_modify_params_ctx(self.modify_params): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - typed_input: ResponseInputParam | str = ( - cast(ResponseInputParam, input_items) if input_items else "" - ) - api_key_value = self._get_litellm_api_key_value() - - ret = litellm_responses( - model=self.model, - input=typed_input, - instructions=instructions, - tools=resp_tools, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - **{**self._aws_kwargs(), **final_kwargs}, - ) - if isinstance(ret, ResponsesAPIResponse): - if user_enable_streaming: - logger.warning( - "Responses streaming was requested, but the provider " - "returned a non-streaming response; no on_token deltas " - "will be emitted." - ) - self._telemetry.on_response(ret) - return ret - - # When stream=True, LiteLLM returns a streaming iterator rather than - # a single ResponsesAPIResponse. Drain the iterator and use the - # completed response. - if final_kwargs.get("stream", False): - if not isinstance(ret, SyncResponsesAPIStreamingIterator): - raise AssertionError( - f"Expected Responses stream iterator, got {type(ret)}" - ) + with ( + self._litellm_modify_params_ctx(self.modify_params), + warnings.catch_warnings(), + ): + warnings.filterwarnings("ignore", category=DeprecationWarning) + typed_input: ResponseInputParam | str = ( + cast(ResponseInputParam, input_items) if input_items else "" + ) + api_key_value = self._get_litellm_api_key_value() + + ret = litellm_responses( + model=self.model, + input=typed_input, + instructions=instructions, + tools=resp_tools, + api_key=api_key_value, + api_base=self.base_url, + api_version=self.api_version, + timeout=self.timeout, + drop_params=self.drop_params, + seed=self.seed, + **{**self._aws_kwargs(), **final_kwargs}, + ) + if isinstance(ret, ResponsesAPIResponse): + if user_enable_streaming: + logger.warning( + "Responses streaming was requested, but the provider " + "returned a non-streaming response; no on_token deltas " + "will be emitted." + ) + self._telemetry.on_response(ret) + return ret + + # When stream=True, LiteLLM returns a streaming iterator rather than + # a single ResponsesAPIResponse. Drain the iterator and use the + # completed response. + if final_kwargs.get("stream", False): + if not isinstance(ret, SyncResponsesAPIStreamingIterator): + raise AssertionError( + f"Expected Responses stream iterator, got {type(ret)}" + ) - stream_callback = on_token if user_enable_streaming else None - # Collect output items from streaming events. - # Some endpoints (e.g., Codex subscription) send output - # items as separate events but the final response.completed - # event has output=[]. We accumulate them here and patch - # the completed response if needed. - collected_output_items: list[Any] = [] - for event in ret: - if event is None: - continue - # Collect finished output items - evt_type = getattr(event, "type", None) - if evt_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: - item = getattr(event, "item", None) - if item is not None: - collected_output_items.append(item) - if stream_callback is None: - continue - if isinstance( - event, - ( - OutputTextDeltaEvent, - RefusalDeltaEvent, - ReasoningSummaryTextDeltaEvent, - ), - ): - delta = event.delta - if delta: - stream_callback( - ModelResponseStream( - choices=[ - StreamingChoices( - delta=Delta(content=delta) - ) - ] - ) + stream_callback = on_token if user_enable_streaming else None + # Collect output items from streaming events. + # Some endpoints (e.g., Codex subscription) send output + # items as separate events but the final response.completed + # event has output=[]. We accumulate them here and patch + # the completed response if needed. + collected_output_items: list[Any] = [] + for event in ret: + if event is None: + continue + # Collect finished output items + evt_type = getattr(event, "type", None) + if evt_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: + item = getattr(event, "item", None) + if item is not None: + collected_output_items.append(item) + if stream_callback is None: + continue + if isinstance( + event, + ( + OutputTextDeltaEvent, + RefusalDeltaEvent, + ReasoningSummaryTextDeltaEvent, + ), + ): + delta = event.delta + if delta: + stream_callback( + ModelResponseStream( + choices=[ + StreamingChoices(delta=Delta(content=delta)) + ] ) + ) - completed_event = ret.completed_response - if completed_event is None: - raise LLMNoResponseError( - "Responses stream finished without a completed response" - ) - if not isinstance(completed_event, ResponseCompletedEvent): - raise LLMNoResponseError( - f"Unexpected completed event: {type(completed_event)}" - ) + completed_event = ret.completed_response + if completed_event is None: + raise LLMNoResponseError( + "Responses stream finished without a completed response" + ) + if not isinstance(completed_event, ResponseCompletedEvent): + raise LLMNoResponseError( + f"Unexpected completed event: {type(completed_event)}" + ) - completed_resp = completed_event.response + completed_resp = completed_event.response - # Patch empty output with items collected from stream - if not completed_resp.output and collected_output_items: - completed_resp.output = collected_output_items + # Patch empty output with items collected from stream + if not completed_resp.output and collected_output_items: + completed_resp.output = collected_output_items - self._telemetry.on_response(completed_resp) - return completed_resp + self._telemetry.on_response(completed_resp) + return completed_resp - raise AssertionError( - f"Expected ResponsesAPIResponse, got {type(ret)}" - ) + raise AssertionError(f"Expected ResponsesAPIResponse, got {type(ret)}") try: resp: ResponsesAPIResponse = _one_attempt() @@ -1334,21 +1332,24 @@ async def aresponses( } ) - completed: ResponsesAPIResponse | None = None - async for attempt in self.async_retry( - num_retries=self.num_retries, - retry_exceptions=LLM_RETRY_EXCEPTIONS, - retry_min_wait=self.retry_min_wait, - retry_max_wait=self.retry_max_wait, - retry_multiplier=self.retry_multiplier, - retry_listener=self._retry_listener_fn, - ): - with attempt: - assert self._telemetry is not None - self._telemetry.on_request(telemetry_ctx=telemetry_ctx) - final_kwargs = {**call_kwargs} - with self._litellm_modify_params_ctx(self.modify_params): - with warnings.catch_warnings(): + try: + completed: ResponsesAPIResponse | None = None + async for attempt in self.async_retry( + num_retries=self.num_retries, + retry_exceptions=LLM_RETRY_EXCEPTIONS, + retry_min_wait=self.retry_min_wait, + retry_max_wait=self.retry_max_wait, + retry_multiplier=self.retry_multiplier, + retry_listener=self._retry_listener_fn, + ): + with attempt: + assert self._telemetry is not None + self._telemetry.on_request(telemetry_ctx=telemetry_ctx) + final_kwargs = {**call_kwargs} + with ( + self._litellm_modify_params_ctx(self.modify_params), + warnings.catch_warnings(), + ): warnings.filterwarnings("ignore", category=DeprecationWarning) typed_input: ResponseInputParam | str = ( cast(ResponseInputParam, input_items) if input_items else "" @@ -1371,9 +1372,10 @@ async def aresponses( if isinstance(ret, ResponsesAPIResponse): if user_enable_streaming: logger.warning( - "Responses streaming was requested, but the " - "provider returned a non-streaming response; " - "no on_token deltas will be emitted." + "Responses streaming was requested, " + "but the provider returned a " + "non-streaming response; no on_token " + "deltas will be emitted." ) self._telemetry.on_response(ret) completed = ret @@ -1399,7 +1401,7 @@ async def aresponses( collected_output_items.append(item) if stream_cb is None: continue - if isinstance( + if not isinstance( event, ( OutputTextDeltaEvent, @@ -1407,24 +1409,25 @@ async def aresponses( ReasoningSummaryTextDeltaEvent, ), ): - delta = event.delta - if delta: - await _invoke_token_callback( - stream_cb, - ModelResponseStream( - choices=[ - StreamingChoices( - delta=Delta(content=delta) - ) - ] - ), - ) + continue + if not event.delta: + continue + await _invoke_token_callback( + stream_cb, + ModelResponseStream( + choices=[ + StreamingChoices( + delta=Delta(content=event.delta) + ) + ] + ), + ) completed_event = ret.completed_response if completed_event is None: raise LLMNoResponseError( - "Responses stream finished without a " - "completed response" + "Responses stream finished without " + "a completed response" ) if not isinstance(completed_event, ResponseCompletedEvent): raise LLMNoResponseError( @@ -1443,7 +1446,6 @@ async def aresponses( f"Expected ResponsesAPIResponse, got {type(ret)}" ) - try: assert completed is not None output_seq = cast(Sequence[Any], completed.output or []) message = Message.from_llm_responses_output(output_seq) @@ -1515,63 +1517,65 @@ def _transport_call( **kwargs, ) -> ModelResponse: # litellm.modify_params is GLOBAL; guard it for thread-safety - with self._litellm_modify_params_ctx(self.modify_params): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=DeprecationWarning, module="httpx.*" - ) - warnings.filterwarnings( - "ignore", - message=r".*content=.*upload.*", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - message=r"There is no current event loop", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - category=UserWarning, - ) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message="Accessing the 'model_fields' attribute.*", - ) - api_key_value = self._get_litellm_api_key_value() - - # When streaming, request usage in the final chunk so that - # detailed token breakdowns (prompt_tokens_details with - # cached_tokens, etc.) are not silently discarded by - # litellm's streaming handler. - if enable_streaming: - kwargs.setdefault("stream_options", {"include_usage": True}) - - # Some providers need renames handled in _normalize_call_kwargs. - ret = litellm_completion( - model=self.model, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - messages=messages, - **{**self._aws_kwargs(), **kwargs}, - ) - if enable_streaming and on_token is not None: - assert isinstance(ret, CustomStreamWrapper) - chunks = [] - for chunk in ret: - on_token(chunk) - chunks.append(chunk) - ret = litellm.stream_chunk_builder(chunks, messages=messages) - - assert isinstance(ret, ModelResponse), ( - f"Expected ModelResponse, got {type(ret)}" - ) - return ret + with ( + self._litellm_modify_params_ctx(self.modify_params), + warnings.catch_warnings(), + ): + warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="httpx.*" + ) + warnings.filterwarnings( + "ignore", + message=r".*content=.*upload.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"There is no current event loop", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="Accessing the 'model_fields' attribute.*", + ) + api_key_value = self._get_litellm_api_key_value() + + # When streaming, request usage in the final chunk so that + # detailed token breakdowns (prompt_tokens_details with + # cached_tokens, etc.) are not silently discarded by + # litellm's streaming handler. + if enable_streaming: + kwargs.setdefault("stream_options", {"include_usage": True}) + + # Some providers need renames handled in _normalize_call_kwargs. + ret = litellm_completion( + model=self.model, + api_key=api_key_value, + api_base=self.base_url, + api_version=self.api_version, + timeout=self.timeout, + drop_params=self.drop_params, + seed=self.seed, + messages=messages, + **{**self._aws_kwargs(), **kwargs}, + ) + if enable_streaming and on_token is not None: + assert isinstance(ret, CustomStreamWrapper) + chunks = [] + for chunk in ret: + on_token(chunk) + chunks.append(chunk) + ret = litellm.stream_chunk_builder(chunks, messages=messages) + + assert isinstance(ret, ModelResponse), ( + f"Expected ModelResponse, got {type(ret)}" + ) + return ret async def _atransport_call( self, @@ -1582,55 +1586,57 @@ async def _atransport_call( **kwargs, ) -> ModelResponse: """Async variant of :meth:`_transport_call`.""" - with self._litellm_modify_params_ctx(self.modify_params): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=DeprecationWarning, module="httpx.*" - ) - warnings.filterwarnings( - "ignore", - message=r".*content=.*upload.*", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - message=r"There is no current event loop", - category=DeprecationWarning, - ) - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message="Accessing the 'model_fields' attribute.*", - ) - api_key_value = self._get_litellm_api_key_value() - - if enable_streaming: - kwargs.setdefault("stream_options", {"include_usage": True}) - - ret = await litellm_acompletion( - model=self.model, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - messages=messages, - **{**self._aws_kwargs(), **kwargs}, - ) - if enable_streaming and on_token is not None: - assert isinstance(ret, CustomStreamWrapper) - chunks = [] - async for chunk in ret: - await _invoke_token_callback(on_token, chunk) - chunks.append(chunk) - ret = litellm.stream_chunk_builder(chunks, messages=messages) - - assert isinstance(ret, ModelResponse), ( - f"Expected ModelResponse, got {type(ret)}" - ) - return ret + with ( + self._litellm_modify_params_ctx(self.modify_params), + warnings.catch_warnings(), + ): + warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="httpx.*" + ) + warnings.filterwarnings( + "ignore", + message=r".*content=.*upload.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"There is no current event loop", + category=DeprecationWarning, + ) + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="Accessing the 'model_fields' attribute.*", + ) + api_key_value = self._get_litellm_api_key_value() + + if enable_streaming: + kwargs.setdefault("stream_options", {"include_usage": True}) + + ret = await litellm_acompletion( + model=self.model, + api_key=api_key_value, + api_base=self.base_url, + api_version=self.api_version, + timeout=self.timeout, + drop_params=self.drop_params, + seed=self.seed, + messages=messages, + **{**self._aws_kwargs(), **kwargs}, + ) + if enable_streaming and on_token is not None: + assert isinstance(ret, CustomStreamWrapper) + chunks = [] + async for chunk in ret: + await _invoke_token_callback(on_token, chunk) + chunks.append(chunk) + ret = litellm.stream_chunk_builder(chunks, messages=messages) + + assert isinstance(ret, ModelResponse), ( + f"Expected ModelResponse, got {type(ret)}" + ) + return ret @contextmanager def _litellm_modify_params_ctx(self, flag: bool): diff --git a/tests/sdk/llm/test_llm_fallback.py b/tests/sdk/llm/test_llm_fallback.py index 9758d65a02..a2c3f3d1f1 100644 --- a/tests/sdk/llm/test_llm_fallback.py +++ b/tests/sdk/llm/test_llm_fallback.py @@ -1,10 +1,12 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest from litellm.exceptions import ( APIConnectionError, + ContextWindowExceededError, RateLimitError, ) +from litellm.types.llms.openai import ResponsesAPIResponse from litellm.types.utils import ( Choices, Message as LiteLLMMessage, @@ -14,7 +16,10 @@ from pydantic import SecretStr from openhands.sdk.llm import LLM, FallbackStrategy, Message, TextContent -from openhands.sdk.llm.exceptions import LLMServiceUnavailableError +from openhands.sdk.llm.exceptions import ( + LLMContextWindowExceedError, + LLMServiceUnavailableError, +) def _get_mock_response(content: str = "ok", model: str = "gpt-4o") -> ModelResponse: @@ -305,3 +310,128 @@ def side_effect(**kwargs): content = resp.message.content[0] assert isinstance(content, TextContent) assert content.text == "from store" + + +# ========================================================================= +# Async error-handling parity tests (acompletion / aresponses) +# ========================================================================= + + +@pytest.mark.asyncio +@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm_acompletion", new_callable=AsyncMock) +async def test_acompletion_fallback_on_transport_error(mock_acomp, mock_comp): + """acompletion must invoke fallback when the primary transport raises.""" + primary_error = APIConnectionError( + message="connection reset", llm_provider="openai", model="gpt-4o" + ) + mock_acomp.side_effect = primary_error + + # Fallback uses sync completion path + mock_comp.return_value = _get_mock_response("fallback ok", model="fallback-model") + + fb = _get_llm("fallback-model") + strategy = FallbackStrategy(fallback_llms=["fb-profile"]) + primary = _get_llm("gpt-4o", fallback_strategy=strategy) + _patch_resolve(primary, [fb]) + + resp = await primary.acompletion(_MSGS) + content = resp.message.content[0] + assert isinstance(content, TextContent) + assert content.text == "fallback ok" + + +@pytest.mark.asyncio +@patch("openhands.sdk.llm.llm.litellm_acompletion", new_callable=AsyncMock) +async def test_acompletion_maps_context_window_error(mock_acomp): + """acompletion must map ContextWindowExceededError to SDK type.""" + mock_acomp.side_effect = ContextWindowExceededError( + message="context window exceeded", + llm_provider="openai", + model="gpt-4o", + ) + primary = _get_llm("gpt-4o") + with pytest.raises(LLMContextWindowExceedError): + await primary.acompletion(_MSGS) + + +@pytest.mark.asyncio +@patch("openhands.sdk.llm.llm.litellm_acompletion", new_callable=AsyncMock) +async def test_acompletion_maps_connection_error(mock_acomp): + """acompletion must map APIConnectionError to LLMServiceUnavailableError.""" + mock_acomp.side_effect = APIConnectionError( + message="down", llm_provider="openai", model="gpt-4o" + ) + primary = _get_llm("gpt-4o") + with pytest.raises(LLMServiceUnavailableError): + await primary.acompletion(_MSGS) + + +@pytest.mark.asyncio +@patch("openhands.sdk.llm.llm.litellm_responses") +@patch("openhands.sdk.llm.llm.litellm_aresponses", new_callable=AsyncMock) +async def test_aresponses_fallback_on_transport_error(mock_aresp, mock_resp): + """aresponses must invoke fallback when the primary transport raises.""" + + primary_error = APIConnectionError( + message="down", llm_provider="openai", model="gpt-4o" + ) + mock_aresp.side_effect = primary_error + + fallback_response = ResponsesAPIResponse( + id="resp-fb", + created_at=1, + model="fb", + object="response", + output=[ + { + "type": "message", + "id": "msg-1", + "role": "assistant", + "status": "completed", + "content": [ + {"type": "output_text", "text": "fb ok", "annotations": []} + ], + } + ], + parallel_tool_calls=False, + tool_choice="auto", + tools=[], + ) + mock_resp.return_value = fallback_response + + fb = _get_llm("fb") + strategy = FallbackStrategy(fallback_llms=["fb-profile"]) + primary = _get_llm("gpt-4o", fallback_strategy=strategy) + _patch_resolve(primary, [fb]) + + resp = await primary.aresponses(_MSGS) + content = resp.message.content[0] + assert isinstance(content, TextContent) + assert content.text == "fb ok" + + +@pytest.mark.asyncio +@patch("openhands.sdk.llm.llm.litellm_aresponses", new_callable=AsyncMock) +async def test_aresponses_maps_context_window_error(mock_aresp): + """aresponses must map ContextWindowExceededError to SDK type.""" + mock_aresp.side_effect = ContextWindowExceededError( + message="context window exceeded", + llm_provider="openai", + model="gpt-4o", + ) + primary = _get_llm("gpt-4o") + with pytest.raises(LLMContextWindowExceedError): + await primary.aresponses(_MSGS) + + +@pytest.mark.asyncio +@patch("openhands.sdk.llm.llm.litellm_aresponses", new_callable=AsyncMock) +async def test_aresponses_maps_connection_error(mock_aresp): + """aresponses must map APIConnectionError to LLMServiceUnavailableError.""" + mock_aresp.side_effect = APIConnectionError( + message="down", llm_provider="openai", model="gpt-4o" + ) + primary = _get_llm("gpt-4o") + with pytest.raises(LLMServiceUnavailableError): + await primary.aresponses(_MSGS)