Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions python/packages/kagent-adk/src/kagent/adk/models/_litellm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,59 @@
from __future__ import annotations

from typing import Optional
import logging
import re
from typing import AsyncGenerator, Optional

from google.adk.models.lite_llm import LiteLlm
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse

logger = logging.getLogger(__name__)

# Bedrock requires tool names to match [a-zA-Z0-9_-]+ with length >= 1
_BEDROCK_TOOL_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
_BEDROCK_TOOL_NAME_FALLBACK = "unknown_tool"


def _is_bedrock_model(model: str) -> bool:
return "bedrock" in model.lower()


def _sanitize_tool_name(name: str, idx: "int | str") -> str:
"""Return a Bedrock-safe tool name; replace invalid/empty names with a fallback."""
if not name or not _BEDROCK_TOOL_NAME_RE.match(name):
safe = re.sub(r"[^a-zA-Z0-9_-]", "_", name) if name else ""
safe = safe or f"{_BEDROCK_TOOL_NAME_FALLBACK}_{idx}"
logger.debug("Sanitized invalid Bedrock tool name %r -> %r", name, safe)
return safe
return name


def _sanitize_llm_request(llm_request: LlmRequest) -> None:
"""Fix tool names in the conversation history before sending to Bedrock."""
for content in llm_request.contents:
if not content.parts:
continue
for idx, part in enumerate(content.parts):
fc = getattr(part, "function_call", None)
if fc is not None and hasattr(fc, "name"):
fc.name = _sanitize_tool_name(fc.name or "", idx)


def _sanitize_llm_response(response: LlmResponse, idx: int) -> LlmResponse:
"""Fix tool names in a model response before the ADK stores it in history."""
if response.content and response.content.parts:
for i, part in enumerate(response.content.parts):
fc = getattr(part, "function_call", None)
if fc is not None and hasattr(fc, "name"):
# Use a composite suffix to avoid collisions across responses/parts.
composite_suffix = f"{idx}_{i}"
fc.name = _sanitize_tool_name(fc.name or "", composite_suffix)
return response


class KAgentLiteLlm(LiteLlm):
"""LiteLlm subclass that supports API key passthrough."""
"""LiteLlm subclass that supports API key passthrough and Bedrock tool name sanitization."""

api_key_passthrough: Optional[bool] = None

Expand All @@ -17,3 +64,17 @@ def __init__(self, model: str, **kwargs):

def set_passthrough_key(self, token: str) -> None:
self._additional_args["api_key"] = token

async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
effective_model = llm_request.model or self.model
if _is_bedrock_model(effective_model):
_sanitize_llm_request(llm_request)
idx = 0
async for response in super().generate_content_async(llm_request, stream=stream):
yield _sanitize_llm_response(response, idx)
idx += 1
Comment on lines +68 to +77
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tool-name sanitization logs at WARNING every time an invalid name is seen. In streaming mode, generate_content_async yields many chunks, so the same invalid tool name could generate a large number of warning log lines (one per chunk/part), which can be noisy and increase log volume. Consider de-duplicating warnings per request/response (e.g., track already-sanitized originals) or lowering to INFO/DEBUG if this is expected in normal operation.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this, I think DEBUG if at all.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — changed to logger.debug in bc1b190.

else:
async for response in super().generate_content_async(llm_request, stream=stream):
yield response
Comment on lines +68 to +80
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces Bedrock-specific mutation of LlmRequest/LlmResponse tool-call names, but there are no unit tests exercising the new sanitization logic. Adding focused tests (e.g., construct an LlmRequest with function_call parts containing invalid names and assert they are sanitized when model is bedrock/..., and unchanged otherwise) would help prevent regressions.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we need unit tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added unit tests in tests/unittests/models/test_litellm_bedrock.py (15 tests, all passing locally). Covers sanitize helpers, Bedrock vs non-Bedrock routing, no-collision across parts, and log-level assertion.

Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Unit tests for Bedrock tool-name sanitization in KAgentLiteLlm."""
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.genai.types import Content, FunctionCall, Part

from kagent.adk.models._litellm import (
_sanitize_llm_request,
_sanitize_llm_response,
_sanitize_tool_name,
)


# ---------------------------------------------------------------------------
# _sanitize_tool_name
# ---------------------------------------------------------------------------


def test_sanitize_tool_name_valid_unchanged():
assert _sanitize_tool_name("valid_tool-name1", 0) == "valid_tool-name1"


def test_sanitize_tool_name_replaces_dots():
result = _sanitize_tool_name("tool.name", 0)
assert result == "tool_name"


def test_sanitize_tool_name_replaces_spaces():
result = _sanitize_tool_name("my tool", 0)
assert result == "my_tool"


def test_sanitize_tool_name_empty_uses_fallback():
result = _sanitize_tool_name("", 3)
assert result == "unknown_tool_3"


def test_sanitize_tool_name_composite_suffix():
result = _sanitize_tool_name("", "2_5")
assert result == "unknown_tool_2_5"


def test_sanitize_tool_name_logs_at_debug(caplog):
import logging

with caplog.at_level(logging.DEBUG, logger="kagent.adk.models._litellm"):
_sanitize_tool_name("bad.name", 0)
assert any("Sanitized invalid Bedrock tool name" in r.message for r in caplog.records)
assert all(r.levelname == "DEBUG" for r in caplog.records if "Sanitized" in r.message)


# ---------------------------------------------------------------------------
# _sanitize_llm_request
# ---------------------------------------------------------------------------


def _make_request_with_function_call(name: str) -> LlmRequest:
fc = FunctionCall(name=name, args={})
part = Part(function_call=fc)
content = Content(parts=[part], role="model")
req = LlmRequest()
req.contents = [content]
return req


def test_sanitize_llm_request_fixes_invalid_name():
req = _make_request_with_function_call("bad.tool.name")
_sanitize_llm_request(req)
assert req.contents[0].parts[0].function_call.name == "bad_tool_name"


def test_sanitize_llm_request_leaves_valid_name():
req = _make_request_with_function_call("good_tool")
_sanitize_llm_request(req)
assert req.contents[0].parts[0].function_call.name == "good_tool"


def test_sanitize_llm_request_no_parts_no_error():
content = Content(parts=[], role="model")
req = LlmRequest()
req.contents = [content]
_sanitize_llm_request(req) # should not raise


# ---------------------------------------------------------------------------
# _sanitize_llm_response
# ---------------------------------------------------------------------------


def _make_response_with_function_call(name: str) -> LlmResponse:
fc = FunctionCall(name=name, args={})
part = Part(function_call=fc)
content = Content(parts=[part], role="model")
resp = LlmResponse()
resp.content = content
return resp


def test_sanitize_llm_response_fixes_invalid_name():
resp = _make_response_with_function_call("bad.tool")
result = _sanitize_llm_response(resp, 0)
assert result.content.parts[0].function_call.name == "bad_tool"


def test_sanitize_llm_response_leaves_valid_name():
resp = _make_response_with_function_call("valid_tool")
result = _sanitize_llm_response(resp, 0)
assert result.content.parts[0].function_call.name == "valid_tool"


def test_sanitize_llm_response_no_collision_across_parts():
fc0 = FunctionCall(name="", args={})
fc1 = FunctionCall(name="", args={})
content = Content(parts=[Part(function_call=fc0), Part(function_call=fc1)], role="model")
resp = LlmResponse()
resp.content = content
_sanitize_llm_response(resp, 1)
names = [p.function_call.name for p in resp.content.parts]
# Each fallback name must be unique (composite idx_i suffix)
assert names[0] == "unknown_tool_1_0"
assert names[1] == "unknown_tool_1_1"


def test_sanitize_llm_response_no_content_no_error():
resp = LlmResponse()
resp.content = None
_sanitize_llm_response(resp, 0) # should not raise


# ---------------------------------------------------------------------------
# KAgentLiteLlm.generate_content_async — Bedrock vs non-Bedrock routing
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_generate_content_async_bedrock_sanitizes(monkeypatch):
from kagent.adk.models._litellm import KAgentLiteLlm

model = KAgentLiteLlm(model="bedrock/anthropic.claude-3-sonnet")

req = _make_request_with_function_call("bad.name")
req.model = "bedrock/anthropic.claude-3-sonnet"

resp = _make_response_with_function_call("bad.response.name")

async def fake_super(*args, **kwargs):
yield resp

with patch.object(
KAgentLiteLlm.__bases__[0], "generate_content_async", return_value=fake_super()
):
results = []
async for r in model.generate_content_async(req):
results.append(r)

assert req.contents[0].parts[0].function_call.name == "bad_name"
assert results[0].content.parts[0].function_call.name == "bad_response_name"


@pytest.mark.asyncio
async def test_generate_content_async_non_bedrock_no_sanitization(monkeypatch):
from kagent.adk.models._litellm import KAgentLiteLlm

model = KAgentLiteLlm(model="openai/gpt-4o")

req = _make_request_with_function_call("bad.name")
req.model = "openai/gpt-4o"

resp = _make_response_with_function_call("bad.name")

async def fake_super(*args, **kwargs):
yield resp

with patch.object(
KAgentLiteLlm.__bases__[0], "generate_content_async", return_value=fake_super()
):
results = []
async for r in model.generate_content_async(req):
results.append(r)

# name must remain unchanged for non-Bedrock models
assert req.contents[0].parts[0].function_call.name == "bad.name"
assert results[0].content.parts[0].function_call.name == "bad.name"
Loading