Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ols/app/endpoints/streaming_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def stream_start_event(conversation_id: str) -> str:
)


def stream_event(data: dict[str, object], event_type: str, media_type: str) -> str:
def stream_event( # pylint: disable=R0911
data: dict[str, object], event_type: str, media_type: str
) -> str:
"""Build an item to yield based on media type.

Args:
Expand All @@ -168,6 +170,8 @@ def stream_event(data: dict[str, object], event_type: str, media_type: str) -> s
text_output = str(data["reasoning"])
case "tool_call":
text_output = f"\nTool call: {json.dumps(data)}\n"
case "approval_required":
text_output = f"\nApproval request: {json.dumps(data)}\n"
case "tool_result":
text_output = f"\nTool result: {json.dumps(data)}\n"
case "history_compression_start":
Expand Down Expand Up @@ -402,6 +406,12 @@ async def response_processing_wrapper( # noqa: C901 # pylint: disable=R0912,R0
event_type=LLM_TOOL_CALL_EVENT,
media_type=media_type,
)
case StreamChunkType.APPROVAL_REQUIRED:
yield stream_event(
data=item.data,
event_type=StreamChunkType.APPROVAL_REQUIRED.value,
media_type=media_type,
)
case StreamChunkType.TOOL_RESULT:
tool_results.append(item.data)
yield stream_event(
Expand Down
1 change: 1 addition & 0 deletions ols/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ class StreamChunkType(StrEnum):

TEXT = "text"
TOOL_CALL = "tool_call"
APPROVAL_REQUIRED = "approval_required"
TOOL_RESULT = "tool_result"
HISTORY_COMPRESSION_START = "history_compression_start"
HISTORY_COMPRESSION_END = "history_compression_end"
Expand Down
29 changes: 19 additions & 10 deletions ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ols.src.prompts.prompt_generator import GeneratePrompt
from ols.src.query_helpers.history_support import prepare_history
from ols.src.query_helpers.query_helper import QueryHelper
from ols.src.tools.tools import enforce_tool_token_budget, execute_tool_calls
from ols.src.tools.tools import enforce_tool_token_budget, execute_tool_calls_stream
from ols.utils.mcp_utils import ClientHeaders, build_mcp_config, get_mcp_tools
from ols.utils.token_handler import TokenHandler

Expand Down Expand Up @@ -634,7 +634,7 @@ def _tool_result_chunk_for_message(
type=StreamChunkType.TOOL_RESULT, data=tool_result_data
)

async def _process_tool_calls_for_round(
async def _process_tool_calls_for_round( # noqa: C901 # pylint: disable=R0912
self,
*,
round_index: int,
Expand Down Expand Up @@ -757,15 +757,24 @@ async def _process_tool_calls_for_round(
)
)
else:
tool_calls_dicts = [
{"name": tool.name, "args": tool_args, "id": tool_id}
for tool_id, tool_args, tool in tool_call_definitions
]
tool_calls_messages = await execute_tool_calls(
tool_calls_dicts,
list(all_tools_dict.values()),
async for execution_event in execute_tool_calls_stream(
tool_call_definitions,
remaining_tool_budget,
)
streaming=self.streaming,
):
match execution_event.event:
case StreamChunkType.APPROVAL_REQUIRED:
yield StreamedChunk(
type=StreamChunkType.APPROVAL_REQUIRED,
data=execution_event.data,
)
case StreamChunkType.TOOL_RESULT:
tool_calls_messages.append(execution_event.data)
case _:
logger.warning(
"Ignoring unexpected tool execution event: %s",
execution_event,
)

# Merge synthetic skipped outcomes with real execution outcomes and
# append all of them to conversation state for the next LLM turn.
Expand Down
Loading