From 12afde31970272eb2cf14bf15787f95fc1f27815 Mon Sep 17 00:00:00 2001 From: Simon Zhu Date: Thu, 5 Mar 2026 00:07:08 -0500 Subject: [PATCH 1/3] feat(agentsts): move STS token exchange to before_tool_callback Move STS token exchange from before_run_callback (eager, once per invocation) to before_tool_callback (lazy, per MCP tool call). This avoids the sync/async problem in header_provider and only performs the exchange when an MCP tool is actually invoked. - before_run_callback now only extracts and stores the subject token - before_tool_callback exchanges on first McpTool call per session - Non-MCP tools (memory, AgentTool, etc.) are skipped - Cached tokens are reused for subsequent MCP calls in same session - Sets the stage for per-audience token exchange Co-Authored-By: Claude Opus 4.6 Signed-off-by: Simon Zhu --- .../agentsts-adk/src/agentsts/adk/_base.py | 82 ++++++++++++---- .../tests/test_adk_integration.py | 96 ++++++++++++++++--- 2 files changed, 145 insertions(+), 33 deletions(-) diff --git a/python/packages/agentsts-adk/src/agentsts/adk/_base.py b/python/packages/agentsts-adk/src/agentsts/adk/_base.py index 38a66bd34..503f2aacf 100644 --- a/python/packages/agentsts-adk/src/agentsts/adk/_base.py +++ b/python/packages/agentsts-adk/src/agentsts/adk/_base.py @@ -14,6 +14,7 @@ from google.adk.sessions.session import Session from google.adk.tools.base_tool import BaseTool from google.adk.tools.mcp_tool import MCPTool +from google.adk.tools.mcp_tool.mcp_tool import McpTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from google.adk.tools.tool_context import ToolContext from typing_extensions import override @@ -40,7 +41,18 @@ def __init__( class ADKTokenPropagationPlugin(BasePlugin): - """Plugin for propagating STS tokens to ADK tools.""" + """Plugin for propagating STS tokens to ADK tools. + + Token exchange lifecycle: + 1. before_run_callback: extracts the subject token from request headers + and stores it for the duration of the invocation. + 2. before_tool_callback: when an MCP tool is about to be called and STS + is configured, exchanges the subject token (async) and caches the + access token so header_provider can read it. + 3. header_provider (sync): returns cached access token as Authorization + header -- called by McpToolset/McpTool during MCP session setup. + 4. after_run_callback: cleans up all cached state. + """ def __init__(self, sts_integration: Optional[STSIntegrationBase] = None): """Initialize the token propagation plugin. @@ -51,6 +63,7 @@ def __init__(self, sts_integration: Optional[STSIntegrationBase] = None): super().__init__("ADKTokenPropagationPlugin") self.sts_integration = sts_integration self.token_cache: Dict[str, str] = {} + self._subject_tokens: Dict[str, str] = {} def add_to_agent(self, agent: BaseAgent): """ @@ -70,7 +83,6 @@ def add_to_agent(self, agent: BaseAgent): logger.debug("Updated tool connection params to include access token from STS server") def header_provider(self, readonly_context: Optional[ReadonlyContext]) -> Dict[str, str]: - # access save token access_token = self.token_cache.get(self.cache_key(readonly_context._invocation_context), "") if not access_token: return {} @@ -85,25 +97,58 @@ async def before_run_callback( *, invocation_context: InvocationContext, ) -> Optional[dict]: - """Propagate token to model before execution.""" + """Extract and store the subject token for later exchange.""" headers = invocation_context.session.state.get(HEADERS_KEY, None) subject_token = _extract_jwt_from_headers(headers) if not subject_token: logger.debug("No subject token found in headers for token propagation") return None - if self.sts_integration: - try: - subject_token = await self.sts_integration.exchange_token( - subject_token=subject_token, - subject_token_type=TokenType.JWT, - actor_token=self.sts_integration._actor_token, - actor_token_type=TokenType.JWT if self.sts_integration._actor_token else None, - ) - except Exception as e: - logger.warning(f"STS token exchange failed: {e}") - return None - # no sts, just propagate the subject token upstream - self.token_cache[self.cache_key(invocation_context)] = subject_token + key = self.cache_key(invocation_context) + self._subject_tokens[key] = subject_token + if not self.sts_integration: + # No STS -- propagate the subject token directly so + # header_provider can return it on the first tool call. + self.token_cache[key] = subject_token + return None + + @override + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> Optional[dict]: + """Exchange the subject token via STS before each MCP tool call.""" + if not self.sts_integration: + return None + # Only exchange tokens for MCP tool calls. Other tool types + # (memory tools, AgentTool, etc.) don't use header_provider and + # have their own auth mechanisms, so exchanging here would be a + # wasted HTTP round-trip to the STS. + if not isinstance(tool, McpTool): + return None + + key = self.cache_key(tool_context._invocation_context) + # Already exchanged for this session + if key in self.token_cache: + return None + + subject_token = self._subject_tokens.get(key) + if not subject_token: + return None + + try: + access_token = await self.sts_integration.exchange_token( + subject_token=subject_token, + subject_token_type=TokenType.JWT, + actor_token=self.sts_integration._actor_token, + actor_token_type=TokenType.JWT if self.sts_integration._actor_token else None, + ) + self.token_cache[key] = access_token + except Exception as e: + logger.warning(f"STS token exchange failed: {e}") + return None def cache_key(self, invocation_context: InvocationContext) -> str: @@ -116,8 +161,9 @@ async def after_run_callback( *, invocation_context: InvocationContext, ) -> Optional[dict]: - # delete token after run - self.token_cache.pop(self.cache_key(invocation_context), None) + key = self.cache_key(invocation_context) + self.token_cache.pop(key, None) + self._subject_tokens.pop(key, None) return None diff --git a/python/packages/agentsts-adk/tests/test_adk_integration.py b/python/packages/agentsts-adk/tests/test_adk_integration.py index 5f03897aa..4beab2c33 100644 --- a/python/packages/agentsts-adk/tests/test_adk_integration.py +++ b/python/packages/agentsts-adk/tests/test_adk_integration.py @@ -4,6 +4,7 @@ import pytest from google.adk.agents import LlmAgent +from google.adk.tools.mcp_tool.mcp_tool import McpTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from agentsts.adk import ADKSTSIntegration, ADKTokenPropagationPlugin @@ -31,6 +32,16 @@ def _make_readonly_context(self, invocation_context): readonly_context._invocation_context = invocation_context return readonly_context + def _make_tool_context(self, invocation_context): + tool_context = Mock() + tool_context._invocation_context = invocation_context + return tool_context + + def _make_mcp_tool(self): + tool = Mock(spec=McpTool) + tool.name = "test_tool" + return tool + def test_init(self): mock_sts_integration = Mock() plugin = ADKTokenPropagationPlugin(mock_sts_integration) @@ -77,49 +88,102 @@ async def test_downstream_token_propagation_without_sts(self): @pytest.mark.asyncio async def test_sts_token_exchange_success(self): - """Case: STS integration exchanges token -> access token cached and returned by header provider.""" + """Case: STS integration -- before_run stores subject token, before_tool_callback exchanges it.""" sts = Mock(spec=ADKSTSIntegration) sts._actor_token = "actor-token" sts.exchange_token = AsyncMock(return_value="access-token-XYZ") plugin = ADKTokenPropagationPlugin(sts) ic = self._make_invocation_context("sess-3", headers={"Authorization": "Bearer original-subject"}) - with patch("agentsts.adk._base.logger") as mock_logger: - result = await plugin.before_run_callback(invocation_context=ic) - assert result is None - sts.exchange_token.assert_called_once_with( - subject_token="original-subject", - subject_token_type=TokenType.JWT, - actor_token="actor-token", - actor_token_type=TokenType.JWT, - ) - # optional debug log length check - mock_logger.debug.assert_called() # at least one debug log + + # before_run_callback should store the subject token but NOT exchange + result = await plugin.before_run_callback(invocation_context=ic) + assert result is None + sts.exchange_token.assert_not_called() + assert "sess-3" not in plugin.token_cache + assert plugin._subject_tokens["sess-3"] == "original-subject" + + # before_tool_callback should exchange on first MCP tool call + tool = self._make_mcp_tool() + tc = self._make_tool_context(ic) + result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) + assert result is None + sts.exchange_token.assert_called_once_with( + subject_token="original-subject", + subject_token_type=TokenType.JWT, + actor_token="actor-token", + actor_token_type=TokenType.JWT, + ) assert plugin.token_cache["sess-3"] == "access-token-XYZ" + # header_provider should return the exchanged token ro_ctx = self._make_readonly_context(ic) headers = plugin.header_provider(ro_ctx) assert headers == {"Authorization": "Bearer access-token-XYZ"} + # second tool call should not exchange again (cached) + sts.exchange_token.reset_mock() + result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) + assert result is None + sts.exchange_token.assert_not_called() + await plugin.after_run_callback(invocation_context=ic) assert "sess-3" not in plugin.token_cache + assert "sess-3" not in plugin._subject_tokens @pytest.mark.asyncio async def test_sts_token_exchange_failure(self): - """Case: STS exchange raises -> no cache entry, graceful warning.""" + """Case: STS exchange raises in before_tool_callback -> no cache entry, graceful warning.""" sts = Mock(spec=ADKSTSIntegration) sts._actor_token = "actor-token" sts.exchange_token = AsyncMock(side_effect=Exception("boom")) plugin = ADKTokenPropagationPlugin(sts) ic = self._make_invocation_context("sess-4", headers={"Authorization": "Bearer original-subject"}) + + await plugin.before_run_callback(invocation_context=ic) + assert plugin._subject_tokens["sess-4"] == "original-subject" + + tool = self._make_mcp_tool() + tc = self._make_tool_context(ic) with patch("agentsts.adk._base.logger") as mock_logger: - result = await plugin.before_run_callback(invocation_context=ic) + result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) assert result is None mock_logger.warning.assert_called_once() assert "sess-4" not in plugin.token_cache + # header provider should yield empty dict ro_ctx = self._make_readonly_context(ic) assert plugin.header_provider(ro_ctx) == {} + @pytest.mark.asyncio + async def test_before_tool_callback_skips_non_mcp_tools(self): + """Case: before_tool_callback ignores non-MCP tools.""" + sts = Mock(spec=ADKSTSIntegration) + sts._actor_token = "actor-token" + sts.exchange_token = AsyncMock(return_value="access-token") + plugin = ADKTokenPropagationPlugin(sts) + ic = self._make_invocation_context("sess-7", headers={"Authorization": "Bearer subj"}) + await plugin.before_run_callback(invocation_context=ic) + + non_mcp_tool = Mock() # not a McpTool + tc = self._make_tool_context(ic) + result = await plugin.before_tool_callback(tool=non_mcp_tool, tool_args={}, tool_context=tc) + assert result is None + sts.exchange_token.assert_not_called() + + @pytest.mark.asyncio + async def test_before_tool_callback_no_sts(self): + """Case: before_tool_callback is a no-op without STS integration.""" + plugin = ADKTokenPropagationPlugin(sts_integration=None) + ic = self._make_invocation_context("sess-8", headers={"Authorization": "Bearer subj"}) + await plugin.before_run_callback(invocation_context=ic) + + tool = self._make_mcp_tool() + tc = self._make_tool_context(ic) + result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) + assert result is None + # token_cache should still have the subject token from before_run + assert plugin.token_cache["sess-8"] == "subj" + def test_header_provider_no_entry(self): """Case: header_provider called with no cached token -> returns empty dict.""" plugin = ADKTokenPropagationPlugin() @@ -131,13 +195,15 @@ def test_header_provider_no_entry(self): @pytest.mark.asyncio async def test_after_run_callback_removes_token(self): - """Case: after_run_callback removes cached token.""" + """Case: after_run_callback removes cached token and subject token.""" plugin = ADKTokenPropagationPlugin() ic = self._make_invocation_context("sess-6", headers={"Authorization": "Bearer AAA"}) await plugin.before_run_callback(invocation_context=ic) assert "sess-6" in plugin.token_cache + assert "sess-6" in plugin._subject_tokens await plugin.after_run_callback(invocation_context=ic) assert "sess-6" not in plugin.token_cache + assert "sess-6" not in plugin._subject_tokens def test_extract_jwt_from_headers_success(self): """Test successful JWT extraction from headers.""" From da6dd4c8d2958bef989ea4ce247c30873a5d5d29 Mon Sep 17 00:00:00 2001 From: Simon Zhu Date: Thu, 5 Mar 2026 03:44:04 -0500 Subject: [PATCH 2/3] implement audience Signed-off-by: Simon Zhu --- go/api/adk/types.go | 2 + .../config/crd/bases/kagent.dev_agents.yaml | 7 + .../bases/kagent.dev_remotemcpservers.yaml | 7 + go/api/v1alpha2/agent_types.go | 7 + go/api/v1alpha2/remotemcpserver_types.go | 7 + go/api/v1alpha2/zz_generated.deepcopy.go | 10 + .../translator/agent/adk_api_translator.go | 10 + .../inputs/agent_with_sts_audience.yaml | 51 +++ .../agent_with_sts_audience_override.yaml | 52 +++ .../outputs/agent_with_sts_audience.json | 298 ++++++++++++++++++ .../agent_with_sts_audience_override.json | 298 ++++++++++++++++++ .../templates/kagent.dev_agents.yaml | 7 + .../kagent.dev_remotemcpservers.yaml | 7 + .../agentsts-adk/src/agentsts/adk/_base.py | 99 ++++-- .../tests/test_adk_integration.py | 223 +++++++++---- .../kagent-adk/src/kagent/adk/_mcp_toolset.py | 4 + .../kagent-adk/src/kagent/adk/types.py | 41 ++- .../unittests/test_header_propagation.py | 73 ++++- 18 files changed, 1100 insertions(+), 103 deletions(-) create mode 100644 go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience.yaml create mode 100644 go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience_override.yaml create mode 100644 go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience.json create mode 100644 go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience_override.json diff --git a/go/api/adk/types.go b/go/api/adk/types.go index aee673f09..c7ba7dcfa 100644 --- a/go/api/adk/types.go +++ b/go/api/adk/types.go @@ -24,6 +24,7 @@ type HttpMcpServerConfig struct { Tools []string `json:"tools"` AllowedHeaders []string `json:"allowed_headers,omitempty"` RequireApproval []string `json:"require_approval,omitempty"` + STSAudience string `json:"sts_audience,omitempty"` } type SseConnectionParams struct { @@ -42,6 +43,7 @@ type SseMcpServerConfig struct { Tools []string `json:"tools"` AllowedHeaders []string `json:"allowed_headers,omitempty"` RequireApproval []string `json:"require_approval,omitempty"` + STSAudience string `json:"sts_audience,omitempty"` } type Model interface { diff --git a/go/api/config/crd/bases/kagent.dev_agents.yaml b/go/api/config/crd/bases/kagent.dev_agents.yaml index 8b735e616..67824a2f2 100644 --- a/go/api/config/crd/bases/kagent.dev_agents.yaml +++ b/go/api/config/crd/bases/kagent.dev_agents.yaml @@ -10103,6 +10103,13 @@ spec: type: string maxItems: 50 type: array + stsAudience: + description: |- + STSAudience overrides the audience value for STS token exchange when + calling this MCP tool server from this agent. If not set, falls back + to the audience configured on the RemoteMCPServer. If neither is set, + no audience is passed in the token exchange request. + type: string toolNames: description: |- The names of the tools to be provided by the ToolServer diff --git a/go/api/config/crd/bases/kagent.dev_remotemcpservers.yaml b/go/api/config/crd/bases/kagent.dev_remotemcpservers.yaml index 534c27b35..a8a0267b5 100644 --- a/go/api/config/crd/bases/kagent.dev_remotemcpservers.yaml +++ b/go/api/config/crd/bases/kagent.dev_remotemcpservers.yaml @@ -171,6 +171,13 @@ spec: type: string sseReadTimeout: type: string + stsAudience: + description: |- + STSAudience specifies the audience value to include in STS token exchange + requests when this MCP server is called. This scopes the issued token to + this specific service. If not set, no audience is passed in the token + exchange request. Can be overridden per-agent via McpServerTool.stsAudience. + type: string terminateOnClose: default: true type: boolean diff --git a/go/api/v1alpha2/agent_types.go b/go/api/v1alpha2/agent_types.go index 81c68bdf6..8c7c95baa 100644 --- a/go/api/v1alpha2/agent_types.go +++ b/go/api/v1alpha2/agent_types.go @@ -413,6 +413,13 @@ type McpServerTool struct { // Example: ["x-user-email", "x-tenant-id"] // +optional AllowedHeaders []string `json:"allowedHeaders,omitempty"` + + // STSAudience overrides the audience value for STS token exchange when + // calling this MCP tool server from this agent. If not set, falls back + // to the audience configured on the RemoteMCPServer. If neither is set, + // no audience is passed in the token exchange request. + // +optional + STSAudience *string `json:"stsAudience,omitempty"` } type TypedLocalReference struct { diff --git a/go/api/v1alpha2/remotemcpserver_types.go b/go/api/v1alpha2/remotemcpserver_types.go index f8a355894..2d9ddd35a 100644 --- a/go/api/v1alpha2/remotemcpserver_types.go +++ b/go/api/v1alpha2/remotemcpserver_types.go @@ -59,6 +59,13 @@ type RemoteMCPServerSpec struct { // See: https://gateway-api.sigs.k8s.io/guides/multiple-ns/#cross-namespace-routing // +optional AllowedNamespaces *AllowedNamespaces `json:"allowedNamespaces,omitempty"` + + // STSAudience specifies the audience value to include in STS token exchange + // requests when this MCP server is called. This scopes the issued token to + // this specific service. If not set, no audience is passed in the token + // exchange request. Can be overridden per-agent via McpServerTool.stsAudience. + // +optional + STSAudience *string `json:"stsAudience,omitempty"` } var _ sql.Scanner = (*RemoteMCPServerSpec)(nil) diff --git a/go/api/v1alpha2/zz_generated.deepcopy.go b/go/api/v1alpha2/zz_generated.deepcopy.go index 52b0309ec..e48cd9634 100644 --- a/go/api/v1alpha2/zz_generated.deepcopy.go +++ b/go/api/v1alpha2/zz_generated.deepcopy.go @@ -603,6 +603,11 @@ func (in *McpServerTool) DeepCopyInto(out *McpServerTool) { *out = make([]string, len(*in)) copy(*out, *in) } + if in.STSAudience != nil { + in, out := &in.STSAudience, &out.STSAudience + *out = new(string) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new McpServerTool. @@ -1070,6 +1075,11 @@ func (in *RemoteMCPServerSpec) DeepCopyInto(out *RemoteMCPServerSpec) { *out = new(AllowedNamespaces) (*in).DeepCopyInto(*out) } + if in.STSAudience != nil { + in, out := &in.STSAudience, &out.STSAudience + *out = new(string) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RemoteMCPServerSpec. diff --git a/go/core/internal/controller/translator/agent/adk_api_translator.go b/go/core/internal/controller/translator/agent/adk_api_translator.go index 6c7870aa3..89308953a 100644 --- a/go/core/internal/controller/translator/agent/adk_api_translator.go +++ b/go/core/internal/controller/translator/agent/adk_api_translator.go @@ -1330,6 +1330,14 @@ func (a *adkApiTranslator) translateMCPServerTarget(ctx context.Context, agent * } func (a *adkApiTranslator) translateRemoteMCPServerTarget(ctx context.Context, agent *adk.AgentConfig, remoteMcpServer *v1alpha2.RemoteMCPServer, mcpServerTool *v1alpha2.McpServerTool, agentHeaders map[string]string, proxyURL string) error { + // Resolve STS audience: McpServerTool override > RemoteMCPServer default + stsAudience := "" + if mcpServerTool.STSAudience != nil && *mcpServerTool.STSAudience != "" { + stsAudience = *mcpServerTool.STSAudience + } else if remoteMcpServer.Spec.STSAudience != nil { + stsAudience = *remoteMcpServer.Spec.STSAudience + } + switch remoteMcpServer.Spec.Protocol { case v1alpha2.RemoteMCPServerProtocolSse: tool, err := a.translateSseHttpTool(ctx, remoteMcpServer, agentHeaders, proxyURL) @@ -1341,6 +1349,7 @@ func (a *adkApiTranslator) translateRemoteMCPServerTarget(ctx context.Context, a Tools: mcpServerTool.ToolNames, AllowedHeaders: mcpServerTool.AllowedHeaders, RequireApproval: mcpServerTool.RequireApproval, + STSAudience: stsAudience, }) default: tool, err := a.translateStreamableHttpTool(ctx, remoteMcpServer, agentHeaders, proxyURL) @@ -1352,6 +1361,7 @@ func (a *adkApiTranslator) translateRemoteMCPServerTarget(ctx context.Context, a Tools: mcpServerTool.ToolNames, AllowedHeaders: mcpServerTool.AllowedHeaders, RequireApproval: mcpServerTool.RequireApproval, + STSAudience: stsAudience, }) } return nil diff --git a/go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience.yaml b/go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience.yaml new file mode 100644 index 000000000..ac18ae330 --- /dev/null +++ b/go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience.yaml @@ -0,0 +1,51 @@ +operation: translateAgent +targetObject: agent +namespace: test +objects: + - apiVersion: v1 + kind: Secret + metadata: + name: openai-secret + namespace: test + data: + api-key: c2stdGVzdC1hcGkta2V5 # base64 encoded "sk-test-api-key" + - apiVersion: kagent.dev/v1alpha2 + kind: ModelConfig + metadata: + name: test-model + namespace: test + spec: + provider: OpenAI + model: gpt-4o + apiKeySecret: openai-secret + apiKeySecretKey: api-key + - apiVersion: kagent.dev/v1alpha2 + kind: Agent + metadata: + name: agent + namespace: test + spec: + type: Declarative + declarative: + description: An agent with STS audience on RemoteMCPServer + systemMessage: You are a helpful assistant. + modelConfig: test-model + tools: + - type: McpServer + mcpServer: + name: toolserver + kind: RemoteMCPServer + apiGroup: kagent.dev + toolNames: + - tool1 + - tool2 + - apiVersion: kagent.dev/v1alpha2 + kind: RemoteMCPServer + metadata: + name: toolserver + namespace: test + spec: + url: http://mcp-server.test:8080/mcp + description: "Test MCP Server with STS audience" + protocol: STREAMABLE_HTTP + stsAudience: "https://mcp-server.example.com" diff --git a/go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience_override.yaml b/go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience_override.yaml new file mode 100644 index 000000000..8e70ba5c3 --- /dev/null +++ b/go/core/internal/controller/translator/agent/testdata/inputs/agent_with_sts_audience_override.yaml @@ -0,0 +1,52 @@ +operation: translateAgent +targetObject: agent +namespace: test +objects: + - apiVersion: v1 + kind: Secret + metadata: + name: openai-secret + namespace: test + data: + api-key: c2stdGVzdC1hcGkta2V5 # base64 encoded "sk-test-api-key" + - apiVersion: kagent.dev/v1alpha2 + kind: ModelConfig + metadata: + name: test-model + namespace: test + spec: + provider: OpenAI + model: gpt-4o + apiKeySecret: openai-secret + apiKeySecretKey: api-key + - apiVersion: kagent.dev/v1alpha2 + kind: Agent + metadata: + name: agent + namespace: test + spec: + type: Declarative + declarative: + description: An agent with STS audience override on McpServerTool + systemMessage: You are a helpful assistant. + modelConfig: test-model + tools: + - type: McpServer + mcpServer: + name: toolserver + kind: RemoteMCPServer + apiGroup: kagent.dev + toolNames: + - tool1 + - tool2 + stsAudience: "https://agent-specific-audience.example.com" + - apiVersion: kagent.dev/v1alpha2 + kind: RemoteMCPServer + metadata: + name: toolserver + namespace: test + spec: + url: http://mcp-server.test:8080/mcp + description: "Test MCP Server with STS audience that gets overridden" + protocol: STREAMABLE_HTTP + stsAudience: "https://default-audience.example.com" diff --git a/go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience.json b/go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience.json new file mode 100644 index 000000000..871489a3f --- /dev/null +++ b/go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience.json @@ -0,0 +1,298 @@ +{ + "agentCard": { + "capabilities": { + "pushNotifications": false, + "stateTransitionHistory": true, + "streaming": true + }, + "defaultInputModes": [ + "text" + ], + "defaultOutputModes": [ + "text" + ], + "description": "", + "name": "agent", + "skills": null, + "url": "http://agent.test:8080", + "version": "" + }, + "config": { + "description": "", + "http_tools": [ + { + "params": { + "headers": {}, + "url": "http://mcp-server.test:8080/mcp" + }, + "sts_audience": "https://mcp-server.example.com", + "tools": [ + "tool1", + "tool2" + ] + } + ], + "instruction": "You are a helpful assistant.", + "model": { + "base_url": "", + "model": "gpt-4o", + "type": "openai" + }, + "stream": false + }, + "manifest": [ + { + "apiVersion": "v1", + "kind": "Secret", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + }, + "stringData": { + "agent-card.json": "{\"name\":\"agent\",\"description\":\"\",\"url\":\"http://agent.test:8080\",\"version\":\"\",\"capabilities\":{\"streaming\":true,\"pushNotifications\":false,\"stateTransitionHistory\":true},\"defaultInputModes\":[\"text\"],\"defaultOutputModes\":[\"text\"],\"skills\":[]}", + "config.json": "{\"model\":{\"type\":\"openai\",\"model\":\"gpt-4o\",\"base_url\":\"\"},\"description\":\"\",\"instruction\":\"You are a helpful assistant.\",\"http_tools\":[{\"params\":{\"url\":\"http://mcp-server.test:8080/mcp\",\"headers\":{}},\"tools\":[\"tool1\",\"tool2\"],\"sts_audience\":\"https://mcp-server.example.com\"}],\"stream\":false}" + } + }, + { + "apiVersion": "v1", + "kind": "ServiceAccount", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + } + }, + { + "apiVersion": "apps/v1", + "kind": "Deployment", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + }, + "spec": { + "selector": { + "matchLabels": { + "app": "kagent", + "kagent": "agent" + } + }, + "strategy": { + "rollingUpdate": { + "maxSurge": 1, + "maxUnavailable": 0 + }, + "type": "RollingUpdate" + }, + "template": { + "metadata": { + "annotations": { + "kagent.dev/config-hash": "3177752694121230378" + }, + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + } + }, + "spec": { + "containers": [ + { + "args": [ + "--host", + "0.0.0.0", + "--port", + "8080", + "--filepath", + "/config" + ], + "env": [ + { + "name": "OPENAI_API_KEY", + "valueFrom": { + "secretKeyRef": { + "key": "api-key", + "name": "openai-secret" + } + } + }, + { + "name": "KAGENT_NAMESPACE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.namespace" + } + } + }, + { + "name": "KAGENT_NAME", + "value": "agent" + }, + { + "name": "KAGENT_URL", + "value": "http://kagent-controller.kagent:8083" + } + ], + "image": "cr.kagent.dev/kagent-dev/kagent/app:dev", + "imagePullPolicy": "IfNotPresent", + "name": "kagent", + "ports": [ + { + "containerPort": 8080, + "name": "http" + } + ], + "readinessProbe": { + "httpGet": { + "path": "/.well-known/agent-card.json", + "port": "http" + }, + "initialDelaySeconds": 15, + "periodSeconds": 15, + "timeoutSeconds": 15 + }, + "resources": { + "limits": { + "cpu": "2", + "memory": "1Gi" + }, + "requests": { + "cpu": "100m", + "memory": "384Mi" + } + }, + "volumeMounts": [ + { + "mountPath": "/config", + "name": "config" + }, + { + "mountPath": "/var/run/secrets/tokens", + "name": "kagent-token" + } + ] + } + ], + "serviceAccountName": "agent", + "volumes": [ + { + "name": "config", + "secret": { + "secretName": "agent" + } + }, + { + "name": "kagent-token", + "projected": { + "sources": [ + { + "serviceAccountToken": { + "audience": "kagent", + "expirationSeconds": 3600, + "path": "kagent-token" + } + } + ] + } + } + ] + } + } + }, + "status": {} + }, + { + "apiVersion": "v1", + "kind": "Service", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + }, + "spec": { + "ports": [ + { + "name": "http", + "port": 8080, + "targetPort": 8080 + } + ], + "selector": { + "app": "kagent", + "kagent": "agent" + }, + "type": "ClusterIP" + }, + "status": { + "loadBalancer": {} + } + } + ] +} \ No newline at end of file diff --git a/go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience_override.json b/go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience_override.json new file mode 100644 index 000000000..d01ca0dc5 --- /dev/null +++ b/go/core/internal/controller/translator/agent/testdata/outputs/agent_with_sts_audience_override.json @@ -0,0 +1,298 @@ +{ + "agentCard": { + "capabilities": { + "pushNotifications": false, + "stateTransitionHistory": true, + "streaming": true + }, + "defaultInputModes": [ + "text" + ], + "defaultOutputModes": [ + "text" + ], + "description": "", + "name": "agent", + "skills": null, + "url": "http://agent.test:8080", + "version": "" + }, + "config": { + "description": "", + "http_tools": [ + { + "params": { + "headers": {}, + "url": "http://mcp-server.test:8080/mcp" + }, + "sts_audience": "https://agent-specific-audience.example.com", + "tools": [ + "tool1", + "tool2" + ] + } + ], + "instruction": "You are a helpful assistant.", + "model": { + "base_url": "", + "model": "gpt-4o", + "type": "openai" + }, + "stream": false + }, + "manifest": [ + { + "apiVersion": "v1", + "kind": "Secret", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + }, + "stringData": { + "agent-card.json": "{\"name\":\"agent\",\"description\":\"\",\"url\":\"http://agent.test:8080\",\"version\":\"\",\"capabilities\":{\"streaming\":true,\"pushNotifications\":false,\"stateTransitionHistory\":true},\"defaultInputModes\":[\"text\"],\"defaultOutputModes\":[\"text\"],\"skills\":[]}", + "config.json": "{\"model\":{\"type\":\"openai\",\"model\":\"gpt-4o\",\"base_url\":\"\"},\"description\":\"\",\"instruction\":\"You are a helpful assistant.\",\"http_tools\":[{\"params\":{\"url\":\"http://mcp-server.test:8080/mcp\",\"headers\":{}},\"tools\":[\"tool1\",\"tool2\"],\"sts_audience\":\"https://agent-specific-audience.example.com\"}],\"stream\":false}" + } + }, + { + "apiVersion": "v1", + "kind": "ServiceAccount", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + } + }, + { + "apiVersion": "apps/v1", + "kind": "Deployment", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + }, + "spec": { + "selector": { + "matchLabels": { + "app": "kagent", + "kagent": "agent" + } + }, + "strategy": { + "rollingUpdate": { + "maxSurge": 1, + "maxUnavailable": 0 + }, + "type": "RollingUpdate" + }, + "template": { + "metadata": { + "annotations": { + "kagent.dev/config-hash": "14897446789976215587" + }, + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + } + }, + "spec": { + "containers": [ + { + "args": [ + "--host", + "0.0.0.0", + "--port", + "8080", + "--filepath", + "/config" + ], + "env": [ + { + "name": "OPENAI_API_KEY", + "valueFrom": { + "secretKeyRef": { + "key": "api-key", + "name": "openai-secret" + } + } + }, + { + "name": "KAGENT_NAMESPACE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.namespace" + } + } + }, + { + "name": "KAGENT_NAME", + "value": "agent" + }, + { + "name": "KAGENT_URL", + "value": "http://kagent-controller.kagent:8083" + } + ], + "image": "cr.kagent.dev/kagent-dev/kagent/app:dev", + "imagePullPolicy": "IfNotPresent", + "name": "kagent", + "ports": [ + { + "containerPort": 8080, + "name": "http" + } + ], + "readinessProbe": { + "httpGet": { + "path": "/.well-known/agent-card.json", + "port": "http" + }, + "initialDelaySeconds": 15, + "periodSeconds": 15, + "timeoutSeconds": 15 + }, + "resources": { + "limits": { + "cpu": "2", + "memory": "1Gi" + }, + "requests": { + "cpu": "100m", + "memory": "384Mi" + } + }, + "volumeMounts": [ + { + "mountPath": "/config", + "name": "config" + }, + { + "mountPath": "/var/run/secrets/tokens", + "name": "kagent-token" + } + ] + } + ], + "serviceAccountName": "agent", + "volumes": [ + { + "name": "config", + "secret": { + "secretName": "agent" + } + }, + { + "name": "kagent-token", + "projected": { + "sources": [ + { + "serviceAccountToken": { + "audience": "kagent", + "expirationSeconds": 3600, + "path": "kagent-token" + } + } + ] + } + } + ] + } + } + }, + "status": {} + }, + { + "apiVersion": "v1", + "kind": "Service", + "metadata": { + "labels": { + "app": "kagent", + "app.kubernetes.io/managed-by": "kagent", + "app.kubernetes.io/name": "agent", + "app.kubernetes.io/part-of": "kagent", + "kagent": "agent" + }, + "name": "agent", + "namespace": "test", + "ownerReferences": [ + { + "apiVersion": "kagent.dev/v1alpha2", + "blockOwnerDeletion": true, + "controller": true, + "kind": "Agent", + "name": "agent", + "uid": "" + } + ] + }, + "spec": { + "ports": [ + { + "name": "http", + "port": 8080, + "targetPort": 8080 + } + ], + "selector": { + "app": "kagent", + "kagent": "agent" + }, + "type": "ClusterIP" + }, + "status": { + "loadBalancer": {} + } + } + ] +} \ No newline at end of file diff --git a/helm/kagent-crds/templates/kagent.dev_agents.yaml b/helm/kagent-crds/templates/kagent.dev_agents.yaml index 8b735e616..67824a2f2 100644 --- a/helm/kagent-crds/templates/kagent.dev_agents.yaml +++ b/helm/kagent-crds/templates/kagent.dev_agents.yaml @@ -10103,6 +10103,13 @@ spec: type: string maxItems: 50 type: array + stsAudience: + description: |- + STSAudience overrides the audience value for STS token exchange when + calling this MCP tool server from this agent. If not set, falls back + to the audience configured on the RemoteMCPServer. If neither is set, + no audience is passed in the token exchange request. + type: string toolNames: description: |- The names of the tools to be provided by the ToolServer diff --git a/helm/kagent-crds/templates/kagent.dev_remotemcpservers.yaml b/helm/kagent-crds/templates/kagent.dev_remotemcpservers.yaml index 534c27b35..a8a0267b5 100644 --- a/helm/kagent-crds/templates/kagent.dev_remotemcpservers.yaml +++ b/helm/kagent-crds/templates/kagent.dev_remotemcpservers.yaml @@ -171,6 +171,13 @@ spec: type: string sseReadTimeout: type: string + stsAudience: + description: |- + STSAudience specifies the audience value to include in STS token exchange + requests when this MCP server is called. This scopes the issued token to + this specific service. If not set, no audience is passed in the token + exchange request. Can be overridden per-agent via McpServerTool.stsAudience. + type: string terminateOnClose: default: true type: boolean diff --git a/python/packages/agentsts-adk/src/agentsts/adk/_base.py b/python/packages/agentsts-adk/src/agentsts/adk/_base.py index 503f2aacf..d38b0a30a 100644 --- a/python/packages/agentsts-adk/src/agentsts/adk/_base.py +++ b/python/packages/agentsts-adk/src/agentsts/adk/_base.py @@ -14,7 +14,6 @@ from google.adk.sessions.session import Session from google.adk.tools.base_tool import BaseTool from google.adk.tools.mcp_tool import MCPTool -from google.adk.tools.mcp_tool.mcp_tool import McpTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from google.adk.tools.tool_context import ToolContext from typing_extensions import override @@ -43,15 +42,8 @@ def __init__( class ADKTokenPropagationPlugin(BasePlugin): """Plugin for propagating STS tokens to ADK tools. - Token exchange lifecycle: - 1. before_run_callback: extracts the subject token from request headers - and stores it for the duration of the invocation. - 2. before_tool_callback: when an MCP tool is about to be called and STS - is configured, exchanges the subject token (async) and caches the - access token so header_provider can read it. - 3. header_provider (sync): returns cached access token as Authorization - header -- called by McpToolset/McpTool during MCP session setup. - 4. after_run_callback: cleans up all cached state. + Supports audience-scoped token exchange: each MCP toolset can declare + an STS audience so the exchanged token is scoped to that specific service. """ def __init__(self, sts_integration: Optional[STSIntegrationBase] = None): @@ -62,8 +54,18 @@ def __init__(self, sts_integration: Optional[STSIntegrationBase] = None): """ super().__init__("ADKTokenPropagationPlugin") self.sts_integration = sts_integration - self.token_cache: Dict[str, str] = {} - self._subject_tokens: Dict[str, str] = {} + self.token_cache: Dict[str, str] = {} # cache_key -> access_token + self._subject_tokens: Dict[str, str] = {} # session_id -> subject_token + self._audience_map: Dict[int, str] = {} # id(session_manager) -> audience + + def register_toolset(self, toolset: MCPToolset, audience: Optional[str]) -> None: + """Register a toolset's session manager -> audience mapping. + + Called during agent setup (to_agent) so the before_tool_callback + can resolve the audience for each MCP tool at runtime. + """ + if audience and hasattr(toolset, "_mcp_session_manager"): + self._audience_map[id(toolset._mcp_session_manager)] = audience def add_to_agent(self, agent: BaseAgent): """ @@ -82,8 +84,17 @@ def add_to_agent(self, agent: BaseAgent): mcp_toolset._header_provider = self.header_provider logger.debug("Updated tool connection params to include access token from STS server") - def header_provider(self, readonly_context: Optional[ReadonlyContext]) -> Dict[str, str]: - access_token = self.token_cache.get(self.cache_key(readonly_context._invocation_context), "") + def header_provider( + self, readonly_context: Optional[ReadonlyContext], audience: Optional[str] = None + ) -> Dict[str, str]: + """Return cached access token as an Authorization header. + + Args: + readonly_context: ADK readonly context with invocation info. + audience: Optional audience to look up the correct scoped token. + """ + key = self.cache_key(readonly_context._invocation_context, audience) + access_token = self.token_cache.get(key, "") if not access_token: return {} @@ -97,18 +108,25 @@ async def before_run_callback( *, invocation_context: InvocationContext, ) -> Optional[dict]: - """Extract and store the subject token for later exchange.""" + """Extract and store the subject token from request headers. + + Token exchange is deferred to before_tool_callback so each tool + can exchange with its own audience. + """ headers = invocation_context.session.state.get(HEADERS_KEY, None) subject_token = _extract_jwt_from_headers(headers) if not subject_token: logger.debug("No subject token found in headers for token propagation") return None + key = self.cache_key(invocation_context) self._subject_tokens[key] = subject_token + + # When there is no STS integration, propagate the subject token + # directly under the empty-audience cache key (backward compat). if not self.sts_integration: - # No STS -- propagate the subject token directly so - # header_provider can return it on the first tool call. self.token_cache[key] = subject_token + return None @override @@ -119,22 +137,24 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - """Exchange the subject token via STS before each MCP tool call.""" + """Exchange token with audience scoping on first MCP tool call. + + Only fires for MCPTool instances. Looks up the audience from the + tool's session manager and exchanges (or reuses) a scoped token. + """ if not self.sts_integration: return None - # Only exchange tokens for MCP tool calls. Other tool types - # (memory tools, AgentTool, etc.) don't use header_provider and - # have their own auth mechanisms, so exchanging here would be a - # wasted HTTP round-trip to the STS. - if not isinstance(tool, McpTool): + if not isinstance(tool, MCPTool): return None - key = self.cache_key(tool_context._invocation_context) - # Already exchanged for this session + invocation_context = tool_context._invocation_context + audience = self._resolve_audience(tool) + key = self.cache_key(invocation_context, audience) + if key in self.token_cache: - return None + return None # already exchanged for this (session, audience) - subject_token = self._subject_tokens.get(key) + subject_token = self._subject_tokens.get(self.cache_key(invocation_context)) if not subject_token: return None @@ -144,16 +164,26 @@ async def before_tool_callback( subject_token_type=TokenType.JWT, actor_token=self.sts_integration._actor_token, actor_token_type=TokenType.JWT if self.sts_integration._actor_token else None, + audience=audience if audience else None, ) self.token_cache[key] = access_token except Exception as e: - logger.warning(f"STS token exchange failed: {e}") + logger.warning("STS token exchange failed for audience '%s': %s", audience, e) return None - def cache_key(self, invocation_context: InvocationContext) -> str: - """Generate a cache key based on the session ID.""" - return invocation_context.session.id + def _resolve_audience(self, tool: BaseTool) -> str: + """Look up audience from the tool's session manager.""" + if hasattr(tool, "_mcp_session_manager"): + return self._audience_map.get(id(tool._mcp_session_manager), "") + return "" + + def cache_key(self, invocation_context: InvocationContext, audience: Optional[str] = None) -> str: + """Generate a cache key based on session ID and audience.""" + session_id = invocation_context.session.id + if audience: + return f"{session_id}:{audience}" + return session_id @override async def after_run_callback( @@ -161,8 +191,13 @@ async def after_run_callback( *, invocation_context: InvocationContext, ) -> Optional[dict]: + """Clean up all cached tokens for the completed session.""" key = self.cache_key(invocation_context) - self.token_cache.pop(key, None) + keys_to_remove = [ + k for k in self.token_cache if k == key or k.startswith(f"{key}:") + ] + for k in keys_to_remove: + self.token_cache.pop(k, None) self._subject_tokens.pop(key, None) return None diff --git a/python/packages/agentsts-adk/tests/test_adk_integration.py b/python/packages/agentsts-adk/tests/test_adk_integration.py index 4beab2c33..f7f5f187e 100644 --- a/python/packages/agentsts-adk/tests/test_adk_integration.py +++ b/python/packages/agentsts-adk/tests/test_adk_integration.py @@ -4,7 +4,7 @@ import pytest from google.adk.agents import LlmAgent -from google.adk.tools.mcp_tool.mcp_tool import McpTool +from google.adk.tools.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from agentsts.adk import ADKSTSIntegration, ADKTokenPropagationPlugin @@ -15,7 +15,7 @@ class TestADKTokenPropagationPlugin: - """Unit tests for token propagation plugin covering: none, downstream, and STS exchange.""" + """Unit tests for token propagation plugin covering: none, downstream, STS exchange, and audience scoping.""" def _make_invocation_context(self, session_id: str, headers: dict | None): session = Mock() @@ -37,9 +37,10 @@ def _make_tool_context(self, invocation_context): tool_context._invocation_context = invocation_context return tool_context - def _make_mcp_tool(self): - tool = Mock(spec=McpTool) - tool.name = "test_tool" + def _make_mcp_tool(self, session_manager=None): + tool = Mock(spec=MCPTool) + if session_manager is not None: + tool._mcp_session_manager = session_manager return tool def test_init(self): @@ -48,6 +49,30 @@ def test_init(self): assert plugin.name == "ADKTokenPropagationPlugin" assert plugin.sts_integration is mock_sts_integration assert plugin.token_cache == {} + assert plugin._subject_tokens == {} + assert plugin._audience_map == {} + + def test_register_toolset(self): + """Registering a toolset maps its session manager id to audience.""" + plugin = ADKTokenPropagationPlugin() + toolset = Mock(spec=MCPToolset) + session_manager = Mock() + toolset._mcp_session_manager = session_manager + + plugin.register_toolset(toolset, "https://audience.example.com") + assert plugin._audience_map[id(session_manager)] == "https://audience.example.com" + + def test_register_toolset_no_audience(self): + """Registering with None or empty audience is a no-op.""" + plugin = ADKTokenPropagationPlugin() + toolset = Mock(spec=MCPToolset) + toolset._mcp_session_manager = Mock() + + plugin.register_toolset(toolset, None) + assert plugin._audience_map == {} + + plugin.register_toolset(toolset, "") + assert plugin._audience_map == {} @pytest.mark.asyncio async def test_before_run_callback_no_headers(self): @@ -59,6 +84,7 @@ async def test_before_run_callback_no_headers(self): assert result is None mock_logger.debug.assert_called_once_with("No subject token found in headers for token propagation") assert plugin.token_cache == {} + assert plugin._subject_tokens == {} @pytest.mark.asyncio async def test_downstream_token_propagation_without_sts(self): @@ -67,6 +93,9 @@ async def test_downstream_token_propagation_without_sts(self): ic = self._make_invocation_context("sess-2", headers={"Authorization": "Bearer subj-token-123"}) result = await plugin.before_run_callback(invocation_context=ic) assert result is None + # Subject token stored for later + assert plugin._subject_tokens["sess-2"] == "subj-token-123" + # Without STS, subject token is directly cached under session key assert plugin.token_cache["sess-2"] == "subj-token-123" # propagate toolset @@ -85,26 +114,28 @@ async def test_downstream_token_propagation_without_sts(self): # cleanup await plugin.after_run_callback(invocation_context=ic) assert "sess-2" not in plugin.token_cache + assert "sess-2" not in plugin._subject_tokens @pytest.mark.asyncio - async def test_sts_token_exchange_success(self): - """Case: STS integration -- before_run stores subject token, before_tool_callback exchanges it.""" + async def test_before_tool_callback_with_sts_exchange(self): + """Case: STS integration exchanges token on first MCP tool call.""" sts = Mock(spec=ADKSTSIntegration) sts._actor_token = "actor-token" sts.exchange_token = AsyncMock(return_value="access-token-XYZ") plugin = ADKTokenPropagationPlugin(sts) - ic = self._make_invocation_context("sess-3", headers={"Authorization": "Bearer original-subject"}) - # before_run_callback should store the subject token but NOT exchange - result = await plugin.before_run_callback(invocation_context=ic) - assert result is None - sts.exchange_token.assert_not_called() - assert "sess-3" not in plugin.token_cache + ic = self._make_invocation_context("sess-3", headers={"Authorization": "Bearer original-subject"}) + # before_run stores subject token + await plugin.before_run_callback(invocation_context=ic) assert plugin._subject_tokens["sess-3"] == "original-subject" + # No exchange yet (STS present, exchange deferred to before_tool_callback) + assert plugin.token_cache == {} - # before_tool_callback should exchange on first MCP tool call - tool = self._make_mcp_tool() + # before_tool_callback triggers exchange + session_manager = Mock() + tool = self._make_mcp_tool(session_manager=session_manager) tc = self._make_tool_context(ic) + result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) assert result is None sts.exchange_token.assert_called_once_with( @@ -112,98 +143,182 @@ async def test_sts_token_exchange_success(self): subject_token_type=TokenType.JWT, actor_token="actor-token", actor_token_type=TokenType.JWT, + audience=None, # no audience registered for this tool ) assert plugin.token_cache["sess-3"] == "access-token-XYZ" - # header_provider should return the exchanged token ro_ctx = self._make_readonly_context(ic) headers = plugin.header_provider(ro_ctx) assert headers == {"Authorization": "Bearer access-token-XYZ"} - # second tool call should not exchange again (cached) - sts.exchange_token.reset_mock() - result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) - assert result is None - sts.exchange_token.assert_not_called() - await plugin.after_run_callback(invocation_context=ic) assert "sess-3" not in plugin.token_cache - assert "sess-3" not in plugin._subject_tokens @pytest.mark.asyncio - async def test_sts_token_exchange_failure(self): - """Case: STS exchange raises in before_tool_callback -> no cache entry, graceful warning.""" + async def test_before_tool_callback_with_audience(self): + """Case: STS exchange with audience-scoped token.""" sts = Mock(spec=ADKSTSIntegration) sts._actor_token = "actor-token" - sts.exchange_token = AsyncMock(side_effect=Exception("boom")) + sts.exchange_token = AsyncMock(return_value="scoped-token-ABC") plugin = ADKTokenPropagationPlugin(sts) - ic = self._make_invocation_context("sess-4", headers={"Authorization": "Bearer original-subject"}) + # Register a toolset with an audience + session_manager = Mock() + toolset = Mock(spec=MCPToolset) + toolset._mcp_session_manager = session_manager + plugin.register_toolset(toolset, "https://my-service.example.com") + + ic = self._make_invocation_context("sess-aud", headers={"Authorization": "Bearer subj"}) await plugin.before_run_callback(invocation_context=ic) - assert plugin._subject_tokens["sess-4"] == "original-subject" - tool = self._make_mcp_tool() + # Tool using that session_manager + tool = self._make_mcp_tool(session_manager=session_manager) tc = self._make_tool_context(ic) - with patch("agentsts.adk._base.logger") as mock_logger: - result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) - assert result is None - mock_logger.warning.assert_called_once() - assert "sess-4" not in plugin.token_cache - # header provider should yield empty dict + result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) + assert result is None + sts.exchange_token.assert_called_once_with( + subject_token="subj", + subject_token_type=TokenType.JWT, + actor_token="actor-token", + actor_token_type=TokenType.JWT, + audience="https://my-service.example.com", + ) + assert plugin.token_cache["sess-aud:https://my-service.example.com"] == "scoped-token-ABC" + + # header_provider with audience returns scoped token ro_ctx = self._make_readonly_context(ic) - assert plugin.header_provider(ro_ctx) == {} + headers = plugin.header_provider(ro_ctx, audience="https://my-service.example.com") + assert headers == {"Authorization": "Bearer scoped-token-ABC"} + + @pytest.mark.asyncio + async def test_per_audience_cache_isolation(self): + """Tools with different audiences get different tokens.""" + call_count = 0 + + async def mock_exchange(**kwargs): + nonlocal call_count + call_count += 1 + audience = kwargs.get("audience", "") + return f"token-for-{audience}-{call_count}" + + sts = Mock(spec=ADKSTSIntegration) + sts._actor_token = "actor" + sts.exchange_token = AsyncMock(side_effect=mock_exchange) + plugin = ADKTokenPropagationPlugin(sts) + + # Two toolsets with different audiences + sm_a = Mock() + toolset_a = Mock(spec=MCPToolset) + toolset_a._mcp_session_manager = sm_a + plugin.register_toolset(toolset_a, "audience-A") + + sm_b = Mock() + toolset_b = Mock(spec=MCPToolset) + toolset_b._mcp_session_manager = sm_b + plugin.register_toolset(toolset_b, "audience-B") + + ic = self._make_invocation_context("sess-iso", headers={"Authorization": "Bearer subj"}) + await plugin.before_run_callback(invocation_context=ic) + + tc = self._make_tool_context(ic) + + # First tool call (audience A) + tool_a = self._make_mcp_tool(session_manager=sm_a) + await plugin.before_tool_callback(tool=tool_a, tool_args={}, tool_context=tc) + assert plugin.token_cache["sess-iso:audience-A"] == "token-for-audience-A-1" + + # Second tool call (audience B) + tool_b = self._make_mcp_tool(session_manager=sm_b) + await plugin.before_tool_callback(tool=tool_b, tool_args={}, tool_context=tc) + assert plugin.token_cache["sess-iso:audience-B"] == "token-for-audience-B-2" + + # Repeat call for audience A should NOT re-exchange (cached) + await plugin.before_tool_callback(tool=tool_a, tool_args={}, tool_context=tc) + assert sts.exchange_token.call_count == 2 # still only 2 calls + + # header_provider returns correct token per audience + ro_ctx = self._make_readonly_context(ic) + assert plugin.header_provider(ro_ctx, audience="audience-A") == {"Authorization": "Bearer token-for-audience-A-1"} + assert plugin.header_provider(ro_ctx, audience="audience-B") == {"Authorization": "Bearer token-for-audience-B-2"} @pytest.mark.asyncio async def test_before_tool_callback_skips_non_mcp_tools(self): - """Case: before_tool_callback ignores non-MCP tools.""" + """Non-MCPTool tools are skipped.""" sts = Mock(spec=ADKSTSIntegration) - sts._actor_token = "actor-token" - sts.exchange_token = AsyncMock(return_value="access-token") plugin = ADKTokenPropagationPlugin(sts) - ic = self._make_invocation_context("sess-7", headers={"Authorization": "Bearer subj"}) + + ic = self._make_invocation_context("sess-skip", headers={"Authorization": "Bearer subj"}) await plugin.before_run_callback(invocation_context=ic) - non_mcp_tool = Mock() # not a McpTool + non_mcp_tool = Mock() # not an MCPTool tc = self._make_tool_context(ic) result = await plugin.before_tool_callback(tool=non_mcp_tool, tool_args={}, tool_context=tc) assert result is None - sts.exchange_token.assert_not_called() + assert plugin.token_cache == {} # no exchange happened @pytest.mark.asyncio async def test_before_tool_callback_no_sts(self): - """Case: before_tool_callback is a no-op without STS integration.""" + """Without STS integration, before_tool_callback is a no-op.""" plugin = ADKTokenPropagationPlugin(sts_integration=None) - ic = self._make_invocation_context("sess-8", headers={"Authorization": "Bearer subj"}) + + ic = self._make_invocation_context("sess-no-sts", headers={"Authorization": "Bearer subj"}) await plugin.before_run_callback(invocation_context=ic) - tool = self._make_mcp_tool() + tool = self._make_mcp_tool(session_manager=Mock()) tc = self._make_tool_context(ic) result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) assert result is None - # token_cache should still have the subject token from before_run - assert plugin.token_cache["sess-8"] == "subj" + + @pytest.mark.asyncio + async def test_sts_token_exchange_failure(self): + """Case: STS exchange raises -> no cache entry, graceful warning.""" + sts = Mock(spec=ADKSTSIntegration) + sts._actor_token = "actor-token" + sts.exchange_token = AsyncMock(side_effect=Exception("boom")) + plugin = ADKTokenPropagationPlugin(sts) + ic = self._make_invocation_context("sess-4", headers={"Authorization": "Bearer original-subject"}) + await plugin.before_run_callback(invocation_context=ic) + + tool = self._make_mcp_tool(session_manager=Mock()) + tc = self._make_tool_context(ic) + with patch("agentsts.adk._base.logger") as mock_logger: + result = await plugin.before_tool_callback(tool=tool, tool_args={}, tool_context=tc) + assert result is None + mock_logger.warning.assert_called_once() + assert "sess-4" not in plugin.token_cache + + # header provider should yield empty dict + ro_ctx = self._make_readonly_context(ic) + assert plugin.header_provider(ro_ctx) == {} def test_header_provider_no_entry(self): """Case: header_provider called with no cached token -> returns empty dict.""" plugin = ADKTokenPropagationPlugin() ic = self._make_invocation_context("sess-5", headers=None) ro_ctx = self._make_readonly_context(ic) - # token_cache intentionally missing key -> KeyError would occur; simulate by setting empty string - plugin.token_cache["sess-5"] = "" # empty token should result in {} assert plugin.header_provider(ro_ctx) == {} @pytest.mark.asyncio - async def test_after_run_callback_removes_token(self): - """Case: after_run_callback removes cached token and subject token.""" + async def test_after_run_callback_removes_all_audience_tokens(self): + """after_run_callback removes all cached tokens for the session across audiences.""" plugin = ADKTokenPropagationPlugin() ic = self._make_invocation_context("sess-6", headers={"Authorization": "Bearer AAA"}) - await plugin.before_run_callback(invocation_context=ic) - assert "sess-6" in plugin.token_cache - assert "sess-6" in plugin._subject_tokens + # Simulate multiple cached tokens for different audiences + plugin.token_cache["sess-6"] = "token-default" + plugin.token_cache["sess-6:aud-A"] = "token-A" + plugin.token_cache["sess-6:aud-B"] = "token-B" + plugin.token_cache["sess-7:aud-A"] = "other-session-token" + plugin._subject_tokens["sess-6"] = "AAA" + await plugin.after_run_callback(invocation_context=ic) + + # All sess-6 entries removed (both bare key and audience-scoped keys) assert "sess-6" not in plugin.token_cache + assert not any(k.startswith("sess-6:") for k in plugin.token_cache) assert "sess-6" not in plugin._subject_tokens + # Other sessions untouched + assert plugin.token_cache["sess-7:aud-A"] == "other-session-token" def test_extract_jwt_from_headers_success(self): """Test successful JWT extraction from headers.""" diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 26c4c6df7..b2e8e555a 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -25,6 +25,10 @@ class KAgentMcpToolset(McpToolset): implementation may not catch and propagate without enough context. """ + def __init__(self, *, sts_audience: str | None = None, **kwargs): + super().__init__(**kwargs) + self.sts_audience = sts_audience + async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]: try: return await super().get_tools(readonly_context) diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 26b3df144..077b58870 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -34,7 +34,8 @@ def create_header_provider( allowed_headers: list[str] | None = None, - sts_header_provider: Callable[[Optional[ReadonlyContext]], dict[str, str]] | None = None, + sts_header_provider: Callable[[Optional[ReadonlyContext], Optional[str]], dict[str, str]] | None = None, + sts_audience: str | None = None, ) -> Callable[[Optional[ReadonlyContext]], dict[str, str]] | None: """Create a header provider that combines STS tokens and allowed headers. @@ -47,6 +48,10 @@ def create_header_provider( will take precedence over any matching allowed headers. This is a security measure to prevent request headers from overwriting authentication tokens generated by STS. + Signature: (readonly_context, audience) -> headers + sts_audience: Optional audience value for STS token scoping for this toolset. + Passed to the STS header provider to retrieve the correct + audience-scoped token. Returns: A header provider function, or None if no headers need to be propagated. @@ -71,7 +76,7 @@ def header_provider(readonly_context: Optional[ReadonlyContext]) -> dict[str, st # allowed headers from overwriting authentication tokens) # Use case-insensitive replacement to handle header name case variations if sts_header_provider: - sts_headers = sts_header_provider(readonly_context) + sts_headers = sts_header_provider(readonly_context, sts_audience) if sts_headers: for sts_key, sts_value in sts_headers.items(): # Remove any existing header with same name (case-insensitive) @@ -144,6 +149,7 @@ class HttpMcpServerConfig(BaseModel): tools: list[str] = Field(default_factory=list) allowed_headers: list[str] | None = None # Headers to forward from A2A request to MCP calls require_approval: list[str] | None = None # Tools requiring human approval before execution + sts_audience: str | None = None # Audience for STS token exchange scoping class SseMcpServerConfig(BaseModel): @@ -151,6 +157,7 @@ class SseMcpServerConfig(BaseModel): tools: list[str] = Field(default_factory=list) allowed_headers: list[str] | None = None # Headers to forward from A2A request to MCP calls require_approval: list[str] | None = None # Tools requiring human approval before execution + sts_audience: str | None = None # Audience for STS token exchange scoping class RemoteAgentConfig(BaseModel): @@ -278,14 +285,17 @@ def to_agent(self, name: str, sts_integration: Optional[ADKTokenPropagationPlugi tool_header_provider = create_header_provider( allowed_headers=http_tool.allowed_headers, sts_header_provider=sts_header_provider, + sts_audience=http_tool.sts_audience, ) - tools.append( - KAgentMcpToolset( - connection_params=http_tool.params, - tool_filter=http_tool.tools, - header_provider=tool_header_provider, - ) + toolset = KAgentMcpToolset( + connection_params=http_tool.params, + tool_filter=http_tool.tools, + header_provider=tool_header_provider, + sts_audience=http_tool.sts_audience, ) + tools.append(toolset) + if sts_integration and http_tool.sts_audience: + sts_integration.register_toolset(toolset, http_tool.sts_audience) if http_tool.require_approval: tools_requiring_approval.update(http_tool.require_approval) if self.sse_tools: @@ -294,14 +304,17 @@ def to_agent(self, name: str, sts_integration: Optional[ADKTokenPropagationPlugi tool_header_provider = create_header_provider( allowed_headers=sse_tool.allowed_headers, sts_header_provider=sts_header_provider, + sts_audience=sse_tool.sts_audience, ) - tools.append( - KAgentMcpToolset( - connection_params=sse_tool.params, - tool_filter=sse_tool.tools, - header_provider=tool_header_provider, - ) + toolset = KAgentMcpToolset( + connection_params=sse_tool.params, + tool_filter=sse_tool.tools, + header_provider=tool_header_provider, + sts_audience=sse_tool.sts_audience, ) + tools.append(toolset) + if sts_integration and sse_tool.sts_audience: + sts_integration.register_toolset(toolset, sse_tool.sts_audience) if sse_tool.require_approval: tools_requiring_approval.update(sse_tool.require_approval) if self.remote_agents: diff --git a/python/packages/kagent-adk/tests/unittests/test_header_propagation.py b/python/packages/kagent-adk/tests/unittests/test_header_propagation.py index d7c3930f4..52eeead85 100644 --- a/python/packages/kagent-adk/tests/unittests/test_header_propagation.py +++ b/python/packages/kagent-adk/tests/unittests/test_header_propagation.py @@ -74,7 +74,7 @@ def test_case_insensitive_header_matching(self): def test_combines_sts_and_allowed_headers(self): """Test that STS headers and allowed headers are combined.""" - def mock_sts_provider(ctx): + def mock_sts_provider(ctx, audience=None): return {"Authorization": "Bearer token123"} provider = create_header_provider( @@ -105,7 +105,7 @@ def test_sts_headers_take_precedence_over_allowed_headers(self): for header names to verify case-insensitive handling. """ - def mock_sts_provider(ctx): + def mock_sts_provider(ctx, audience=None): # STS returns "Authorization" with capital A return {"Authorization": "Bearer sts-token"} @@ -135,7 +135,7 @@ def mock_sts_provider(ctx): def test_sts_only_when_no_allowed_headers(self): """Test that only STS headers are returned when no allowed headers.""" - def mock_sts_provider(ctx): + def mock_sts_provider(ctx, audience=None): return {"Authorization": "Bearer token123"} provider = create_header_provider(sts_header_provider=mock_sts_provider) @@ -163,6 +163,43 @@ def test_handles_none_context(self): assert headers == {} + def test_sts_audience_forwarded_to_provider(self): + """Test that sts_audience is forwarded to the STS header provider.""" + received_audiences = [] + + def mock_sts_provider(ctx, audience=None): + received_audiences.append(audience) + return {"Authorization": f"Bearer token-for-{audience}"} + + provider = create_header_provider( + sts_header_provider=mock_sts_provider, + sts_audience="https://my-service.example.com", + ) + assert provider is not None + + context = MockReadonlyContext(state={"headers": {}}) + headers = provider(context) + assert headers == {"Authorization": "Bearer token-for-https://my-service.example.com"} + assert received_audiences == ["https://my-service.example.com"] + + def test_sts_audience_none_when_not_specified(self): + """Test that audience is None when not specified.""" + received_audiences = [] + + def mock_sts_provider(ctx, audience=None): + received_audiences.append(audience) + return {"Authorization": "Bearer token"} + + provider = create_header_provider( + sts_header_provider=mock_sts_provider, + ) + assert provider is not None + + context = MockReadonlyContext(state={"headers": {}}) + provider(context) + assert received_audiences == [None] + + class TestMcpServerConfigAllowedHeaders: """Tests for allowed_headers field in MCP server configs.""" @@ -195,3 +232,33 @@ def test_sse_mcp_config_allowed_headers_default_none(self): params=SseConnectionParams(url="http://localhost:8080"), ) assert config.allowed_headers is None + + def test_http_mcp_config_has_sts_audience(self): + """Test that HttpMcpServerConfig has sts_audience field.""" + config = HttpMcpServerConfig( + params=StreamableHTTPConnectionParams(url="http://localhost:8080"), + sts_audience="https://my-service.example.com", + ) + assert config.sts_audience == "https://my-service.example.com" + + def test_http_mcp_config_sts_audience_default_none(self): + """Test that sts_audience defaults to None.""" + config = HttpMcpServerConfig( + params=StreamableHTTPConnectionParams(url="http://localhost:8080"), + ) + assert config.sts_audience is None + + def test_sse_mcp_config_has_sts_audience(self): + """Test that SseMcpServerConfig has sts_audience field.""" + config = SseMcpServerConfig( + params=SseConnectionParams(url="http://localhost:8080"), + sts_audience="https://my-service.example.com", + ) + assert config.sts_audience == "https://my-service.example.com" + + def test_sse_mcp_config_sts_audience_default_none(self): + """Test that sts_audience defaults to None.""" + config = SseMcpServerConfig( + params=SseConnectionParams(url="http://localhost:8080"), + ) + assert config.sts_audience is None From cc07ac7f77bfa34336b78403bb61355fb8dabaa0 Mon Sep 17 00:00:00 2001 From: Simon Zhu Date: Thu, 5 Mar 2026 03:54:11 -0500 Subject: [PATCH 3/3] fix: remove dead code and make mock STS server audience-aware - Remove unused imports from _base.py (BaseAgent, LlmAgent, AuthCredential, etc.) - Remove add_to_agent method that would overwrite audience-aware closures created by create_header_provider in types.py - Remove LlmAgent import and add_to_agent test code from test_adk_integration.py - Make mock STS server include 'aud' claim in generated tokens when audience is provided, improving E2E test coverage for per-audience token exchange Co-Authored-By: Claude Opus 4.6 Signed-off-by: Simon Zhu --- go/core/test/e2e/mocks/mock_sts_server.go | 8 +++++-- .../agentsts-adk/src/agentsts/adk/_base.py | 23 ------------------- .../tests/test_adk_integration.py | 9 -------- 3 files changed, 6 insertions(+), 34 deletions(-) diff --git a/go/core/test/e2e/mocks/mock_sts_server.go b/go/core/test/e2e/mocks/mock_sts_server.go index fc30b9b66..3c8724311 100644 --- a/go/core/test/e2e/mocks/mock_sts_server.go +++ b/go/core/test/e2e/mocks/mock_sts_server.go @@ -178,7 +178,7 @@ func (m *MockSTSServer) handleTokenExchange(w http.ResponseWriter, r *http.Reque return } - accessToken, err := m.generateMockAccessToken(req.SubjectToken) + accessToken, err := m.generateMockAccessToken(req.SubjectToken, req.Audience) if err != nil { http.Error(w, fmt.Sprintf("Error generating mock access token: %v", err), http.StatusBadRequest) return @@ -200,7 +200,7 @@ func (m *MockSTSServer) handleTokenExchange(w http.ResponseWriter, r *http.Reque m.requests = append(m.requests, req) } -func (m *MockSTSServer) generateMockAccessToken(subjectToken string) (string, error) { +func (m *MockSTSServer) generateMockAccessToken(subjectToken string, audience string) (string, error) { // Try to parse JWT token to extract subject claim subject, err := extractSubjectFromJWT(subjectToken) if err != nil { @@ -219,6 +219,10 @@ func (m *MockSTSServer) generateMockAccessToken(subjectToken string) (string, er "iss": "mock-sts-server", } + if audience != "" { + tokenData["aud"] = audience + } + // For testing purposes, we'll return a simple JSON string // In a real implementation, this would be a signed JWT tokenBytes, err := json.Marshal(tokenData) diff --git a/python/packages/agentsts-adk/src/agentsts/adk/_base.py b/python/packages/agentsts-adk/src/agentsts/adk/_base.py index d38b0a30a..4bfefde58 100644 --- a/python/packages/agentsts-adk/src/agentsts/adk/_base.py +++ b/python/packages/agentsts-adk/src/agentsts/adk/_base.py @@ -3,15 +3,9 @@ import logging from typing import Any, Dict, Optional -from google.adk.agents import BaseAgent, LlmAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.readonly_context import ReadonlyContext -from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, HttpAuth, HttpCredentials -from google.adk.events.event import Event from google.adk.plugins.base_plugin import BasePlugin -from google.adk.runners import Runner -from google.adk.sessions import BaseSessionService -from google.adk.sessions.session import Session from google.adk.tools.base_tool import BaseTool from google.adk.tools.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset @@ -67,23 +61,6 @@ def register_toolset(self, toolset: MCPToolset, audience: Optional[str]) -> None if audience and hasattr(toolset, "_mcp_session_manager"): self._audience_map[id(toolset._mcp_session_manager)] = audience - def add_to_agent(self, agent: BaseAgent): - """ - Add the plugin to an ADK LLM agent by updating its MCP toolset - Call this once when setting up the agent; do not call it at runtime. - """ - if not isinstance(agent, LlmAgent): - return - - if not agent.tools: - return - - for tool in agent.tools: - if isinstance(tool, MCPToolset): - mcp_toolset = tool - mcp_toolset._header_provider = self.header_provider - logger.debug("Updated tool connection params to include access token from STS server") - def header_provider( self, readonly_context: Optional[ReadonlyContext], audience: Optional[str] = None ) -> Dict[str, str]: diff --git a/python/packages/agentsts-adk/tests/test_adk_integration.py b/python/packages/agentsts-adk/tests/test_adk_integration.py index f7f5f187e..513ddfbde 100644 --- a/python/packages/agentsts-adk/tests/test_adk_integration.py +++ b/python/packages/agentsts-adk/tests/test_adk_integration.py @@ -3,7 +3,6 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from google.adk.agents import LlmAgent from google.adk.tools.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset @@ -98,14 +97,6 @@ async def test_downstream_token_propagation_without_sts(self): # Without STS, subject token is directly cached under session key assert plugin.token_cache["sess-2"] == "subj-token-123" - # propagate toolset - mcp_toolset = Mock(spec=MCPToolset) - agent = Mock(spec=LlmAgent) - agent.tools = [mcp_toolset] - plugin.add_to_agent(agent) - # The toolset._header_provider should be callable - assert callable(mcp_toolset._header_provider) - # header provider should return subject token ro_ctx = self._make_readonly_context(ic) headers = plugin.header_provider(ro_ctx)