diff --git a/ols/app/endpoints/ols.py b/ols/app/endpoints/ols.py index bdb5f3c6d..e314830c2 100644 --- a/ols/app/endpoints/ols.py +++ b/ols/app/endpoints/ols.py @@ -562,6 +562,7 @@ def generate_response( system_prompt=llm_request.system_prompt, user_token=user_token, client_headers=client_headers, + streaming=streaming, ) history = CacheEntry.cache_entries_to_history(previous_input) if streaming: diff --git a/ols/app/endpoints/streaming_ols.py b/ols/app/endpoints/streaming_ols.py index 0ba4545ab..18b5eb305 100644 --- a/ols/app/endpoints/streaming_ols.py +++ b/ols/app/endpoints/streaming_ols.py @@ -11,7 +11,6 @@ from fastapi import APIRouter, Depends, status from fastapi.responses import StreamingResponse -from langchain_core.messages import ToolMessage from ols import config, constants from ols.app.endpoints.ols import ( @@ -27,6 +26,7 @@ ) from ols.app.models.models import ( Attachment, + ChunkType, ErrorResponse, ForbiddenResponse, LLMRequest, @@ -48,9 +48,13 @@ auth_dependency = get_auth_dependency(config.ols_config, virtual_path="/ols-access") -LLM_TOKEN_EVENT = "token" # noqa: S105 -LLM_TOOL_CALL_EVENT = "tool_call" -LLM_TOOL_RESULT_EVENT = "tool_result" +STREAM_KEY_EVENT = "event" +STREAM_KEY_DATA = "data" +TOKEN_KEY_ID = "id" # noqa: S105 +TOKEN_KEY_TOKEN = "token" # noqa: S105 +END_KEY_RAG_CHUNKS = "rag_chunks" +END_KEY_TRUNCATED = "truncated" +END_KEY_TOKEN_COUNTER = "token_counter" # noqa: S105 query_responses: dict[int | str, dict[str, Any]] = { @@ -123,7 +127,7 @@ def conversation_request( ) -def format_stream_data(d: dict) -> str: +def format_stream_data(d: dict[str, object]) -> str: """Format outbound data in the Event Stream Format.""" data = json.dumps(d) return f"data: {data}\n\n" @@ -145,7 +149,7 @@ def stream_start_event(conversation_id: str) -> str: ) -def stream_event(data: dict, event_type: str, media_type: str) -> str: +def stream_event(data: dict[str, object], event_type: str, media_type: str) -> str: """Build an item to yield based on media type. Args: @@ -157,18 +161,20 @@ def stream_event(data: dict, event_type: str, media_type: str) -> str: str: The formatted string or JSON to yield. """ if media_type == MEDIA_TYPE_TEXT: - if event_type == LLM_TOKEN_EVENT: - return data["token"] - if event_type == LLM_TOOL_CALL_EVENT: - return f"\nTool call: {json.dumps(data)}\n" - if event_type == LLM_TOOL_RESULT_EVENT: - return f"\nTool result: {json.dumps(data)}\n" - logger.error("Unknown event type: %s", event_type) - return "" + match event_type: + case _ if event_type == TOKEN_KEY_TOKEN: + return data[TOKEN_KEY_TOKEN] + case ChunkType.TOOL_CALL.value: + return f"\nTool call: {json.dumps(data)}\n" + case ChunkType.TOOL_RESULT.value: + return f"\nTool result: {json.dumps(data)}\n" + case _: + logger.error("Unknown event type: %s", event_type) + return "" return format_stream_data( { - "event": event_type, - "data": data, + STREAM_KEY_EVENT: event_type, + STREAM_KEY_DATA: data, } ) @@ -288,8 +294,8 @@ def store_data( conversation_id: str, llm_request: LLMRequest, response: str, - tool_calls: list[dict], - tool_results: list[ToolMessage], + tool_calls: list[dict[str, object]], + tool_results: list[dict[str, object]], attachments: list[Attachment], query_without_attachments: str, rag_chunks: list[RagChunk], @@ -342,7 +348,7 @@ def store_data( async def response_processing_wrapper( - generator: AsyncGenerator[Any, None], + generator: AsyncGenerator[StreamedChunk, None], user_id: str, conversation_id: str, llm_request: LLMRequest, @@ -372,9 +378,9 @@ async def response_processing_wrapper( yield stream_start_event(conversation_id) response: str = "" - rag_chunks: list = [] - tool_calls: list = [] - tool_results: list = [] + rag_chunks: list[RagChunk] = [] + tool_calls: list[dict[str, object]] = [] + tool_results: list[dict[str, object]] = [] history_truncated: bool = False idx: int = 0 token_counter: Optional[TokenCounter] = None @@ -385,39 +391,40 @@ async def response_processing_wrapper( msg = f"Expecting StreamedChunk, but got {type(item)}: {item}" logger.error(msg) raise ValueError(msg) - if item.type == "tool_call": - tool_calls.append(item.data) - yield stream_event( - data=item.data, - event_type=LLM_TOOL_CALL_EVENT, - media_type=media_type, - ) - elif item.type == "tool_result": - tool_results.append(item.data) - yield stream_event( - data=item.data, - event_type=LLM_TOOL_RESULT_EVENT, - media_type=media_type, - ) - elif item.type == "text": - response += item.text - yield stream_event( - data={"id": idx, "token": item.text}, - event_type=LLM_TOKEN_EVENT, - media_type=media_type, - ) - idx += 1 - elif item.type == "end": - rag_chunks = item.data["rag_chunks"] - history_truncated = item.data["truncated"] - token_counter = item.data["token_counter"] - else: - msg = ( - "Yielded unknown item type from streaming generator, " - f"item: {item}" - ) - logger.error(msg) - raise ValueError(msg) + match item.type: + case ChunkType.TOOL_CALL: + tool_calls.append(item.data) + yield stream_event( + data=item.data, + event_type=ChunkType.TOOL_CALL.value, + media_type=media_type, + ) + case ChunkType.TOOL_RESULT: + tool_results.append(item.data) + yield stream_event( + data=item.data, + event_type=ChunkType.TOOL_RESULT.value, + media_type=media_type, + ) + case ChunkType.TEXT: + response += item.text + yield stream_event( + data={TOKEN_KEY_ID: idx, TOKEN_KEY_TOKEN: item.text}, + event_type=TOKEN_KEY_TOKEN, + media_type=media_type, + ) + idx += 1 + case ChunkType.END: + rag_chunks = item.data[END_KEY_RAG_CHUNKS] + history_truncated = item.data[END_KEY_TRUNCATED] + token_counter = item.data[END_KEY_TOKEN_COUNTER] + case _: + msg = ( + "Yielded unknown item type from streaming generator, " + f"item: {item}" + ) + logger.error(msg) + raise ValueError(msg) except PromptTooLongError as summarizer_error: yield prompt_too_long_error(summarizer_error, media_type) return # stop execution after error diff --git a/ols/app/models/models.py b/ols/app/models/models.py index d6241b1aa..c3923007f 100644 --- a/ols/app/models/models.py +++ b/ols/app/models/models.py @@ -3,6 +3,7 @@ import json from collections import OrderedDict from dataclasses import field +from enum import Enum from typing import Any, Literal, Optional, Self, Union from langchain_core.language_models.llms import LLM @@ -982,6 +983,15 @@ class ProcessedRequest(BaseModel): user_token: str +class ChunkType(str, Enum): + """Supported streamed chunk types.""" + + TEXT = "text" + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + END = "end" + + @dataclass class StreamedChunk: """Represents a chunk of streamed data from the LLM. @@ -992,7 +1002,7 @@ class StreamedChunk: data: Additional data associated with the chunk (for non-text chunks) """ - type: Literal["text", "tool_call", "tool_result", "end"] + type: ChunkType text: str = "" data: dict[str, Any] = field(default_factory=dict) diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 439284520..59bc34b36 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -5,8 +5,8 @@ import json import logging import time -from collections.abc import Coroutine -from typing import Any, AsyncGenerator, Optional, TypeAlias +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Coroutine, Optional, TypeAlias from langchain_core.globals import set_debug from langchain_core.messages import AIMessage, BaseMessage, ToolMessage @@ -18,19 +18,25 @@ from ols import config, constants from ols.app.metrics import TokenMetricUpdater from ols.app.metrics.token_counter import GenericTokenCounter -from ols.app.models.models import RagChunk, StreamedChunk, SummarizerResponse +from ols.app.models.models import ChunkType, RagChunk, StreamedChunk, SummarizerResponse from ols.constants import GenericLLMParameters from ols.src.prompts.prompt_generator import GeneratePrompt from ols.src.query_helpers.query_helper import QueryHelper from ols.src.tools.tools import execute_tool_calls -from ols.utils.mcp_utils import ClientHeaders, get_mcp_tools +from ols.utils.mcp_utils import ClientHeaders, build_mcp_config, get_mcp_tools from ols.utils.token_handler import TokenHandler logger = logging.getLogger(__name__) -# Type aliases for clarity and reusability -MessageHistory: TypeAlias = list[BaseMessage] -ToolsList: TypeAlias = list[StructuredTool] +MIN_TOOL_EXECUTION_TOKENS = 100 +ToolCallDefinition: TypeAlias = tuple[str, dict[str, object], StructuredTool] + + +@dataclass(slots=True) +class ToolTokenUsage: + """Mutable holder for cumulative tool-token usage across helper boundaries.""" + + used: int def skip_special_chunk( @@ -58,7 +64,7 @@ def skip_special_chunk( def tool_calls_from_tool_calls_chunks( tool_calls_chunks: list[AIMessageChunk], -) -> list[dict[str, Any]]: +) -> list[dict[str, object]]: """Extract complete tool calls from a series of tool call chunks. The LLM streams tool calls in partial chunks that need to be combined to form @@ -89,138 +95,43 @@ def run_async_safely(coro: Coroutine[Any, Any, Any]) -> Any: raise -def _enrich_tool_call( - tool_call: dict[str, Any], - all_mcp_tools: ToolsList, -) -> dict[str, Any]: - """Enrich a tool_call dict with metadata from the StructuredTool. - - Adds tool_meta and server_name so the UI can start preloading - UI resources (e.g. MCP Apps iframes) before the tool result arrives. - - Args: - tool_call: LLM-generated tool call dict (name, args, id). - all_mcp_tools: Available tools (carry metadata from MCP servers). - - Returns: - Enriched tool call dict. Original keys are preserved. - """ - enriched: dict[str, Any] = {**tool_call} - tool_obj = next((t for t in all_mcp_tools if t.name == tool_call.get("name")), None) - if not tool_obj: - return enriched - - tool_metadata = tool_obj.metadata or {} - - server_name = tool_metadata.get("mcp_server") - if server_name: - enriched["server_name"] = server_name - - tool_meta = tool_metadata.get("_meta") - if tool_meta: - enriched["tool_meta"] = tool_meta - - return enriched - - -def _build_tool_result_chunks( - tool_calls: list[dict[str, Any]], - tool_calls_messages: list[ToolMessage], - all_mcp_tools: ToolsList, - round_number: int, -) -> list[StreamedChunk]: - """Build StreamedChunk objects for tool results with metadata. - - Args: - tool_calls: LLM-generated tool call dicts (name, args, id). - tool_calls_messages: Executed ToolMessage results. - all_mcp_tools: Available tools (carry metadata from MCP servers). - round_number: Current tool-calling round index. - - Returns: - List of StreamedChunk objects with type "tool_result". - """ - tool_id_to_name = {tc.get("id"): tc.get("name") for tc in tool_calls} - tools_by_name = {t.name: t for t in all_mcp_tools} - chunks: list[StreamedChunk] = [] - - for tool_call_message in tool_calls_messages: - was_truncated = tool_call_message.additional_kwargs.get("truncated", False) - tool_status = "truncated" if was_truncated else tool_call_message.status - - tool_name = tool_id_to_name.get(tool_call_message.tool_call_id, "unknown") - tool_obj = tools_by_name.get(tool_name) - tool_metadata = (tool_obj.metadata or {}) if tool_obj else {} - - logger.debug( - json.dumps( - { - "event": "tool_result", - "tool_id": tool_call_message.tool_call_id, - "tool_name": tool_name, - "status": tool_call_message.status, - "truncated": was_truncated, - "has_meta": "_meta" in tool_metadata, - "output_snippet": str(tool_call_message.content)[:1000], - }, - ensure_ascii=False, - indent=2, - ) - ) - - tool_result_data: dict[str, Any] = { - "id": tool_call_message.tool_call_id, - "name": tool_name, - "status": tool_status, - "content": tool_call_message.content, - "type": "tool_result", - "round": round_number, - } - - server_name = tool_metadata.get("mcp_server") - if server_name: - tool_result_data["server_name"] = server_name - - tool_meta = tool_metadata.get("_meta") - if tool_meta: - tool_result_data["tool_meta"] = tool_meta - - structured_content = tool_call_message.additional_kwargs.get( - "structured_content" - ) - if structured_content: - tool_result_data["structured_content"] = structured_content - - chunks.append(StreamedChunk(type="tool_result", data=tool_result_data)) - - return chunks - - class DocsSummarizer(QueryHelper): """A class for summarizing documentation context.""" def __init__( self, - *args: Any, + *args: object, user_token: Optional[str] = None, client_headers: ClientHeaders | None = None, - **kwargs: Any, + streaming: bool = False, + **kwargs: object, ) -> None: """Initialize the DocsSummarizer. Args: user_token: Optional user authentication token for tool access client_headers: Optional client-provided MCP headers for authentication + streaming: Whether this summarizer is used for the streaming endpoint *args: Additional positional arguments passed to the parent class **kwargs: Additional keyword arguments passed to the parent class """ super().__init__(*args, **kwargs) self._prepare_llm() self.verbose = config.ols_config.logging_config.app_log_level == logging.DEBUG + self.streaming = streaming # tools part self.client_headers = client_headers or {} self.user_token = user_token + self.mcp_servers = build_mcp_config( + config.mcp_servers.servers, self.user_token, self.client_headers + ) + if self.mcp_servers: + logger.info("MCP servers provided: %s", list(self.mcp_servers.keys())) + self._tool_calling_enabled = True + else: + logger.debug("No MCP servers provided, tool calling is disabled") + self._tool_calling_enabled = False set_debug(self.verbose) @@ -237,16 +148,11 @@ def _prepare_llm(self) -> None: self.generic_llm_params, ) - @property - def _has_mcp_tools(self) -> bool: - """Check if MCP servers are configured.""" - return config.mcp_servers is not None and len(config.mcp_servers.servers) > 0 - def _prepare_prompt( self, query: str, rag_retriever: Optional[BaseRetriever] = None, - history: Optional[MessageHistory] = None, + history: Optional[list[BaseMessage]] = None, ) -> tuple[ChatPromptTemplate, dict[str, str], list[RagChunk], bool]: """Summarize the given query based on the provided conversation context. @@ -274,12 +180,10 @@ def _prepare_prompt( ["sample"], [AIMessage("sample")], self._system_prompt, - self._has_mcp_tools, + self._tool_calling_enabled, ).generate_prompt(self.model) max_tokens_for_tools = ( - self.model_config.parameters.max_tokens_for_tools - if self._has_mcp_tools - else 0 + self.model_config.parameters.max_tokens_for_tools if self.mcp_servers else 0 ) available_tokens = token_handler.calculate_and_check_available_tokens( temp_prompt.format(**temp_prompt_input), @@ -324,7 +228,7 @@ def _prepare_prompt( rag_context, history, self._system_prompt, - self._has_mcp_tools, + self._tool_calling_enabled, ).generate_prompt(self.model) # Tokens-check: We trigger the computation of the token count @@ -343,7 +247,7 @@ async def _invoke_llm( self, messages: ChatPromptTemplate, llm_input_values: dict[str, str], - tools_map: ToolsList, + tools_map: list[StructuredTool], is_final_round: bool, token_counter: GenericTokenCounter, ) -> AsyncGenerator[AIMessageChunk, None]: @@ -391,76 +295,117 @@ async def _invoke_llm( time.monotonic() - llm_start_time, ) - async def iterate_with_tools( # noqa: C901 # pylint: disable=R0912,R0915 + def _resolve_tool_call_definitions( self, - messages: ChatPromptTemplate, - max_rounds: int, - llm_input_values: dict[str, str], - token_counter: GenericTokenCounter, - all_mcp_tools: ToolsList, - ) -> AsyncGenerator[StreamedChunk, None]: - """Iterate through multiple rounds of LLM invocation with tool calling. - - Args: - messages: The initial messages - max_rounds: Maximum number of tool calling rounds - llm_input_values: Input values for the LLM - token_counter: Counter for tracking token usage - all_mcp_tools: List of MCP tools to use for tool calling - - Yields: - StreamedChunk objects representing parts of the response - """ - async with asyncio.timeout(constants.TOOL_CALL_ROUND_TIMEOUT * max_rounds): - # Track cumulative token usage for tool outputs - tool_tokens_used = 0 - max_tokens_for_tools = self.model_config.parameters.max_tokens_for_tools - max_tokens_per_tool = ( - self.model_config.parameters.max_tokens_per_tool_output - ) - token_handler = TokenHandler() - rounds_used = 0 - stop_reason = "loop_completion" - - # Account for tool definitions tokens (schemas sent to LLM) - if all_mcp_tools: - tool_definitions_text = json.dumps( - [ - { - "name": t.name, - "description": t.description, - "schema": ( - t.args_schema - if isinstance(t.args_schema, dict) - else ( - t.args_schema.model_json_schema() - if t.args_schema is not None - else {} - ) - ), - } - for t in all_mcp_tools - ] + tool_calls: list[dict[str, object]], + all_tools_dict: dict[str, StructuredTool], + duplicate_tool_names: set[str], + ) -> tuple[ + list[ToolCallDefinition], + list[ToolMessage], + ]: + """Resolve LLM tool calls into executable definitions and skipped outcomes.""" + tool_call_definitions: list[ToolCallDefinition] = [] + skipped_tool_messages: list[ToolMessage] = [] + + # Resolve each LLM-emitted tool call to the execution triple: + # (tool_call_id, parsed_args, resolved StructuredTool). + for tool_call in tool_calls: + tool_name = tool_call.get("name") + tool_id = str(tool_call.get("id", "unknown")) + if not isinstance(tool_name, str): + skipped_tool_messages.append( + ToolMessage( + content=( + "Tool call skipped: missing or invalid tool name. " + "Do not retry this exact tool call." + ), + status="error", + tool_call_id=tool_id, + ) + ) + continue + if tool_name in duplicate_tool_names: + logger.error( + "Tool '%s' is ambiguous (duplicate name across servers)", + tool_name, ) - tool_definitions_tokens = TokenHandler._get_token_count( - token_handler.text_to_tokens(tool_definitions_text) + skipped_tool_messages.append( + ToolMessage( + content=( + f"Tool '{tool_name}' call skipped: ambiguous tool name " + "across servers. Do not retry this exact tool call." + ), + status="error", + tool_call_id=tool_id, + ) + ) + continue + resolved_tool = all_tools_dict.get(tool_name) + if resolved_tool is None: + logger.error("Tool '%s' was requested but is unavailable", tool_name) + skipped_tool_messages.append( + ToolMessage( + content=( + f"Tool '{tool_name}' call skipped: tool is unavailable. " + "Do not retry this exact tool call." + ), + status="error", + tool_call_id=tool_id, + ) ) - tool_tokens_used += tool_definitions_tokens - logger.debug( - "Tool definitions consume %d tokens", tool_definitions_tokens + continue + raw_args = tool_call.get("args", {}) + if raw_args is None: + tool_args: dict[str, object] = {} + elif isinstance(raw_args, dict): + tool_args = {str(key): value for key, value in raw_args.items()} + else: + logger.error( + "Tool '%s' requested with invalid args type '%s'; skipping call", + tool_name, + type(raw_args).__name__, ) + skipped_tool_messages.append( + ToolMessage( + content=( + f"Tool '{tool_name}' call skipped: invalid args type " + f"'{type(raw_args).__name__}'. Do not retry this exact tool call." + ), + status="error", + tool_call_id=tool_id, + ) + ) + continue + tool_call_definitions.append( + ( + tool_id, + tool_args, + resolved_tool, + ) + ) - # Tool calling in a loop - for i in range(1, max_rounds + 1): - rounds_used = i + return tool_call_definitions, skipped_tool_messages - is_final_round = (not all_mcp_tools) or (i == max_rounds) - logger.debug("Tool calling round %s (final: %s)", i, is_final_round) + async def _collect_round_llm_chunks( + self, + messages: ChatPromptTemplate, + llm_input_values: dict[str, str], + all_mcp_tools: list[StructuredTool], + is_final_round: bool, + token_counter: GenericTokenCounter, + round_index: int, + ) -> tuple[list[AIMessageChunk], list[StreamedChunk], bool]: + """Collect one round of LLM chunks and streamed text output. - tool_call_chunks = [] - chunk_counter = 0 - stop_generation = False - # invoke LLM and process response chunks + Returns: + Tuple of (tool_call_chunks, streamed_chunks, should_stop_iteration). + """ + tool_call_chunks: list[AIMessageChunk] = [] + streamed_chunks: list[StreamedChunk] = [] + chunk_counter = 0 + try: + async with asyncio.timeout(constants.TOOL_CALL_ROUND_TIMEOUT): async for chunk in self._invoke_llm( messages, llm_input_values, @@ -468,9 +413,6 @@ async def iterate_with_tools( # noqa: C901 # pylint: disable=R0912,R0915 is_final_round=is_final_round, token_counter=token_counter, ): - if stop_generation: - continue - # TODO: Temporary fix for fake-llm (load test) which gives # output as string. Currently every method that we use gives us # proper output, except fake-llm. We need to move to a different @@ -480,125 +422,393 @@ async def iterate_with_tools( # noqa: C901 # pylint: disable=R0912,R0915 # (load test can be run with tool calling set to False till we # have a permanent fix) if isinstance(chunk, str): - yield StreamedChunk(type="text", text=chunk) + streamed_chunks.append( + StreamedChunk(type=ChunkType.TEXT, text=chunk) + ) break if chunk.response_metadata.get("finish_reason") == "stop": # type: ignore [attr-defined] - stop_generation = True - continue + return tool_call_chunks, streamed_chunks, True - # collect tool chunk or yield text if getattr(chunk, "tool_call_chunks", None): tool_call_chunks.append(chunk) else: if not skip_special_chunk( chunk.content, chunk_counter, self.model, is_final_round ): - # stream text chunks directly - yield StreamedChunk(type="text", text=chunk.content) + streamed_chunks.append( + StreamedChunk(type=ChunkType.TEXT, text=chunk.content) + ) chunk_counter += 1 + except TimeoutError: + logger.error( + "Timed out waiting for LLM chunks in round %s after %s seconds", + round_index, + constants.TOOL_CALL_ROUND_TIMEOUT, + ) + streamed_chunks.append( + StreamedChunk( + type=ChunkType.TEXT, + text=( + "I could not complete this request in time. " + "Please try again." + ), + ) + ) + return tool_call_chunks, streamed_chunks, True - if stop_generation: - stop_reason = "finish_reason_stop" - logger.info( - "Tool loop completed: rounds_used=%d max_rounds=%d stop_reason=%s", - rounds_used, - max_rounds, - stop_reason, - ) - return + return tool_call_chunks, streamed_chunks, False - # exit if this was the final round - if is_final_round: - stop_reason = "final_round" - break - - # tool calling part - if tool_call_chunks: - # assess tool calls and add to messages - tool_calls = tool_calls_from_tool_calls_chunks(tool_call_chunks) - ai_tool_call_message = AIMessage( - content="", type="ai", tool_calls=tool_calls - ) - messages.append(ai_tool_call_message) + @staticmethod + def _enrich_with_tool_metadata( + data: dict[str, Any], + tool: Optional[StructuredTool], + ) -> None: + """Add MCP server metadata to a tool_call or tool_result dict in-place. - # Count tokens used by the AIMessage with tool calls - ai_message_tokens = TokenHandler._get_token_count( - token_handler.text_to_tokens(json.dumps(tool_calls)) - ) - tool_tokens_used += ai_message_tokens - - for tool_call in tool_calls: - enriched = _enrich_tool_call(tool_call, all_mcp_tools) - - logger.debug( - json.dumps( - { - "event": "tool_call", - "tool_name": enriched.get("name", "unknown"), - "arguments": enriched.get("args", {}), - "tool_id": enriched.get("id", "unknown"), - }, - ensure_ascii=False, - indent=2, - ) - ) + Adds ``server_name`` and ``tool_meta`` when available so the UI can + associate events with their originating MCP server and preload + resources (e.g. MCP Apps iframes). + """ + if tool is None: + return + tool_metadata = tool.metadata or {} + server_name = tool_metadata.get("mcp_server") + if server_name: + data["server_name"] = server_name + tool_meta = tool_metadata.get("_meta") + if tool_meta: + data["tool_meta"] = tool_meta - yield StreamedChunk(type="tool_call", data=enriched) + def _tool_result_chunk_for_message( + self, + *, + tool_call_message: ToolMessage, + tool_name: str, + tool: Optional[StructuredTool], + token_handler: TokenHandler, + round_index: int, + ) -> tuple[int, StreamedChunk]: + """Convert a ToolMessage into a streamed tool_result chunk. - # Calculate remaining budget for tools - remaining_tool_budget = max_tokens_for_tools - tool_tokens_used - # Use the smaller of per-tool limit or remaining budget - effective_per_tool_limit = min( - max_tokens_per_tool, remaining_tool_budget - ) + Returns: + A tuple of (token_count_for_tool_content, streamed_tool_result_chunk). + """ + content_tokens = token_handler.text_to_tokens(str(tool_call_message.content)) + content_token_count = len(content_tokens) - logger.debug( - "Tool budget: used=%d, remaining=%d, per_tool_limit=%d", - tool_tokens_used, - remaining_tool_budget, - effective_per_tool_limit, - ) + was_truncated = tool_call_message.additional_kwargs.get("truncated", False) + base_status = tool_call_message.status + tool_status = "truncated" if was_truncated else base_status + has_meta = bool( + isinstance(tool.metadata, dict) and tool.metadata.get("_meta") + if tool is not None + else False + ) - # execute tools and add to messages - tool_calls_messages = await execute_tool_calls( - tool_calls, - all_mcp_tools, - max( - effective_per_tool_limit, 100 - ), # Minimum 100 tokens per tool - ) - messages.extend(tool_calls_messages) + logger.info( + json.dumps( + { + "event": "tool_result", + "tool_id": tool_call_message.tool_call_id, + "tool_name": tool_name, + "status": tool_status, + "truncated": was_truncated, + "has_meta": has_meta, + "output_snippet": str(tool_call_message.content)[:1000], + }, + ensure_ascii=False, + indent=2, + ) + ) - # Track tokens used by tool outputs - for tool_call_message in tool_calls_messages: - content_tokens = token_handler.text_to_tokens( - str(tool_call_message.content) - ) - tool_tokens_used += len(content_tokens) + tool_result_data: dict[str, Any] = { + "id": tool_call_message.tool_call_id, + "name": tool_name, + "status": tool_status, + "content": tool_call_message.content, + "type": ChunkType.TOOL_RESULT.value, + "round": round_index, + } + structured_content = tool_call_message.additional_kwargs.get( + "structured_content" + ) + if structured_content: + tool_result_data["structured_content"] = structured_content + self._enrich_with_tool_metadata(tool_result_data, tool) - for result_chunk in _build_tool_result_chunks( - tool_calls, tool_calls_messages, all_mcp_tools, i - ): - yield result_chunk + return content_token_count, StreamedChunk( + type=ChunkType.TOOL_RESULT, data=tool_result_data + ) + + async def _process_tool_calls_for_round( + self, + *, + round_index: int, + tool_call_chunks: list[AIMessageChunk], + all_tools_dict: dict[str, StructuredTool], + duplicate_tool_names: set[str], + messages: ChatPromptTemplate, + token_handler: TokenHandler, + tool_token_usage: ToolTokenUsage, + max_tokens_for_tools: int, + max_tokens_per_tool: int, + ) -> AsyncGenerator[StreamedChunk, None]: + """Resolve, execute, and stream one round of tool calls.""" + tool_tokens_used = tool_token_usage.used + + # Finalize streamed chunks into complete tool calls. + tool_calls = tool_calls_from_tool_calls_chunks(tool_call_chunks) + tool_call_definitions, skipped_tool_messages = ( + self._resolve_tool_call_definitions( + tool_calls, + all_tools_dict, + duplicate_tool_names, + ) + ) + if not tool_call_definitions and not skipped_tool_messages: + logger.warning( + "No executable tools resolved from tool calls in round %s", round_index + ) + tool_token_usage.used = tool_tokens_used + return + + # Persist the AI tool-call message for the next LLM turn. + ai_tool_call_message = AIMessage(content="", type="ai", tool_calls=tool_calls) + messages.append(ai_tool_call_message) + + # Charge token budget for the assistant tool-call message itself, so + # subsequent per-tool limits are computed from the remaining budget. + ai_message_tokens = TokenHandler._get_token_count( + token_handler.text_to_tokens(json.dumps(tool_calls)) + ) + tool_tokens_used += ai_message_tokens + + # Build a mapping from tool_call_id -> tool_name for result enrichment. + tool_id_to_name: dict[str, str] = { + str(tc.get("id", "")): str(tc.get("name", "unknown")) for tc in tool_calls + } + # Log and emit tool-call intents enriched with MCP metadata so the UI + # can associate calls with their server and preload resources. + for tool_call in tool_calls: + enriched: dict[str, Any] = {**tool_call} + tool_name = str(tool_call.get("name", "unknown")) + self._enrich_with_tool_metadata(enriched, all_tools_dict.get(tool_name)) logger.info( - "Tool loop completed: rounds_used=%d max_rounds=%d stop_reason=%s", - rounds_used, - max_rounds, - stop_reason, + json.dumps( + { + "event": "tool_call", + "tool_name": tool_name, + "arguments": tool_call.get("args", {}), + "tool_id": tool_call.get("id", "unknown"), + }, + ensure_ascii=False, + indent=2, + ) ) + yield StreamedChunk(type=ChunkType.TOOL_CALL, data=enriched) - def _get_max_iterations(self) -> int: - """Return configured max rounds for tool-calling loop.""" - return config.ols_config.max_iterations + # Derive per-tool execution cap from remaining global tool budget so we + # do not exceed max_tokens_for_tools across this request. + remaining_tool_budget = max_tokens_for_tools - tool_tokens_used + effective_per_tool_limit = min(max_tokens_per_tool, remaining_tool_budget) + logger.debug( + "Tool budget: used=%d, remaining=%d, per_tool_limit=%d", + tool_tokens_used, + remaining_tool_budget, + effective_per_tool_limit, + ) + + tool_calls_messages: list[ToolMessage] = [] + # Execute resolved tool calls and consume streamed execution events + # (approval prompts + final tool results). + if tool_call_definitions: + # Enforce strict global tool budget. If model config uses a lower + # max per-tool cap, lower the minimum accordingly so tools are not + # permanently skipped due to configuration mismatch. + minimum_required_tokens = min( + MIN_TOOL_EXECUTION_TOKENS, max_tokens_per_tool + ) + if effective_per_tool_limit < minimum_required_tokens: + logger.warning( + "Skipping %d tool call(s) in round %s due to low remaining tool budget " + "(remaining=%d, minimum_required=%d)", + len(tool_call_definitions), + round_index, + remaining_tool_budget, + minimum_required_tokens, + ) + # Emit synthetic tool results for skipped executions so client/UI + # and conversation state remain consistent (one call -> one outcome). + for tool_id, _tool_args, tool in tool_call_definitions: + tool_calls_messages.append( + ToolMessage( + content=( + f"Tool '{tool.name}' call skipped: remaining tool token budget " + f"({remaining_tool_budget}) is below minimum required " + f"({minimum_required_tokens}). " + "Do not retry this exact tool call." + ), + status="error", + tool_call_id=tool_id, + ) + ) + else: + tool_calls_dicts = [ + {"name": tool.name, "args": tool_args, "id": tool_id} + for tool_id, tool_args, tool in tool_call_definitions + ] + tool_calls_messages = await execute_tool_calls( + tool_calls_dicts, + list(all_tools_dict.values()), + effective_per_tool_limit, + ) + + # Merge synthetic skipped outcomes with real execution outcomes and + # append all of them to conversation state for the next LLM turn. + all_tool_messages = skipped_tool_messages + tool_calls_messages + messages.extend(all_tool_messages) + + for tool_call_message in all_tool_messages: + tool_name = tool_id_to_name.get(tool_call_message.tool_call_id, "unknown") + content_token_count, tool_result_chunk = ( + self._tool_result_chunk_for_message( + tool_call_message=tool_call_message, + tool_name=tool_name, + tool=all_tools_dict.get(tool_name), + token_handler=token_handler, + round_index=round_index, + ) + ) + tool_tokens_used += content_token_count + yield tool_result_chunk + + tool_token_usage.used = tool_tokens_used + + async def iterate_with_tools( # noqa: C901 + self, + messages: ChatPromptTemplate, + max_rounds: int, + llm_input_values: dict[str, str], + token_counter: GenericTokenCounter, + all_mcp_tools: list[StructuredTool], + ) -> AsyncGenerator[StreamedChunk, None]: + """Iterate through multiple rounds of LLM invocation with tool calling. + + Args: + messages: The initial messages + max_rounds: Maximum number of tool calling rounds + llm_input_values: Input values for the LLM + token_counter: Counter for tracking token usage + all_mcp_tools: All resolved MCP tools available for the request. + + Yields: + StreamedChunk objects representing parts of the response + """ + all_tools_dict: dict[str, StructuredTool] = {} + duplicate_tool_names: set[str] = set() + # Build a stable name->tool map once per request and disable ambiguous + # duplicates so a tool name resolves to at most one executable tool. + for tool in all_mcp_tools: + if tool.name in all_tools_dict: + duplicate_tool_names.add(tool.name) + else: + all_tools_dict[tool.name] = tool + for tool_name in duplicate_tool_names: + all_tools_dict.pop(tool_name, None) + if duplicate_tool_names: + logger.error( + "Duplicate MCP tool names detected and disabled: %s", + sorted(duplicate_tool_names), + ) + + # Track cumulative token usage for tool outputs + tool_tokens_used = 0 + max_tokens_for_tools = self.model_config.parameters.max_tokens_for_tools + max_tokens_per_tool = self.model_config.parameters.max_tokens_per_tool_output + token_handler = TokenHandler() + + # Account for tool definitions tokens (schemas sent to LLM) + if all_mcp_tools: + tool_definitions_text = json.dumps( + [ + {"name": t.name, "description": t.description, "schema": t.args} + for t in all_mcp_tools + ] + ) + tool_definitions_tokens = TokenHandler._get_token_count( + token_handler.text_to_tokens(tool_definitions_text) + ) + tool_tokens_used += tool_definitions_tokens + logger.debug("Tool definitions consume %d tokens", tool_definitions_tokens) + + # Tool calling in a loop + for i in range(1, max_rounds + 1): + # Final round must produce only the assistant answer (no more tool calls), + # either because tools are disabled or we reached the max tool-call rounds. + is_final_round = (not all_mcp_tools) or (i == max_rounds) + logger.debug("Tool calling round %s (final: %s)", i, is_final_round) + + # Phase 1: collect one LLM round (text chunks + potential tool-call chunks). + tool_call_chunks, round_streamed_chunks, should_stop_iteration = ( + await self._collect_round_llm_chunks( + messages=messages, + llm_input_values=llm_input_values, + all_mcp_tools=all_mcp_tools, + is_final_round=is_final_round, + token_counter=token_counter, + round_index=i, + ) + ) + # Emit all text chunks produced during this LLM round. + for streamed_chunk in round_streamed_chunks: + yield streamed_chunk + # Stop immediately when helper indicates terminal condition + # (final answer reached or timeout fallback emitted). + if should_stop_iteration: + return + + # exit if this was the final round + if is_final_round: + break + + # tool calling part + if tool_call_chunks: + # Phase 2: resolve and execute tool calls for this round. + tool_token_usage = ToolTokenUsage(used=tool_tokens_used) + # No outer timeout here — each MCP server enforces its own + # configured timeout at the transport layer. + try: + async for streamed_chunk in self._process_tool_calls_for_round( + round_index=i, + tool_call_chunks=tool_call_chunks, + all_tools_dict=all_tools_dict, + duplicate_tool_names=duplicate_tool_names, + messages=messages, + token_handler=token_handler, + tool_token_usage=tool_token_usage, + max_tokens_for_tools=max_tokens_for_tools, + max_tokens_per_tool=max_tokens_per_tool, + ): + yield streamed_chunk + except Exception: + logger.exception("Error executing tool calls in round %s", i) + yield StreamedChunk( + type=ChunkType.TEXT, + text=( + "I could not complete this request. " "Please try again." + ), + ) + return + tool_tokens_used = tool_token_usage.used async def generate_response( self, query: str, rag_retriever: Optional[BaseRetriever] = None, - history: Optional[MessageHistory] = None, + history: Optional[list[BaseMessage]] = None, ) -> AsyncGenerator[StreamedChunk, None]: """Generate a response for the given query. @@ -614,10 +824,7 @@ async def generate_response( query, rag_retriever, history ) messages = final_prompt.model_copy() - - # Get all MCP tools (will handle tools_rag population and filtering) all_mcp_tools = await get_mcp_tools(query, self.user_token, self.client_headers) - with TokenMetricUpdater( llm=self.bare_llm, provider=self.provider_config.type, @@ -633,7 +840,7 @@ async def generate_response( yield response yield StreamedChunk( - type="end", + type=ChunkType.END, data={ "rag_chunks": rag_chunks, "truncated": truncated, @@ -641,48 +848,43 @@ async def generate_response( }, ) + def _get_max_iterations(self) -> int: + """Return configured max rounds for tool-calling loop.""" + return config.ols_config.max_iterations + def create_response( self, query: str, rag_retriever: Optional[BaseRetriever] = None, - history: Optional[MessageHistory] = None, + history: Optional[list[BaseMessage]] = None, ) -> SummarizerResponse: """Create a synchronous response for the given query. - This method wraps the asynchronous generate_response method to provide - a synchronous interface. - - Args: - query: The query to be answered - rag_retriever: Retriever for RAG context - history: Optional conversation history - - Returns: - A SummarizerResponse object containing the complete response + This method drains the async response stream and aggregates it into + a SummarizerResponse for non-streaming callers. """ async def drain_generate_response() -> SummarizerResponse: - """Inner async function to collect all response chunks.""" - chunks = [] - response_end: dict[str, Any] = {} - tool_calls = [] - tool_results = [] + """Collect all generated chunks into a single response object.""" + chunks: list[str] = [] + response_end: dict[str, object] = {} + tool_calls: list[dict[str, object]] = [] + tool_results: list[dict[str, object]] = [] async for chunk in self.generate_response(query, rag_retriever, history): - if chunk.type == "end": - response_end = chunk.data - break - if chunk.type == "tool_call": - tool_calls.append(chunk.data) - elif chunk.type == "tool_result": - tool_results.append(chunk.data) - elif chunk.type == "text": - chunks.append(chunk.text) - else: - # this "can't" happen as we control what chunk types - # are yielded in the generator directly - msg = f"Unknown chunk type: {chunk.type}" - logger.warning(msg) - raise ValueError(msg) + match chunk.type: + case ChunkType.END: + response_end = chunk.data + break + case ChunkType.TOOL_CALL: + tool_calls.append(chunk.data) + case ChunkType.TOOL_RESULT: + tool_results.append(chunk.data) + case ChunkType.TEXT: + chunks.append(chunk.text) + case _: + msg = f"Unknown chunk type: {chunk.type}" + logger.warning(msg) + raise ValueError(msg) return SummarizerResponse( response="".join(chunks), @@ -693,6 +895,4 @@ async def drain_generate_response() -> SummarizerResponse: tool_results=tool_results, ) - # TODO: if we define the non-streaming endpoint as async, we don't - # need to handle any of this, we would just await it return run_async_safely(drain_generate_response()) diff --git a/tests/unit/app/endpoints/test_streaming_ols.py b/tests/unit/app/endpoints/test_streaming_ols.py index 070f25c82..ba4a2fba1 100644 --- a/tests/unit/app/endpoints/test_streaming_ols.py +++ b/tests/unit/app/endpoints/test_streaming_ols.py @@ -5,14 +5,13 @@ import pytest from ols import config, constants +from ols.app.models.models import ChunkType # needs to be setup there before is_user_authorized is imported config.ols_config.authentication_config.module = "k8s" from ols.app.endpoints.streaming_ols import ( # noqa:E402 - LLM_TOKEN_EVENT, - LLM_TOOL_CALL_EVENT, - LLM_TOOL_RESULT_EVENT, + TOKEN_KEY_TOKEN, build_referenced_docs, format_stream_data, generic_llm_error, @@ -41,9 +40,9 @@ def _load_config(): def test_event_type_are_not_changed(): """Test that event types are not changed.""" - assert LLM_TOKEN_EVENT == "token" # noqa: S105 - assert LLM_TOOL_CALL_EVENT == "tool_call" - assert LLM_TOOL_RESULT_EVENT == "tool_result" + assert TOKEN_KEY_TOKEN == "token" # noqa: S105 + assert ChunkType.TOOL_CALL.value == "tool_call" + assert ChunkType.TOOL_RESULT.value == "tool_result" def test_format_stream_data(): @@ -59,27 +58,27 @@ def test_stream_event(): data = {"token": "hi", "idx": 1} # text output - assert stream_event(data, LLM_TOKEN_EVENT, constants.MEDIA_TYPE_TEXT) == "hi" + assert stream_event(data, TOKEN_KEY_TOKEN, constants.MEDIA_TYPE_TEXT) == "hi" assert ( - stream_event(data, LLM_TOOL_CALL_EVENT, constants.MEDIA_TYPE_TEXT) + stream_event(data, ChunkType.TOOL_CALL.value, constants.MEDIA_TYPE_TEXT) == '\nTool call: {"token": "hi", "idx": 1}\n' ) assert ( - stream_event(data, LLM_TOOL_RESULT_EVENT, constants.MEDIA_TYPE_TEXT) + stream_event(data, ChunkType.TOOL_RESULT.value, constants.MEDIA_TYPE_TEXT) == '\nTool result: {"token": "hi", "idx": 1}\n' ) # json output assert ( - stream_event(data, LLM_TOKEN_EVENT, constants.MEDIA_TYPE_JSON) + stream_event(data, TOKEN_KEY_TOKEN, constants.MEDIA_TYPE_JSON) == 'data: {"event": "token", "data": {"token": "hi", "idx": 1}}\n\n' ) assert ( - stream_event(data, LLM_TOOL_CALL_EVENT, constants.MEDIA_TYPE_JSON) + stream_event(data, ChunkType.TOOL_CALL.value, constants.MEDIA_TYPE_JSON) == 'data: {"event": "tool_call", "data": {"token": "hi", "idx": 1}}\n\n' ) assert ( - stream_event(data, LLM_TOOL_RESULT_EVENT, constants.MEDIA_TYPE_JSON) + stream_event(data, ChunkType.TOOL_RESULT.value, constants.MEDIA_TYPE_JSON) == 'data: {"event": "tool_result", "data": {"token": "hi", "idx": 1}}\n\n' ) diff --git a/tests/unit/query_helpers/test_docs_summarizer.py b/tests/unit/query_helpers/test_docs_summarizer.py index 69f5952a3..75245e7b3 100644 --- a/tests/unit/query_helpers/test_docs_summarizer.py +++ b/tests/unit/query_helpers/test_docs_summarizer.py @@ -1,54 +1,62 @@ -"""Unit tests for DocsSummarizer class.""" +"""Unit tests for DocsSummarizer PR2 class.""" -import json -import re -from math import ceil -from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch +import asyncio +import logging +from typing import ClassVar +from unittest.mock import ANY, AsyncMock, patch import pytest -from langchain_core.messages import HumanMessage +from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.messages.ai import AIMessageChunk +from langchain_core.tools.structured import StructuredTool +from pydantic import BaseModel from ols import config -from ols.app.models.config import MCPServerConfig -from ols.constants import TOKEN_BUFFER_WEIGHT -from ols.utils.mcp_utils import _normalize_tool_schema, gather_mcp_tools -from ols.utils.token_handler import TokenHandler -from tests.mock_classes.mock_tools import ( - MOCK_TOOL_META, - NAMESPACES_OUTPUT, - POD_STRUCTURED_CONTENT, - mock_tools_map, - mock_tools_with_meta, - mock_tools_with_structured_content, -) +from ols.app.models.config import LoggingConfig, MCPServerConfig +from ols.app.models.models import ChunkType -# needs to be setup there before is_user_authorized is imported +# needs to be setup before importing docs_summarizer config.ols_config.authentication_config.module = "k8s" - - -from ols.src.query_helpers.docs_summarizer import ( # noqa:E402 +from ols.src.query_helpers.docs_summarizer import ( # noqa: E402 DocsSummarizer, QueryHelper, + ToolTokenUsage, ) -from ols.utils import suid # noqa:E402 -from tests import constants # noqa:E402 -from tests.mock_classes.mock_langchain_interface import ( # noqa:E402 +from ols.utils.logging_configurator import configure_logging # noqa: E402 +from ols.utils.mcp_utils import build_mcp_config, gather_mcp_tools # noqa: E402 +from ols.utils.token_handler import TokenHandler # noqa: E402 +from tests import constants # noqa: E402 +from tests.mock_classes.mock_langchain_interface import ( # noqa: E402 mock_langchain_interface, ) -from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa:E402 -from tests.mock_classes.mock_retrievers import MockRetriever # noqa:E402 +from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa: E402 +from tests.mock_classes.mock_retrievers import MockRetriever # noqa: E402 +from tests.mock_classes.mock_tools import mock_tools_map # noqa: E402 -conversation_id = suid.get_suid() +class SampleTool(StructuredTool): + """Simple structured tool used for targeted docs_summarizer tests.""" -def test_is_query_helper_subclass(): - """Test that DocsSummarizer is a subclass of QueryHelper.""" - assert issubclass(DocsSummarizer, QueryHelper) + def __init__(self, name: str, description: str = "sample tool") -> None: + """Initialize simple fake structured tool.""" + + class _Schema(BaseModel): + pass + async def _coro(**kwargs): + return "ok" -def check_summary_result(summary, question): - """Check result produced by DocsSummarizer.summary method.""" + super().__init__( + name=name, + description=description, + func=lambda **kwargs: "ok", + coroutine=_coro, + args_schema=_Schema, + ) + + +def check_summary_result(summary, question: str) -> None: + """Check result produced by DocsSummarizer.create_response method.""" assert question in summary.response assert isinstance(summary.rag_chunks, list) assert len(summary.rag_chunks) == 1 @@ -67,44 +75,43 @@ def _setup(): config.reload_from_yaml_file("tests/config/valid_config_without_mcp.yaml") +def test_is_query_helper_subclass(): + """Test that DocsSummarizer is a subclass of QueryHelper.""" + assert issubclass(DocsSummarizer, QueryHelper) + + def test_if_system_prompt_was_updated(): - """Test if system prompt was overided from the configuration.""" + """Test if system prompt was overridden from the configuration.""" summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - # expected prompt was loaded during configuration phase - expected_prompt = config.ols_config.system_prompt - assert summarizer._system_prompt == expected_prompt + assert summarizer._system_prompt == config.ols_config.system_prompt def test_summarize_empty_history(): - """Basic test for DocsSummarizer using mocked index and query engine.""" + """Basic test for DocsSummarizer using mocked retriever and empty history.""" with ( patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4), patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 1), ): summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) question = "What's the ultimate question with answer 42?" - rag_retriever = MockRetriever() - history = [] # empty history - summary = summarizer.create_response(question, rag_retriever, history) + summary = summarizer.create_response(question, MockRetriever(), []) check_summary_result(summary, question) def test_summarize_no_history(): - """Basic test for DocsSummarizer using mocked index and query engine, no history is provided.""" + """Basic test for DocsSummarizer without explicit history argument.""" with ( patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4), patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 3), ): summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) question = "What's the ultimate question with answer 42?" - rag_retriever = MockRetriever() - # no history is passed into summarize() method - summary = summarizer.create_response(question, rag_retriever) + summary = summarizer.create_response(question, MockRetriever()) check_summary_result(summary, question) def test_summarize_history_provided(): - """Basic test for DocsSummarizer using mocked index and query engine, history is provided.""" + """Basic test with explicit history vs default history paths.""" with ( patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4), patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 3), @@ -114,7 +121,6 @@ def test_summarize_history_provided(): history = ["human: What is Kubernetes?"] rag_retriever = MockRetriever() - # first call with history provided with patch( "ols.src.query_helpers.docs_summarizer.TokenHandler.limit_conversation_history", return_value=([], False), @@ -123,7 +129,6 @@ def test_summarize_history_provided(): token_handler.assert_called_once_with(history, ANY) check_summary_result(summary1, question) - # second call without history provided with patch( "ols.src.query_helpers.docs_summarizer.TokenHandler.limit_conversation_history", return_value=([], False), @@ -134,22 +139,19 @@ def test_summarize_history_provided(): def test_summarize_truncation(): - """Basic test for DocsSummarizer to check if truncation is done.""" + """Basic truncation check for very long history.""" with patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4): summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - question = "What's the ultimate question with answer 42?" - rag_retriever = MockRetriever() - - # too long history - history = [HumanMessage("What is Kubernetes?")] * 10000 - summary = summarizer.create_response(question, rag_retriever, history) - - # truncation should be done + summary = summarizer.create_response( + "What's the ultimate question with answer 42?", + MockRetriever(), + [HumanMessage("What is Kubernetes?")] * 10000, + ) assert summary.history_truncated def test_summarize_no_reference_content(): - """Basic test for DocsSummarizer using mocked index and query engine.""" + """Basic test when no retriever is provided.""" summarizer = DocsSummarizer( llm_loader=mock_llm_loader(mock_langchain_interface("test response")()) ) @@ -160,6 +162,24 @@ def test_summarize_no_reference_content(): assert not summary.history_truncated +def test_summarize_retrieval_logging(caplog): + """Basic test to ensure retrieval path is visible in logs.""" + logging_config = LoggingConfig(app_log_level="debug") + configure_logging(logging_config) + logger = logging.getLogger("ols") + logger.handlers = [caplog.handler] + + with ( + patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4), + patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 3), + ): + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + question = "What's the ultimate question with answer 42?" + summary = summarizer.create_response(question, MockRetriever()) + check_summary_result(summary, question) + assert "Retrieved 1 documents from indexes" in caplog.text + + @pytest.mark.asyncio async def test_response_generator(): """Test response generator method.""" @@ -167,10 +187,9 @@ async def test_response_generator(): llm_loader=mock_llm_loader(mock_langchain_interface("test response")()) ) question = "What's the ultimate question with answer 42?" - summary_gen = summarizer.generate_response(question) generated_content = "" - async for item in summary_gen: + async for item in summarizer.generate_response(question): generated_content += item.text assert generated_content == question @@ -184,8 +203,6 @@ async def async_mock_invoke(yield_values): def test_tool_calling_one_iteration(): """Test tool calling - stops after one iteration.""" - question = "How many namespaces are there in my cluster?" - with patch( "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" ) as mock_invoke: @@ -193,7 +210,8 @@ def test_tool_calling_one_iteration(): [AIMessageChunk(content="XYZ", response_metadata={"finish_reason": "stop"})] ) summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.create_response(question) + summarizer._tool_calling_enabled = True + summarizer.create_response("How many namespaces are there in my cluster?") assert mock_invoke.call_count == 1 @@ -221,59 +239,39 @@ def test_tool_calling_drains_chunks_after_stop(): async def fake_invoke_llm(*args, **kwargs): - """Fake invoke_llm function to simulate LLM behavior. - - Yields depends on the number of calls - """ - # use an attribute on the function to track calls + """Fake invoke_llm function to simulate two-turn LLM behavior.""" if not hasattr(fake_invoke_llm, "call_count"): fake_invoke_llm.call_count = 0 fake_invoke_llm.call_count += 1 if fake_invoke_llm.call_count == 1: - # first call yields a message that requests tool calls yield AIMessageChunk( content="", response_metadata={"finish_reason": "tool_calls"} ) elif fake_invoke_llm.call_count == 2: - # second call yields the final message. yield AIMessageChunk(content="XYZ", response_metadata={"finish_reason": "stop"}) - else: - # extra - yield AIMessageChunk( - content="Extra", response_metadata={"finish_reason": "extra"} - ) def test_tool_calling_two_iteration(): """Test tool calling - stops after two iterations.""" - question = "How many namespaces are there in my cluster?" - with ( patch( "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm", new=fake_invoke_llm, ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, + patch( + "ols.src.query_helpers.docs_summarizer.get_mcp_tools", + new=AsyncMock(return_value=mock_tools_map), + ), ): - # Mock config for get_mcp_tools - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] # Non-empty list - - # Mock _gather_and_populate_tools to return tools - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=({"test": {}}, mock_tools_map)), - ): - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.create_response(question) - assert mock_invoke.call_count == 2 + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + summarizer._tool_calling_enabled = True + summarizer.create_response("How many namespaces are there in my cluster?") + assert mock_invoke.call_count == 2 def test_tool_calling_force_stop(): - """Test tool calling - force stop.""" - question = "How many namespaces are there in my cluster?" - + """Test tool calling - force stop by max rounds.""" with ( patch( "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", @@ -282,105 +280,27 @@ def test_tool_calling_force_stop(): patch( "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, - ): - # Mock config for get_mcp_tools - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] # Non-empty list - - # Mock _gather_and_populate_tools to return tools - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=({"test": {}}, mock_tools_map)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="XYZ", response_metadata={"finish_reason": "tool_calls"} - ) - ] - ) - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.create_response(question) - assert mock_invoke.call_count == 3 - - -def test_tool_calling_tool_execution(caplog): - """Test tool calling - tool execution.""" - caplog.set_level(10) # Set debug level - - question = "How many namespaces are there in my cluster?" - - mcp_servers_config = { - "test_server": { - "transport": "streamable_http", - "url": "http://test-server:8080/mcp", - }, - } - - with ( patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", - return_value=2, + "ols.src.query_helpers.docs_summarizer.get_mcp_tools", + new=AsyncMock(return_value=mock_tools_map), ), - patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" - ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, ): - # Mock config for get_mcp_tools - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] # Non-empty list - - # Mock _gather_and_populate_tools to return tools - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=(mcp_servers_config, mock_tools_map)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_namespaces_mock", - "args": {}, - "id": "call_id1", - }, - { - "name": "invalid_function_name", - "args": {}, - "id": "call_id2", - }, - ], - ) - ] - ) - - # Create mock MCP client - now get_tools is called with server_name parameter - mock_mcp_client_instance = AsyncMock() - mock_mcp_client_instance.get_tools.return_value = mock_tools_map - mock_mcp_client_cls.return_value = mock_mcp_client_instance - + mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( + [ + AIMessageChunk( + content="XYZ", response_metadata={"finish_reason": "tool_calls"} + ) + ] + ) summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - # Disable token reservation for tools in this test (test config has small context window) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - summarizer.create_response(question) - - assert "Tool: get_namespaces_mock" in caplog.text - assert f"Output: {NAMESPACES_OUTPUT}" in caplog.text - - assert "Error: Tool 'invalid_function_name' not found." in caplog.text - - assert mock_invoke.call_count == 2 - + summarizer._tool_calling_enabled = True + summarizer.create_response("How many namespaces are there in my cluster?") + assert mock_invoke.call_count == 3 -def test_tool_result_includes_structured_content(): - """Test that tool_result chunks include structured_content from artifact.""" - question = "What are pod metrics?" +def test_tool_calling_tool_execution(caplog): + """Test tool execution path with one valid and one invalid tool call.""" + caplog.set_level(10) mcp_servers_config = { "test_server": { "transport": "streamable_http", @@ -395,213 +315,51 @@ def test_tool_result_includes_structured_content(): ), patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" - ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, - ): - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] - - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock( - return_value=(mcp_servers_config, mock_tools_with_structured_content) - ), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_pod_metrics_mock", - "args": {}, - "id": "call_pod1", - }, - ], - ) - ] - ) - - mock_mcp_client_instance = AsyncMock() - mock_mcp_client_instance.get_tools.return_value = ( - mock_tools_with_structured_content - ) - mock_mcp_client_cls.return_value = mock_mcp_client_instance - - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - result = summarizer.create_response(question) - - assert len(result.tool_results) == 1 - tool_result = result.tool_results[0] - - assert tool_result["id"] == "call_pod1" - assert tool_result["structured_content"] == POD_STRUCTURED_CONTENT - - -def test_tool_result_without_structured_content_has_no_key(): - """Test that tool_result omits structured_content when not present.""" - question = "How many namespaces are there in my cluster?" - - mcp_servers_config = { - "test_server": { - "transport": "streamable_http", - "url": "http://test-server:8080/mcp", - }, - } - - with ( - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", - return_value=2, + "ols.src.query_helpers.docs_summarizer.TokenHandler.calculate_and_check_available_tokens", + return_value=1000, ), - patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, patch( "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, - ): - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] - - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=(mcp_servers_config, mock_tools_map)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_namespaces_mock", - "args": {}, - "id": "call_ns1", - }, - ], - ) - ] - ) - - mock_mcp_client_instance = AsyncMock() - mock_mcp_client_instance.get_tools.return_value = mock_tools_map - mock_mcp_client_cls.return_value = mock_mcp_client_instance - - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - result = summarizer.create_response(question) - - assert len(result.tool_results) == 1 - tool_result = result.tool_results[0] - - assert tool_result["id"] == "call_ns1" - assert "structured_content" not in tool_result - - -def test_tool_token_tracking(caplog): - """Test that tool definitions and AIMessage tokens are tracked with buffer weight.""" - caplog.set_level(10) # Set debug level - - question = "How many namespaces are there in my cluster?" - - mcp_servers_config = { - "test_server": { - "transport": "streamable_http", - "url": "http://test-server:8080/mcp", - }, - } - - with ( patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", - return_value=2, + "ols.src.query_helpers.docs_summarizer.build_mcp_config", + return_value=mcp_servers_config, ), - patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" - ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, + "ols.src.query_helpers.docs_summarizer.get_mcp_tools", + new=AsyncMock(return_value=mock_tools_map), + ), ): - # Mock config for get_mcp_tools - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] # Non-empty list - - # Mock _gather_and_populate_tools to return tools - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=(mcp_servers_config, mock_tools_map)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_namespaces_mock", - "args": {}, - "id": "call_id1", - }, - ], - ) - ] - ) + mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( + [ + AIMessageChunk( + content="", + response_metadata={"finish_reason": "tool_calls"}, + tool_calls=[ + {"name": "get_namespaces_mock", "args": {}, "id": "call_id1"}, + {"name": "invalid_function_name", "args": {}, "id": "call_id2"}, + ], + ) + ] + ) mock_mcp_client_instance = AsyncMock() mock_mcp_client_instance.get_tools.return_value = mock_tools_map mock_mcp_client_cls.return_value = mock_mcp_client_instance summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - # Disable token reservation for tools (test config has small context window) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - summarizer.create_response(question) - - # Verify tool definitions token counting is logged - assert "Tool definitions consume" in caplog.text + summarizer.model_config.parameters.max_tokens_for_tools = 100 + summarizer.create_response("How many namespaces are there in my cluster?") - # Calculate expected token count with buffer weight applied - token_handler = TokenHandler() - tool_definitions_text = json.dumps( - [ - { - "name": t.name, - "description": t.description, - "schema": ( - t.args_schema - if isinstance(t.args_schema, dict) - else t.args_schema.model_json_schema() - ), - } - for t in mock_tools_map - ] - ) - raw_tokens = len(token_handler.text_to_tokens(tool_definitions_text)) - expected_buffered_tokens = ceil(raw_tokens * TOKEN_BUFFER_WEIGHT) - - # Extract logged token count and verify buffer weight was applied - match = re.search(r"Tool definitions consume (\d+) tokens", caplog.text) - assert match is not None, "Token count not found in logs" - logged_tokens = int(match.group(1)) - assert logged_tokens == expected_buffered_tokens, ( - f"Expected {expected_buffered_tokens} (raw={raw_tokens} * {TOKEN_BUFFER_WEIGHT}), " - f"got {logged_tokens}" - ) + assert "get_namespaces_mock" in caplog.text + assert "invalid_function_name" in caplog.text + assert mock_invoke.call_count == 2 @pytest.mark.asyncio async def test_gather_mcp_tools_failure_isolation(caplog): - """Test gather_mcp_tools isolates failures from individual MCP servers. - - When multiple MCP servers are configured and one is unreachable, - tools from the working servers should still be returned. - """ - from ols.utils.mcp_utils import gather_mcp_tools - + """Test gather_mcp_tools isolates failures from individual MCP servers.""" caplog.set_level(10) - mcp_servers = { "working_server": { "transport": "streamable_http", @@ -613,653 +371,400 @@ async def test_gather_mcp_tools_failure_isolation(caplog): }, } - # Mock MultiServerMCPClient.get_tools to simulate per-server behavior async def mock_get_tools(server_name=None): if server_name == "working_server": return mock_tools_map - elif server_name == "broken_server": - raise ConnectionError("Failed to connect to http://non-exist:8888/mcp") - return [] + raise ConnectionError("Failed to connect to http://non-exist:8888/mcp") with patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_client_cls: mock_client_instance = AsyncMock() mock_client_instance.get_tools.side_effect = mock_get_tools mock_client_cls.return_value = mock_client_instance - # Call gather_mcp_tools - should return tools from working server - # even though broken_server fails tools = await gather_mcp_tools(mcp_servers) - - # Verify we got tools from the working server assert len(tools) == 1 assert tools[0].name == "get_namespaces_mock" - - # Verify logging shows partial success assert "Loaded 1 tools from MCP server 'working_server'" in caplog.text assert "Failed to get tools from MCP server 'broken_server'" in caplog.text -@pytest.mark.asyncio -async def test_gather_mcp_tools_all_servers_working(caplog): - """Test gather_mcp_tools aggregates tools from all working servers.""" - from ols.utils.mcp_utils import gather_mcp_tools - - caplog.set_level(10) - - mcp_servers = { - "server_a": {"transport": "streamable_http", "url": "http://server-a:8080/mcp"}, - "server_b": {"transport": "streamable_http", "url": "http://server-b:8080/mcp"}, - } - - async def mock_get_tools(server_name=None): - # Both servers return tools successfully - return mock_tools_map - - with patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_client_cls: - mock_client_instance = AsyncMock() - mock_client_instance.get_tools.side_effect = mock_get_tools - mock_client_cls.return_value = mock_client_instance - - tools = await gather_mcp_tools(mcp_servers) - - # Should have tools from both servers (2 x 1 = 2 tools) - assert len(tools) == 2 - assert "Loaded 1 tools from MCP server 'server_a'" in caplog.text - assert "Loaded 1 tools from MCP server 'server_b'" in caplog.text - - -@pytest.mark.asyncio -async def test_gather_mcp_tools_all_servers_failing(caplog): - """Test gather_mcp_tools handles all servers failing gracefully.""" - from ols.utils.mcp_utils import gather_mcp_tools - - caplog.set_level(10) - - mcp_servers = { - "broken_a": {"transport": "streamable_http", "url": "http://broken-a:8888/mcp"}, - "broken_b": {"transport": "streamable_http", "url": "http://broken-b:8888/mcp"}, - } - - async def mock_get_tools(server_name=None): - raise ConnectionError(f"Failed to connect to {server_name}") - - with patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_client_cls: - mock_client_instance = AsyncMock() - mock_client_instance.get_tools.side_effect = mock_get_tools - mock_client_cls.return_value = mock_client_instance - - tools = await gather_mcp_tools(mcp_servers) - - # Should return empty list, not raise exception - assert tools == [] - assert "Failed to get tools from MCP server 'broken_a'" in caplog.text - assert "Failed to get tools from MCP server 'broken_b'" in caplog.text - - -@pytest.mark.asyncio -async def test_gather_mcp_tools_empty_config(): - """Test gather_mcp_tools with no servers configured.""" - from ols.utils.mcp_utils import gather_mcp_tools - - with patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_client_cls: - mock_client_instance = AsyncMock() - mock_client_cls.return_value = mock_client_instance - - tools = await gather_mcp_tools({}) - - # Should return empty list - assert tools == [] - # get_tools should never be called - mock_client_instance.get_tools.assert_not_called() - - -def test_normalize_tool_schema_adds_missing_properties_and_required(): - """Test _normalize_tool_schema patches a bare object schema.""" - tool = Mock() - tool.args_schema = {"type": "object"} - - _normalize_tool_schema(tool) - - assert tool.args_schema == {"type": "object", "properties": {}, "required": []} - - -def test_normalize_tool_schema_preserves_existing_properties(): - """Test _normalize_tool_schema does not overwrite existing properties.""" - tool = Mock() - tool.args_schema = { - "type": "object", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - } - - _normalize_tool_schema(tool) - - assert tool.args_schema == { - "type": "object", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - } - - -def test_normalize_tool_schema_skips_non_dict_schema(): - """Test _normalize_tool_schema is a no-op when args_schema is not a dict.""" - - class MySchema: - pass - - tool = Mock() - tool.args_schema = MySchema - - _normalize_tool_schema(tool) - - assert tool.args_schema is MySchema - - -def test_normalize_tool_schema_skips_non_object_type(): - """Test _normalize_tool_schema is a no-op for non-object schemas.""" - tool = Mock() - tool.args_schema = {"type": "string"} - - _normalize_tool_schema(tool) - - assert tool.args_schema == {"type": "string"} - - -@pytest.mark.asyncio -async def test_gather_mcp_tools_fixes_schemas_without_properties(): - """Test gather_mcp_tools patches tool schemas that lack 'properties'. - - MCP tools with no arguments produce schemas like {"type": "object"} - without a "properties" key. This causes KeyError in LangChain and - 400 errors from OpenAI. gather_mcp_tools must fix these schemas. - """ - mcp_servers = { - "server": {"transport": "streamable_http", "url": "http://server:8080/mcp"}, - } - - no_args_tool = Mock() - no_args_tool.args_schema = {"type": "object"} - no_args_tool.name = "no_args_tool" - no_args_tool.metadata = {} - - with_args_tool = Mock() - with_args_tool.args_schema = { - "type": "object", - "properties": {"query": {"type": "string"}}, - } - with_args_tool.name = "with_args_tool" - with_args_tool.metadata = {} - - async def mock_get_tools(server_name=None): - return [no_args_tool, with_args_tool] - - with patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_client_cls: - mock_client_instance = AsyncMock() - mock_client_instance.get_tools.side_effect = mock_get_tools - mock_client_cls.return_value = mock_client_instance - - tools = await gather_mcp_tools(mcp_servers) - - assert len(tools) == 2 - - assert tools[0].args_schema == { - "type": "object", - "properties": {}, - "required": [], - } - - assert tools[1].args_schema == { - "type": "object", - "properties": {"query": {"type": "string"}}, - "required": [], - } - - def test_build_mcp_config_transport_is_streamable_http(): """Test build_mcp_config sets transport to streamable_http for all servers.""" - from ols.utils.mcp_utils import build_mcp_config - server1 = MCPServerConfig(name="server1", url="http://server1:8080/mcp") server1._resolved_headers = {} - server2 = MCPServerConfig(name="server2", url="http://server2:9090/mcp", timeout=30) server2._resolved_headers = {} - mcp_config = build_mcp_config( - [server1, server2], user_token=None, client_headers=None - ) - - assert "server1" in mcp_config - assert "server2" in mcp_config + mcp_config = build_mcp_config([server1, server2], None, None) assert mcp_config["server1"]["transport"] == "streamable_http" - assert mcp_config["server1"]["url"] == "http://server1:8080/mcp" - assert mcp_config["server2"]["transport"] == "streamable_http" - assert mcp_config["server2"]["url"] == "http://server2:9090/mcp" - assert mcp_config["server2"]["timeout"] == 30 - - -def test_resolve_server_headers_with_client_placeholder(): - """Test resolve_server_headers replaces client placeholder with client headers.""" - from ols.constants import MCP_CLIENT_PLACEHOLDER - from ols.utils.mcp_utils import resolve_server_headers - - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "_client_"}, - ) - server._resolved_headers = {"Authorization": MCP_CLIENT_PLACEHOLDER} - - client_headers = {"test-server": {"Authorization": "Bearer client-token"}} - - headers = resolve_server_headers( - server, user_token=None, client_headers=client_headers - ) - - assert headers is not None - assert headers == {"Authorization": "Bearer client-token"} - -def test_resolve_server_headers_with_kubernetes_placeholder(): - """Test resolve_server_headers replaces kubernetes placeholder with user token.""" - from ols.constants import MCP_KUBERNETES_PLACEHOLDER - from ols.utils.mcp_utils import resolve_server_headers - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "kubernetes"}, - ) - server._resolved_headers = {"Authorization": MCP_KUBERNETES_PLACEHOLDER} - - headers = resolve_server_headers( - server, user_token="user-k8s-token", client_headers=None # noqa: S106 # nosec - ) - - assert headers is not None - assert headers == {"Authorization": "Bearer user-k8s-token"} - - -def test_resolve_server_headers_missing_client_headers(): - """Test resolve_server_headers returns None when client headers missing.""" - from ols.constants import MCP_CLIENT_PLACEHOLDER - from ols.utils.mcp_utils import resolve_server_headers - - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "_client_"}, +def test_resolve_tool_call_definitions_targeted_paths(): + """Test targeted paths in _resolve_tool_call_definitions helper.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + all_tools_dict = {"get_namespaces_mock": mock_tools_map[0]} + duplicate_tool_names = {"dup_tool"} + tool_calls = [ + {"name": None, "args": {}, "id": "missing_name"}, + {"name": "dup_tool", "args": {}, "id": "duplicate"}, + {"name": "not_found", "args": {}, "id": "unavailable"}, + {"name": "get_namespaces_mock", "args": "bad", "id": "bad_args"}, + {"name": "get_namespaces_mock", "args": {"ok": True}, "id": "valid"}, + ] + + definitions, skipped = summarizer._resolve_tool_call_definitions( + tool_calls, all_tools_dict, duplicate_tool_names ) - server._resolved_headers = {"Authorization": MCP_CLIENT_PLACEHOLDER} - # No client headers provided - headers = resolve_server_headers(server, user_token=None, client_headers=None) + assert len(definitions) == 1 + assert definitions[0][0] == "valid" + assert definitions[0][1] == {"ok": True} + assert definitions[0][2] is mock_tools_map[0] - assert headers is None + assert len(skipped) == 4 + skipped_ids = {msg.tool_call_id for msg in skipped} + assert skipped_ids == {"missing_name", "duplicate", "unavailable", "bad_args"} -def test_resolve_server_headers_missing_kubernetes_token(): - """Test resolve_server_headers returns None when kubernetes token missing.""" - from ols.constants import MCP_KUBERNETES_PLACEHOLDER - from ols.utils.mcp_utils import resolve_server_headers +@pytest.mark.asyncio +async def test_collect_round_llm_chunks_targeted_paths(): + """Test _collect_round_llm_chunks returns chunks/text/stop as expected.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "kubernetes"}, - ) - server._resolved_headers = {"Authorization": MCP_KUBERNETES_PLACEHOLDER} + async def _fake_invoke(*args, **kwargs): + yield AIMessageChunk(content="hello", response_metadata={}) + yield AIMessageChunk( + content="", + response_metadata={"finish_reason": "tool_calls"}, + tool_call_chunks=[ + {"name": "get_namespaces_mock", "args": "{}", "id": "call_1"} + ], + tool_calls=[{"name": "get_namespaces_mock", "args": {}, "id": "call_1"}], + ) - # No user token provided - headers = resolve_server_headers(server, user_token=None, client_headers=None) + with patch.object(summarizer, "_invoke_llm", side_effect=_fake_invoke): + tool_call_chunks, streamed_chunks, should_stop = ( + await summarizer._collect_round_llm_chunks( + messages=[], + llm_input_values={}, + all_mcp_tools=mock_tools_map, + is_final_round=False, + token_counter=AsyncMock(), + round_index=1, + ) + ) - assert headers is None + assert should_stop is False + assert len(streamed_chunks) == 1 + assert streamed_chunks[0].type == "text" + assert streamed_chunks[0].text == "hello" + assert len(tool_call_chunks) == 1 -def test_resolve_server_headers_with_multiple_client_header_dicts(): - """Test resolve_server_headers handles multiple headers in dict.""" - from ols.constants import MCP_CLIENT_PLACEHOLDER - from ols.utils.mcp_utils import resolve_server_headers +@pytest.mark.asyncio +async def test_collect_round_llm_chunks_timeout_without_any_chunks(): + """Test round timeout path when LLM yields nothing before timeout.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "_client_", "X-Custom": "_client_"}, - ) - server._resolved_headers = { - "Authorization": MCP_CLIENT_PLACEHOLDER, - "X-Custom": MCP_CLIENT_PLACEHOLDER, - } + async def _slow_invoke(*args, **kwargs): + await asyncio.sleep(0.05) + if kwargs.get("_never_yield", False): + yield AIMessageChunk(content="", response_metadata={}) - client_headers = { - "test-server": { - "Authorization": "Bearer token", - "X-Custom": "custom-value", - } - } + with ( + patch( + "ols.src.query_helpers.docs_summarizer.constants.TOOL_CALL_ROUND_TIMEOUT", + 0.001, + ), + patch.object(summarizer, "_invoke_llm", side_effect=_slow_invoke), + ): + tool_call_chunks, streamed_chunks, should_stop = ( + await summarizer._collect_round_llm_chunks( + messages=[], + llm_input_values={}, + all_mcp_tools=mock_tools_map, + is_final_round=False, + token_counter=AsyncMock(), + round_index=1, + ) + ) - headers = resolve_server_headers( - server, user_token=None, client_headers=client_headers - ) + assert should_stop is True + assert tool_call_chunks == [] + assert len(streamed_chunks) == 1 + assert streamed_chunks[0].type == ChunkType.TEXT + assert "I could not complete this request in time." in streamed_chunks[0].text - assert headers is not None - assert headers == {"Authorization": "Bearer token", "X-Custom": "custom-value"} +@pytest.mark.asyncio +async def test_collect_round_llm_chunks_timeout_after_partial_text(): + """Test timeout still preserves already-streamed text before fallback.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) -def test_resolve_server_headers_client_does_not_override_static_config(): - """Test client headers don't override static server-configured headers.""" - from ols.utils.mcp_utils import resolve_server_headers + async def _partial_then_slow(*args, **kwargs): + yield AIMessageChunk(content="partial", response_metadata={}) + await asyncio.sleep(0.05) + if kwargs.get("_never_yield", False): + yield AIMessageChunk(content="", response_metadata={}) - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "Bearer config-token"}, - ) - server._resolved_headers = {"Authorization": "Bearer config-token"} + with ( + patch( + "ols.src.query_helpers.docs_summarizer.constants.TOOL_CALL_ROUND_TIMEOUT", + 0.001, + ), + patch.object(summarizer, "_invoke_llm", side_effect=_partial_then_slow), + ): + tool_call_chunks, streamed_chunks, should_stop = ( + await summarizer._collect_round_llm_chunks( + messages=[], + llm_input_values={}, + all_mcp_tools=mock_tools_map, + is_final_round=False, + token_counter=AsyncMock(), + round_index=1, + ) + ) - # Client provides different authorization (should be ignored for non-placeholder) - client_headers = {"test-server": {"Authorization": "Bearer client-token"}} + assert should_stop is True + assert tool_call_chunks == [] + assert [chunk.type for chunk in streamed_chunks] == [ChunkType.TEXT, ChunkType.TEXT] + assert streamed_chunks[0].text == "partial" + assert "I could not complete this request in time." in streamed_chunks[1].text - headers = resolve_server_headers( - server, user_token=None, client_headers=client_headers - ) - assert headers is not None - # Config header should be used (not client) - assert headers == {"Authorization": "Bearer config-token"} +@pytest.mark.asyncio +async def test_collect_round_llm_chunks_stop_short_circuits_before_timeout(): + """Test finish_reason=stop returns immediately and does not emit timeout fallback.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + async def _stop_immediately(*args, **kwargs): + yield AIMessageChunk(content="", response_metadata={"finish_reason": "stop"}) -def test_resolve_server_headers_mixed_placeholders(): - """Test resolve_server_headers with mix of kubernetes and client placeholders.""" - from ols.constants import MCP_CLIENT_PLACEHOLDER, MCP_KUBERNETES_PLACEHOLDER - from ols.utils.mcp_utils import resolve_server_headers + with ( + patch( + "ols.src.query_helpers.docs_summarizer.constants.TOOL_CALL_ROUND_TIMEOUT", + 0.001, + ), + patch.object(summarizer, "_invoke_llm", side_effect=_stop_immediately), + ): + tool_call_chunks, streamed_chunks, should_stop = ( + await summarizer._collect_round_llm_chunks( + messages=[], + llm_input_values={}, + all_mcp_tools=mock_tools_map, + is_final_round=False, + token_counter=AsyncMock(), + round_index=1, + ) + ) - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "kubernetes", "X-API-Key": "_client_"}, - ) - server._resolved_headers = { - "Authorization": MCP_KUBERNETES_PLACEHOLDER, - "X-API-Key": MCP_CLIENT_PLACEHOLDER, - } + assert should_stop is True + assert tool_call_chunks == [] + assert streamed_chunks == [] - client_headers = {"test-server": {"X-API-Key": "api-key-123"}} - headers = resolve_server_headers( - server, - user_token="k8s-token", # noqa: S106 # nosec - client_headers=client_headers, - ) +@pytest.mark.asyncio +async def test_collect_round_llm_chunks_handles_string_chunk(): + """Test fake-LLM compatibility path where chunk is plain string.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - assert headers is not None - assert headers == {"Authorization": "Bearer k8s-token", "X-API-Key": "api-key-123"} + async def _string_invoke(*args, **kwargs): + yield "plain-string-chunk" + + with patch.object(summarizer, "_invoke_llm", side_effect=_string_invoke): + tool_call_chunks, streamed_chunks, should_stop = ( + await summarizer._collect_round_llm_chunks( + messages=[], + llm_input_values={}, + all_mcp_tools=[], + is_final_round=False, + token_counter=AsyncMock(), + round_index=1, + ) + ) + assert should_stop is False + assert tool_call_chunks == [] + assert len(streamed_chunks) == 1 + assert streamed_chunks[0].type == "text" + assert streamed_chunks[0].text == "plain-string-chunk" -def test_resolve_server_headers_no_placeholders(): - """Test resolve_server_headers with direct header values (no placeholders).""" - from ols.utils.mcp_utils import resolve_server_headers - server = MCPServerConfig( - name="test-server", - url="http://test:8080/mcp", - headers={"Authorization": "Bearer static-token"}, +def test_resolve_tool_call_definitions_none_args_normalized_to_empty_dict(): + """Test that None tool args are normalized to {}.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + tool = mock_tools_map[0] + definitions, skipped = summarizer._resolve_tool_call_definitions( + [{"name": tool.name, "args": None, "id": "call_none"}], + {tool.name: tool}, + set(), ) - server._resolved_headers = {"Authorization": "Bearer static-token"} - - headers = resolve_server_headers(server, user_token=None, client_headers=None) - assert headers is not None - assert headers == {"Authorization": "Bearer static-token"} + assert skipped == [] + assert len(definitions) == 1 + assert definitions[0][0] == "call_none" + assert definitions[0][1] == {} + assert definitions[0][2] is tool -def test_tool_result_includes_tool_meta(): - """Test that tool_result includes tool_meta and server_name from tool metadata.""" - question = "How many namespaces are there in my cluster?" - - mcp_servers_config = { - "test-server": { - "transport": "streamable_http", - "url": "http://test-server:8080/mcp", - }, - } - - mock_server = MagicMock() - mock_server.name = "test-server" +@pytest.mark.asyncio +async def test_process_tool_calls_for_round_skipped_only_without_execution(): + """Test skipped-only path emits tool_result without calling executor.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + token_usage = ToolTokenUsage(used=0) + messages: list = [] + tool = mock_tools_map[0] + tool_call_chunks = [ + AIMessageChunk( + content="", + response_metadata={"finish_reason": "tool_calls"}, + tool_calls=[{"name": "missing_tool", "args": {}, "id": "skip_1"}], + ) + ] - with ( - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", - return_value=2, - ), - patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" - ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, + with patch( + "ols.src.query_helpers.docs_summarizer.execute_tool_calls", + new=AsyncMock(side_effect=AssertionError("executor should not be called")), ): - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [mock_server] - - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=(mcp_servers_config, mock_tools_with_meta)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_namespaces_with_meta_mock", - "args": {}, - "id": "call_meta1", - }, - ], - ) - ] + streamed = [ + chunk + async for chunk in summarizer._process_tool_calls_for_round( + round_index=1, + tool_call_chunks=tool_call_chunks, + all_tools_dict={tool.name: tool}, + duplicate_tool_names=set(), + messages=messages, + token_handler=TokenHandler(), + tool_token_usage=token_usage, + max_tokens_for_tools=1000, + max_tokens_per_tool=200, ) + ] - mock_mcp_client_instance = AsyncMock() - mock_mcp_client_instance.get_tools.return_value = mock_tools_with_meta - mock_mcp_client_cls.return_value = mock_mcp_client_instance - - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - result = summarizer.create_response(question) - - assert len(result.tool_results) == 1 - tool_result = result.tool_results[0] - - assert tool_result["id"] == "call_meta1" - assert tool_result["name"] == "get_namespaces_with_meta_mock" - assert tool_result["server_name"] == "test-server" - assert tool_result["tool_meta"] == MOCK_TOOL_META + assert [chunk.type for chunk in streamed] == [ + ChunkType.TOOL_CALL, + ChunkType.TOOL_RESULT, + ] + assert streamed[1].data["type"] == "tool_result" + assert "tool is unavailable" in streamed[1].data["content"] + assert len(messages) == 2 -def test_tool_result_without_meta_has_no_tool_meta_key(): - """Test that tool_result omits tool_meta when tool has no _meta in metadata.""" - question = "How many namespaces are there in my cluster?" - - mcp_servers_config = { - "test_server": { - "transport": "streamable_http", - "url": "http://test-server:8080/mcp", - }, - } +@pytest.mark.asyncio +async def test_iterate_with_tools_deduplicates_tool_names(caplog): + """Test duplicate MCP tool names are disabled and logged.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + summarizer._tool_calling_enabled = False + caplog.set_level(logging.ERROR) + tools = [SampleTool("dup"), SampleTool("dup")] with ( - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", - return_value=2, + patch.object( + summarizer, + "_collect_round_llm_chunks", + new=AsyncMock(return_value=([], [], True)), ), - patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" - ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, ): - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] - - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=(mcp_servers_config, mock_tools_map)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_namespaces_mock", - "args": {}, - "id": "call_ns1", - }, - ], - ) - ] + chunks = [ + chunk + async for chunk in summarizer.iterate_with_tools( + messages=[], + max_rounds=1, + llm_input_values={}, + token_counter=AsyncMock(), + all_mcp_tools=tools, ) + ] - mock_mcp_client_instance = AsyncMock() - mock_mcp_client_instance.get_tools.return_value = mock_tools_map - mock_mcp_client_cls.return_value = mock_mcp_client_instance + assert chunks == [] + assert "Duplicate MCP tool names detected and disabled" in caplog.text - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - result = summarizer.create_response(question) - assert len(result.tool_results) == 1 - tool_result = result.tool_results[0] +def test_tool_result_chunk_for_message_preserves_metadata_and_logs_has_meta(caplog): + """Test tool result chunk contains metadata enrichment and has_meta logging.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + caplog.set_level(logging.INFO) + tool = mock_tools_map[0] + tool.metadata = {"mcp_server": "server-a", "_meta": {"app": "ui"}} + message = ToolMessage( + content="ok", + status="success", + tool_call_id="call_meta", + additional_kwargs={"truncated": False}, + ) - assert tool_result["id"] == "call_ns1" - assert tool_result["name"] == "get_namespaces_mock" - assert "tool_meta" not in tool_result + _, chunk = summarizer._tool_result_chunk_for_message( + tool_call_message=message, + tool_name=tool.name, + tool=tool, + token_handler=TokenHandler(), + round_index=1, + ) + assert chunk.type == ChunkType.TOOL_RESULT + assert chunk.data["server_name"] == "server-a" + assert chunk.data["tool_meta"] == {"app": "ui"} + assert '"has_meta": true' in caplog.text -def test_tool_call_includes_tool_meta(): - """Test that tool_call events include tool_meta and server_name from metadata.""" - question = "How many namespaces are there in my cluster?" - mcp_servers_config = { - "test-server": { - "transport": "streamable_http", - "url": "http://test-server:8080/mcp", - }, - } +@pytest.mark.asyncio +async def test_iterate_with_tools_handles_tool_execution_error(): + """Test iterate_with_tools emits fallback when tool execution raises.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + tool = mock_tools_map[0] + tool_call_chunks = [ + AIMessageChunk( + content="", + response_metadata={"finish_reason": "tool_calls"}, + tool_calls=[{"name": tool.name, "args": {}, "id": "call_error"}], + ) + ] - mock_server = MagicMock() - mock_server.name = "test-server" + async def _failing_process(*args, **kwargs): + if False: + yield + raise RuntimeError("MCP server unreachable") with ( - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", - return_value=2, + patch.object( + summarizer, + "_collect_round_llm_chunks", + new=AsyncMock(return_value=(tool_call_chunks, [], False)), + ), + patch.object( + summarizer, "_process_tool_calls_for_round", side_effect=_failing_process ), - patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" - ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, ): - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [mock_server] - - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=(mcp_servers_config, mock_tools_with_meta)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_namespaces_with_meta_mock", - "args": {}, - "id": "call_meta1", - }, - ], - ) - ] + chunks = [ + chunk + async for chunk in summarizer.iterate_with_tools( + messages=[], + max_rounds=2, + llm_input_values={}, + token_counter=AsyncMock(), + all_mcp_tools=[tool], ) + ] - mock_mcp_client_instance = AsyncMock() - mock_mcp_client_instance.get_tools.return_value = mock_tools_with_meta - mock_mcp_client_cls.return_value = mock_mcp_client_instance - - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - result = summarizer.create_response(question) - - assert len(result.tool_calls) == 1 - tool_call = result.tool_calls[0] - - assert tool_call["name"] == "get_namespaces_with_meta_mock" - assert tool_call["server_name"] == "test-server" - assert tool_call["tool_meta"] == MOCK_TOOL_META - + assert len(chunks) == 1 + assert chunks[0].type == ChunkType.TEXT + assert "I could not complete this request." in chunks[0].text -def test_tool_call_without_meta_has_no_tool_meta_key(): - """Test that tool_call events omit tool_meta when tool has no _meta.""" - question = "How many namespaces are there in my cluster?" - - mcp_servers_config = { - "test_server": { - "transport": "streamable_http", - "url": "http://test-server:8080/mcp", - }, - } - with ( - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._get_max_iterations", - return_value=2, - ), - patch("ols.utils.mcp_utils.MultiServerMCPClient") as mock_mcp_client_cls, - patch( - "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" - ) as mock_invoke, - patch("ols.utils.mcp_utils.config") as mock_config, - ): - mock_config.tools_rag = None - mock_config.mcp_servers.servers = [MagicMock()] - - with patch( - "ols.utils.mcp_utils._gather_and_populate_tools", - new=AsyncMock(return_value=(mcp_servers_config, mock_tools_map)), - ): - mock_invoke.side_effect = lambda *args, **kwargs: async_mock_invoke( - [ - AIMessageChunk( - content="", - response_metadata={"finish_reason": "tool_calls"}, - tool_calls=[ - { - "name": "get_namespaces_mock", - "args": {}, - "id": "call_ns1", - }, - ], - ) - ] - ) - - mock_mcp_client_instance = AsyncMock() - mock_mcp_client_instance.get_tools.return_value = mock_tools_map - mock_mcp_client_cls.return_value = mock_mcp_client_instance +def test_create_response_raises_on_unknown_chunk_type(): + """Test create_response raises ValueError on unsupported chunk type.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) - summarizer.model_config.parameters.max_tokens_for_tools = 0 - result = summarizer.create_response(question) + class UnknownChunk: + type = "unsupported" + text = "" + data: ClassVar[dict[str, str]] = {} - assert len(result.tool_calls) == 1 - tool_call = result.tool_calls[0] + async def _fake_generate(self, *args, **kwargs): + yield UnknownChunk() - assert tool_call["name"] == "get_namespaces_mock" - assert "tool_meta" not in tool_call + with patch.object(DocsSummarizer, "generate_response", _fake_generate): + with pytest.raises(ValueError, match="Unknown chunk type"): + summarizer.create_response("q")