Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _tool_result_chunk_for_message(
A tuple of (token_count_for_tool_content, streamed_tool_result_chunk).
"""
content_tokens = token_handler.text_to_tokens(str(tool_call_message.content))
content_token_count = len(content_tokens)
content_token_count = TokenHandler._get_token_count(content_tokens)

was_truncated = tool_call_message.additional_kwargs.get("truncated", False)
base_status = tool_call_message.status
Expand Down
82 changes: 81 additions & 1 deletion tests/unit/query_helpers/test_docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
from typing import ClassVar
from unittest.mock import ANY, AsyncMock, patch
from unittest.mock import ANY, AsyncMock, MagicMock, patch

import pytest
from langchain_core.messages import HumanMessage, ToolMessage
Expand Down Expand Up @@ -356,6 +356,86 @@ def test_tool_calling_tool_execution(caplog):
assert mock_invoke.call_count == 2


def test_tool_output_token_tracking_uses_buffer_weight(caplog):
"""Test that tool output tokens are counted with TOKEN_BUFFER_WEIGHT like other budget items.

Before this fix, raw len(tokens) was used for tool outputs while tool definitions
and AIMessage tokens used _get_token_count() (which applies a 1.1x buffer).
This test asserts _get_token_count() is called for tool output tokens by spying on
it: with one tool call in one round it must be called at least 3 times
(tool definitions, AIMessage, tool output).
"""
mcp_servers_config = {
"test_server": {
"transport": "streamable_http",
"url": "http://test-server:8080/mcp",
},
}

original_get_token_count = TokenHandler._get_token_count
call_count = 0

def counting_get_token_count(tokens: list) -> int:
nonlocal call_count
call_count += 1
return original_get_token_count(tokens)

with (
patch(
"ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations",
return_value=2,
),
patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls,
patch(
"ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm"
) as mock_invoke,
patch("ols.utils.mcp_utils.config") as mock_config,
patch.object(
TokenHandler, "_get_token_count", staticmethod(counting_get_token_count)
),
):
mock_config.tools_rag = None
mock_config.mcp_servers.servers = [MagicMock()]

with patch(
"ols.utils.mcp_utils._gather_and_populate_tools",
new=AsyncMock(return_value=(mcp_servers_config, mock_tools_map)),
):
mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke(
[
AIMessageChunk(
content="",
response_metadata={"finish_reason": "tool_calls"},
tool_calls=[
{
"name": "get_namespaces_mock",
"args": {},
"id": "call_id1",
},
],
)
]
)

mock_mcp_client_instance = AsyncMock()
mock_mcp_client_instance.get_tools.return_value = mock_tools_map
mock_mcp_client_cls.return_value = mock_mcp_client_instance

summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None))
summarizer.model_config.parameters.max_tokens_for_tools = 50000
summarizer.model_config.parameters.max_tokens_per_tool_output = 8000
summarizer.create_response("How many namespaces?")

# _get_token_count must be called for:
# 1. tool definitions (once at the start of the loop)
# 2. AIMessage with tool_calls
# 3. tool output (the change introduced by this fix)
assert call_count >= 3, (
f"Expected _get_token_count to be called at least 3 times "
f"(definitions + AIMessage + tool output), got {call_count}"
)


@pytest.mark.asyncio
async def test_gather_mcp_tools_failure_isolation(caplog):
"""Test gather_mcp_tools isolates failures from individual MCP servers."""
Expand Down