-
Notifications
You must be signed in to change notification settings - Fork 424
fix(bedrock): sanitize tool names for Bedrock API constraints #1474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7555d12
5dacaa3
d68aeb2
bc1b190
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,59 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Optional | ||
| import logging | ||
| import re | ||
| from typing import AsyncGenerator, Optional | ||
|
|
||
| from google.adk.models.lite_llm import LiteLlm | ||
| from google.adk.models.llm_request import LlmRequest | ||
| from google.adk.models.llm_response import LlmResponse | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Bedrock requires tool names to match [a-zA-Z0-9_-]+ with length >= 1 | ||
| _BEDROCK_TOOL_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]+$") | ||
| _BEDROCK_TOOL_NAME_FALLBACK = "unknown_tool" | ||
|
|
||
|
|
||
| def _is_bedrock_model(model: str) -> bool: | ||
| return "bedrock" in model.lower() | ||
|
|
||
|
|
||
| def _sanitize_tool_name(name: str, idx: "int | str") -> str: | ||
| """Return a Bedrock-safe tool name; replace invalid/empty names with a fallback.""" | ||
| if not name or not _BEDROCK_TOOL_NAME_RE.match(name): | ||
| safe = re.sub(r"[^a-zA-Z0-9_-]", "_", name) if name else "" | ||
| safe = safe or f"{_BEDROCK_TOOL_NAME_FALLBACK}_{idx}" | ||
| logger.debug("Sanitized invalid Bedrock tool name %r -> %r", name, safe) | ||
| return safe | ||
| return name | ||
|
|
||
|
|
||
| def _sanitize_llm_request(llm_request: LlmRequest) -> None: | ||
| """Fix tool names in the conversation history before sending to Bedrock.""" | ||
| for content in llm_request.contents: | ||
| if not content.parts: | ||
| continue | ||
| for idx, part in enumerate(content.parts): | ||
| fc = getattr(part, "function_call", None) | ||
| if fc is not None and hasattr(fc, "name"): | ||
| fc.name = _sanitize_tool_name(fc.name or "", idx) | ||
|
|
||
|
|
||
| def _sanitize_llm_response(response: LlmResponse, idx: int) -> LlmResponse: | ||
| """Fix tool names in a model response before the ADK stores it in history.""" | ||
| if response.content and response.content.parts: | ||
| for i, part in enumerate(response.content.parts): | ||
| fc = getattr(part, "function_call", None) | ||
| if fc is not None and hasattr(fc, "name"): | ||
| # Use a composite suffix to avoid collisions across responses/parts. | ||
| composite_suffix = f"{idx}_{i}" | ||
| fc.name = _sanitize_tool_name(fc.name or "", composite_suffix) | ||
| return response | ||
|
|
||
|
|
||
| class KAgentLiteLlm(LiteLlm): | ||
| """LiteLlm subclass that supports API key passthrough.""" | ||
| """LiteLlm subclass that supports API key passthrough and Bedrock tool name sanitization.""" | ||
|
|
||
| api_key_passthrough: Optional[bool] = None | ||
|
|
||
|
|
@@ -17,3 +64,17 @@ def __init__(self, model: str, **kwargs): | |
|
|
||
| def set_passthrough_key(self, token: str) -> None: | ||
| self._additional_args["api_key"] = token | ||
|
|
||
| async def generate_content_async( | ||
| self, llm_request: LlmRequest, stream: bool = False | ||
| ) -> AsyncGenerator[LlmResponse, None]: | ||
| effective_model = llm_request.model or self.model | ||
| if _is_bedrock_model(effective_model): | ||
| _sanitize_llm_request(llm_request) | ||
| idx = 0 | ||
| async for response in super().generate_content_async(llm_request, stream=stream): | ||
| yield _sanitize_llm_response(response, idx) | ||
| idx += 1 | ||
| else: | ||
| async for response in super().generate_content_async(llm_request, stream=stream): | ||
| yield response | ||
|
Comment on lines
+68
to
+80
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| """Unit tests for Bedrock tool-name sanitization in KAgentLiteLlm.""" | ||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| import pytest | ||
| from google.adk.models.llm_request import LlmRequest | ||
| from google.adk.models.llm_response import LlmResponse | ||
| from google.genai.types import Content, FunctionCall, Part | ||
|
|
||
| from kagent.adk.models._litellm import ( | ||
| _sanitize_llm_request, | ||
| _sanitize_llm_response, | ||
| _sanitize_tool_name, | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # _sanitize_tool_name | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_sanitize_tool_name_valid_unchanged(): | ||
| assert _sanitize_tool_name("valid_tool-name1", 0) == "valid_tool-name1" | ||
|
|
||
|
|
||
| def test_sanitize_tool_name_replaces_dots(): | ||
| result = _sanitize_tool_name("tool.name", 0) | ||
| assert result == "tool_name" | ||
|
|
||
|
|
||
| def test_sanitize_tool_name_replaces_spaces(): | ||
| result = _sanitize_tool_name("my tool", 0) | ||
| assert result == "my_tool" | ||
|
|
||
|
|
||
| def test_sanitize_tool_name_empty_uses_fallback(): | ||
| result = _sanitize_tool_name("", 3) | ||
| assert result == "unknown_tool_3" | ||
|
|
||
|
|
||
| def test_sanitize_tool_name_composite_suffix(): | ||
| result = _sanitize_tool_name("", "2_5") | ||
| assert result == "unknown_tool_2_5" | ||
|
|
||
|
|
||
| def test_sanitize_tool_name_logs_at_debug(caplog): | ||
| import logging | ||
|
|
||
| with caplog.at_level(logging.DEBUG, logger="kagent.adk.models._litellm"): | ||
| _sanitize_tool_name("bad.name", 0) | ||
| assert any("Sanitized invalid Bedrock tool name" in r.message for r in caplog.records) | ||
| assert all(r.levelname == "DEBUG" for r in caplog.records if "Sanitized" in r.message) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # _sanitize_llm_request | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _make_request_with_function_call(name: str) -> LlmRequest: | ||
| fc = FunctionCall(name=name, args={}) | ||
| part = Part(function_call=fc) | ||
| content = Content(parts=[part], role="model") | ||
| req = LlmRequest() | ||
| req.contents = [content] | ||
| return req | ||
|
|
||
|
|
||
| def test_sanitize_llm_request_fixes_invalid_name(): | ||
| req = _make_request_with_function_call("bad.tool.name") | ||
| _sanitize_llm_request(req) | ||
| assert req.contents[0].parts[0].function_call.name == "bad_tool_name" | ||
|
|
||
|
|
||
| def test_sanitize_llm_request_leaves_valid_name(): | ||
| req = _make_request_with_function_call("good_tool") | ||
| _sanitize_llm_request(req) | ||
| assert req.contents[0].parts[0].function_call.name == "good_tool" | ||
|
|
||
|
|
||
| def test_sanitize_llm_request_no_parts_no_error(): | ||
| content = Content(parts=[], role="model") | ||
| req = LlmRequest() | ||
| req.contents = [content] | ||
| _sanitize_llm_request(req) # should not raise | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # _sanitize_llm_response | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _make_response_with_function_call(name: str) -> LlmResponse: | ||
| fc = FunctionCall(name=name, args={}) | ||
| part = Part(function_call=fc) | ||
| content = Content(parts=[part], role="model") | ||
| resp = LlmResponse() | ||
| resp.content = content | ||
| return resp | ||
|
|
||
|
|
||
| def test_sanitize_llm_response_fixes_invalid_name(): | ||
| resp = _make_response_with_function_call("bad.tool") | ||
| result = _sanitize_llm_response(resp, 0) | ||
| assert result.content.parts[0].function_call.name == "bad_tool" | ||
|
|
||
|
|
||
| def test_sanitize_llm_response_leaves_valid_name(): | ||
| resp = _make_response_with_function_call("valid_tool") | ||
| result = _sanitize_llm_response(resp, 0) | ||
| assert result.content.parts[0].function_call.name == "valid_tool" | ||
|
|
||
|
|
||
| def test_sanitize_llm_response_no_collision_across_parts(): | ||
| fc0 = FunctionCall(name="", args={}) | ||
| fc1 = FunctionCall(name="", args={}) | ||
| content = Content(parts=[Part(function_call=fc0), Part(function_call=fc1)], role="model") | ||
| resp = LlmResponse() | ||
| resp.content = content | ||
| _sanitize_llm_response(resp, 1) | ||
| names = [p.function_call.name for p in resp.content.parts] | ||
| # Each fallback name must be unique (composite idx_i suffix) | ||
| assert names[0] == "unknown_tool_1_0" | ||
| assert names[1] == "unknown_tool_1_1" | ||
|
|
||
|
|
||
| def test_sanitize_llm_response_no_content_no_error(): | ||
| resp = LlmResponse() | ||
| resp.content = None | ||
| _sanitize_llm_response(resp, 0) # should not raise | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # KAgentLiteLlm.generate_content_async — Bedrock vs non-Bedrock routing | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_generate_content_async_bedrock_sanitizes(monkeypatch): | ||
| from kagent.adk.models._litellm import KAgentLiteLlm | ||
|
|
||
| model = KAgentLiteLlm(model="bedrock/anthropic.claude-3-sonnet") | ||
|
|
||
| req = _make_request_with_function_call("bad.name") | ||
| req.model = "bedrock/anthropic.claude-3-sonnet" | ||
|
|
||
| resp = _make_response_with_function_call("bad.response.name") | ||
|
|
||
| async def fake_super(*args, **kwargs): | ||
| yield resp | ||
|
|
||
| with patch.object( | ||
| KAgentLiteLlm.__bases__[0], "generate_content_async", return_value=fake_super() | ||
| ): | ||
| results = [] | ||
| async for r in model.generate_content_async(req): | ||
| results.append(r) | ||
|
|
||
| assert req.contents[0].parts[0].function_call.name == "bad_name" | ||
| assert results[0].content.parts[0].function_call.name == "bad_response_name" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_generate_content_async_non_bedrock_no_sanitization(monkeypatch): | ||
| from kagent.adk.models._litellm import KAgentLiteLlm | ||
|
|
||
| model = KAgentLiteLlm(model="openai/gpt-4o") | ||
|
|
||
| req = _make_request_with_function_call("bad.name") | ||
| req.model = "openai/gpt-4o" | ||
|
|
||
| resp = _make_response_with_function_call("bad.name") | ||
|
|
||
| async def fake_super(*args, **kwargs): | ||
| yield resp | ||
|
|
||
| with patch.object( | ||
| KAgentLiteLlm.__bases__[0], "generate_content_async", return_value=fake_super() | ||
| ): | ||
| results = [] | ||
| async for r in model.generate_content_async(req): | ||
| results.append(r) | ||
|
|
||
| # name must remain unchanged for non-Bedrock models | ||
| assert req.contents[0].parts[0].function_call.name == "bad.name" | ||
| assert results[0].content.parts[0].function_call.name == "bad.name" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tool-name sanitization logs at WARNING every time an invalid name is seen. In streaming mode,
generate_content_asyncyields many chunks, so the same invalid tool name could generate a large number of warning log lines (one per chunk/part), which can be noisy and increase log volume. Consider de-duplicating warnings per request/response (e.g., track already-sanitized originals) or lowering to INFO/DEBUG if this is expected in normal operation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this, I think
DEBUGif at all.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed — changed to
logger.debugin bc1b190.