diff --git a/go/api/database/client.go b/go/api/database/client.go index ea59cc936..9eb5f5564 100644 --- a/go/api/database/client.go +++ b/go/api/database/client.go @@ -43,6 +43,7 @@ type Client interface { GetTool(name string) (*Tool, error) GetToolServer(name string) (*ToolServer, error) GetPushNotification(taskID string, configID string) (*protocol.TaskPushNotificationConfig, error) + GetSubAgentSession(sessionID, toolCallID string) (string, error) // List methods ListTools() ([]Tool, error) diff --git a/go/core/internal/database/client.go b/go/core/internal/database/client.go index d7f7d46fb..8dc3bbccc 100644 --- a/go/core/internal/database/client.go +++ b/go/core/internal/database/client.go @@ -113,6 +113,28 @@ func (c *clientImpl) GetSession(sessionID string, userID string) (*dbpkg.Session Clause{Key: "user_id", Value: userID}) } +func (c *clientImpl) GetSubAgentSession(sessionID, toolCallID string) (string, error) { + var sessionName string + query := c.db.Table("task").Select("session_id") + + if c.db.Name() == "sqlite" { + query = query.Where("json_extract(data, '$.metadata.kagent_caller_session_id') = ?", sessionID). + Where("json_extract(data, '$.metadata.kagent_caller_tool_call_id') = ?", toolCallID) + } else { + query = query.Where("(data::json -> 'metadata' ->> 'kagent_caller_session_id') = ?", sessionID). + Where("(data::json -> 'metadata' ->> 'kagent_caller_tool_call_id') = ?", toolCallID) + } + + err := query.Order("created_at DESC").Limit(1).Pluck("session_id", &sessionName).Error + if err != nil { + return "", err + } + if sessionName == "" { + return "", errors.New("sub-agent session not found") + } + return sessionName, nil +} + // GetAgent retrieves an agent by name and user ID func (c *clientImpl) GetAgent(agentID string) (*dbpkg.Agent, error) { return get[dbpkg.Agent](c.db, Clause{Key: "id", Value: agentID}) diff --git a/go/core/internal/database/client_test.go b/go/core/internal/database/client_test.go index ca77cba36..13ae19136 100644 --- a/go/core/internal/database/client_test.go +++ b/go/core/internal/database/client_test.go @@ -572,3 +572,31 @@ func TestPruneExpiredMemories(t *testing.T) { assert.Contains(t, ids, "prune-hot", "Expired popular memory should have TTL extended and be retained") assert.Contains(t, ids, "prune-live", "Non-expired memory should be retained") } + +func TestGetSubAgentSession(t *testing.T) { + db := setupTestDB(t) + client := NewClient(db) + + sessionID := "parent-session" + toolCallID := "call-1" + subSessionID := "sub-session-1" + + taskData := `{"metadata": {"kagent_caller_session_id": "parent-session", "kagent_caller_tool_call_id": "call-1"}}` + + err := client.(*clientImpl).db.Create(&dbpkg.Task{ + ID: "task-1", + SessionID: subSessionID, + Data: taskData, + CreatedAt: time.Now(), + }).Error + require.NoError(t, err) + + // Test success + foundSessionID, err := client.GetSubAgentSession(sessionID, toolCallID) + require.NoError(t, err) + assert.Equal(t, subSessionID, foundSessionID) + + // Test not found (wrong session id) + _, err = client.GetSubAgentSession("wrong-session", toolCallID) + assert.Error(t, err) +} diff --git a/go/core/internal/database/fake/client.go b/go/core/internal/database/fake/client.go index d40231bf5..e72acaa41 100644 --- a/go/core/internal/database/fake/client.go +++ b/go/core/internal/database/fake/client.go @@ -251,7 +251,47 @@ func (c *InMemoryFakeClient) GetSession(sessionID string, userID string) (*datab return session, nil } -// GetAgent retrieves an agent by name +// GetSubAgentSession finds a subagent session based on parent session ID and tool call ID +func (c *InMemoryFakeClient) GetSubAgentSession(sessionID, toolCallID string) (string, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + var matchedTasks []*database.Task + + for _, task := range c.tasks { + var taskData protocol.Task + if err := json.Unmarshal([]byte(task.Data), &taskData); err != nil { + continue + } + + if taskData.Metadata == nil { + continue + } + + if taskData.Metadata["kagent_caller_session_id"] == sessionID && + taskData.Metadata["kagent_caller_tool_call_id"] == toolCallID { + matchedTasks = append(matchedTasks, task) + } + } + + if len(matchedTasks) == 0 { + return "", gorm.ErrRecordNotFound + } + + // Sort by created_at DESC to get the latest + slices.SortStableFunc(matchedTasks, func(a, b *database.Task) int { + if a.CreatedAt.After(b.CreatedAt) { + return -1 + } + if a.CreatedAt.Before(b.CreatedAt) { + return 1 + } + return 0 + }) + + return matchedTasks[0].SessionID, nil +} + func (c *InMemoryFakeClient) GetAgent(agentName string) (*database.Agent, error) { c.mu.RLock() defer c.mu.RUnlock() diff --git a/go/core/internal/httpserver/handlers/sessions.go b/go/core/internal/httpserver/handlers/sessions.go index 51ff3a807..46840e2cc 100644 --- a/go/core/internal/httpserver/handlers/sessions.go +++ b/go/core/internal/httpserver/handlers/sessions.go @@ -395,6 +395,51 @@ func (h *SessionsHandler) HandleAddEventToSession(w ErrorResponseWriter, r *http RespondWithJSON(w, http.StatusCreated, data) } +func (h *SessionsHandler) HandleGetSubAgentSession(w ErrorResponseWriter, r *http.Request) { + log := ctrllog.FromContext(r.Context()).WithName("sessions-handler").WithValues("operation", "get-sub-agent-session") + + parentSessionID, err := GetPathParam(r, "session_id") + if err != nil { + w.RespondWithError(errors.NewBadRequestError("Missing required path parameter: session_id", err)) + return + } + toolCallID, err := GetPathParam(r, "tool_call_id") + if err != nil { + w.RespondWithError(errors.NewBadRequestError("Missing required path parameter: tool_call_id", err)) + return + } + + userID, err := getUserIDOrAgentUser(r) + if err != nil { + w.RespondWithError(errors.NewBadRequestError("Failed to get user ID", err)) + return + } + + log = log.WithValues("parentSessionID", parentSessionID, "toolCallID", toolCallID, "userID", userID) + + sessionID, err := h.DatabaseService.GetSubAgentSession(parentSessionID, toolCallID) + if err != nil { + w.RespondWithError(errors.NewNotFoundError("Sub-agent session not found", err)) + return + } + + // Now get the session details as standard session response + session, err := h.DatabaseService.GetSession(sessionID, userID) + if err != nil { + w.RespondWithError(errors.NewNotFoundError("Session not found", err)) + return + } + + // We don't return events here for the lookup as it's a lookup, but we can if we want to follow HandleGetSession + // For now let's just return the session to be consistent with creating a session + log.Info("Successfully found sub-agent session", "sessionID", sessionID) + data := api.NewResponse(SessionResponse{ + Session: session, + Events: nil, + }, "Successfully found sub-agent session", false) + RespondWithJSON(w, http.StatusOK, data) +} + func getUserID(r *http.Request) (string, error) { log := ctrllog.Log.WithName("http-helpers") diff --git a/go/core/internal/httpserver/handlers/sessions_test.go b/go/core/internal/httpserver/handlers/sessions_test.go index 65f57c700..c4280f523 100644 --- a/go/core/internal/httpserver/handlers/sessions_test.go +++ b/go/core/internal/httpserver/handlers/sessions_test.go @@ -517,9 +517,72 @@ func TestSessionsHandler(t *testing.T) { req = mux.SetURLVars(req, map[string]string{"session_id": sessionID}) handler.HandleListTasksForSession(responseRecorder, req) - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) assert.NotNil(t, responseRecorder.errorReceived) }) }) + + t.Run("HandleGetSubAgentSession", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { + handler, dbClient, responseRecorder := setupHandler() + sessionID := "parent-session" + toolCallID := "call-1" + subSessionID := "sub-session-1" + userID := "test-user" + + // Create sub-session + dbClient.StoreSession(&database.Session{ + ID: subSessionID, + UserID: userID, + }) + + // Add task with caller metadata + dbClient.AddTask(&database.Task{ + ID: "task-1", + SessionID: subSessionID, + Data: `{"id": "task-1", "metadata": {"kagent_caller_session_id": "parent-session", "kagent_caller_tool_call_id": "call-1"}}`, + CreatedAt: time.Now(), + }) + + req := httptest.NewRequest("GET", "/api/sessions/"+sessionID+"/subagentsessions/"+toolCallID, nil) + req = mux.SetURLVars(req, map[string]string{ + "session_id": sessionID, + "tool_call_id": toolCallID, + }) + req = setUser(req, userID) + handler.HandleGetSubAgentSession(responseRecorder, req) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + var response api.StandardResponse[handlers.SessionResponse] + err := json.Unmarshal(responseRecorder.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, subSessionID, response.Data.Session.ID) + }) + + t.Run("MissingParams", func(t *testing.T) { + handler, _, responseRecorder := setupHandler() + + req := httptest.NewRequest("GET", "/api/sessions/s/toolcalls/c/subagentsession", nil) + // No URL vars set - will fail on first GetPathParam + req = setUser(req, "test-user") + handler.HandleGetSubAgentSession(responseRecorder, req) + + assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) + }) + + t.Run("NotFound", func(t *testing.T) { + handler, _, responseRecorder := setupHandler() + + req := httptest.NewRequest("GET", "/api/sessions/c/toolcalls/d/subagentsession", nil) + req = mux.SetURLVars(req, map[string]string{ + "session_id": "c", + "tool_call_id": "d", + }) + req = setUser(req, "test-user") + handler.HandleGetSubAgentSession(responseRecorder, req) + + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) + }) + }) } diff --git a/go/core/internal/httpserver/server.go b/go/core/internal/httpserver/server.go index 787785789..b989136b3 100644 --- a/go/core/internal/httpserver/server.go +++ b/go/core/internal/httpserver/server.go @@ -214,6 +214,7 @@ func (s *HTTPServer) setupRoutes() { // Sessions - using database handlers s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleListSessions)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleCreateSession)).Methods(http.MethodPost) + s.router.HandleFunc(APIPathSessions+"/{session_id}/subagentsessions/{tool_call_id}", adaptHandler(s.handlers.Sessions.HandleGetSubAgentSession)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions+"/agent/{namespace}/{name}", adaptHandler(s.handlers.Sessions.HandleGetSessionsForAgent)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleGetSession)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions+"/{session_id}/tasks", adaptHandler(s.handlers.Sessions.HandleListTasksForSession)).Methods(http.MethodGet) diff --git a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py index a4ac5e280..280942ca5 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py +++ b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py @@ -40,6 +40,7 @@ from kagent.core.a2a import ( KAGENT_HITL_DECISION_TYPE_APPROVE, KAGENT_HITL_DECISION_TYPE_BATCH, + KAGENT_METADATA_KEY_PREFIX, TaskResultAggregator, extract_ask_user_answers_from_message, extract_batch_decisions_from_message, @@ -475,6 +476,12 @@ async def _handle_request( runner: Runner, run_args: dict[str, Any], ): + # If the caller propagated a user_id via A2A request metadata, + # use it so the child session is owned by the same identity. + user_id_key = get_kagent_metadata_key("user_id") + if user_id_key in context.metadata: + run_args["user_id"] = context.metadata[user_id_key] + # ensure the session exists session = await self._prepare_session(context, run_args, runner) @@ -515,6 +522,16 @@ async def _handle_request( get_kagent_metadata_key("session_id"): run_args["session_id"], } + # Include caller metadata from A2A request metadata if present. + # When a parent agent calls a child via RemoteA2aAgent, the + # a2a_request_meta_provider injects kagent_* keys into the A2A + # wire protocol's MessageSendParams.metadata which arrives here + # via context.metadata. + if context.metadata: + for key, value in context.metadata.items(): + if key.startswith(KAGENT_METADATA_KEY_PREFIX): + run_metadata[key] = value + # publish the task working event await event_queue.enqueue_event( TaskStatusUpdateEvent( diff --git a/python/packages/kagent-adk/src/kagent/adk/_session_linking_plugin.py b/python/packages/kagent-adk/src/kagent/adk/_session_linking_plugin.py new file mode 100644 index 000000000..e9d641ee5 --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/_session_linking_plugin.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import logging +from typing import Any + +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools.agent_tool import AgentTool +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext + +from kagent.core.a2a import ( + KAGENT_METADATA_KEY_CALLER_APP_NAME, + KAGENT_METADATA_KEY_CALLER_SESSION_ID, + KAGENT_METADATA_KEY_CALLER_TOOL_CALL_ID, + get_kagent_metadata_key, +) + +logger = logging.getLogger("kagent_adk." + __name__) + + +class SessionLinkingPlugin(BasePlugin): + """A plugin that propagates parent session metadata to child agents. + + This plugin intercepts tool calls to AgentTool and injects the parent's + session_id and tool_call_id into the session state. + Because AgentTool clones the current session state when creating a child + session, these values are automatically propagated to the sub-agent and + eventually to the remote A2A task metadata. + """ + + def __init__(self, name: str = "session_linking"): + super().__init__(name) + + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> None: + """Inject parent metadata into the state before an AgentTool runs.""" + # We only need to do this for tools that spin up sub-sessions (AgentTools) + if isinstance(tool, AgentTool): + invocation_context = tool_context._invocation_context + if not invocation_context: + logger.warning( + "No invocation context found in tool_context for tool %s. Cannot link sessions.", + tool.name, + ) + return + + # Parent metadata keys. We use a prefix that is NOT filtered out by + # AgentTool (which typically filters '_adk'). + parent_metadata = { + get_kagent_metadata_key(KAGENT_METADATA_KEY_CALLER_APP_NAME): invocation_context.app_name, + get_kagent_metadata_key(KAGENT_METADATA_KEY_CALLER_SESSION_ID): invocation_context.session.id, + get_kagent_metadata_key(KAGENT_METADATA_KEY_CALLER_TOOL_CALL_ID): tool_context.function_call_id, + } + + # Update the current session state. This state is then cloned by + # the AgentTool.run_async method to create the child session. + tool_context.state.update(parent_metadata) diff --git a/python/packages/kagent-adk/src/kagent/adk/cli.py b/python/packages/kagent-adk/src/kagent/adk/cli.py index 838d70134..558e44fa7 100644 --- a/python/packages/kagent-adk/src/kagent/adk/cli.py +++ b/python/packages/kagent-adk/src/kagent/adk/cli.py @@ -15,6 +15,7 @@ from kagent.core import KAgentConfig, configure_logging, configure_tracing from . import AgentConfig, KAgentApp +from ._session_linking_plugin import SessionLinkingPlugin from .tools import add_skills_tool_to_agent logger = logging.getLogger(__name__) @@ -60,10 +61,10 @@ def static( with open(os.path.join(filepath, "agent-card.json"), "r") as f: agent_card = json.load(f) agent_card = AgentCard.model_validate(agent_card) - plugins = None + plugins = [SessionLinkingPlugin()] sts_integration = create_sts_integration() if sts_integration: - plugins = [sts_integration] + plugins.append(sts_integration) if agent_config.model.api_key_passthrough: from ._llm_passthrough_plugin import LLMPassthroughPlugin @@ -136,10 +137,10 @@ def run( ): app_cfg = KAgentConfig() - plugins = None + plugins = [SessionLinkingPlugin()] sts_integration = create_sts_integration() if sts_integration: - plugins = [sts_integration] + plugins.append(sts_integration) agent_loader = AgentLoader(agents_dir=working_dir) @@ -206,10 +207,10 @@ def root_agent_factory() -> BaseAgent: async def test_agent(agent_config: AgentConfig, agent_card: AgentCard, task: str): app_cfg = KAgentConfig(url="http://fake-url.example.com", name="test-agent", namespace="kagent") - plugins = None + plugins = [SessionLinkingPlugin()] sts_integration = create_sts_integration() if sts_integration: - plugins = [sts_integration] + plugins.append(sts_integration) def root_agent_factory() -> BaseAgent: root_agent = agent_config.to_agent(app_cfg.name, sts_integration) diff --git a/python/packages/kagent-adk/src/kagent/adk/converters/event_converter.py b/python/packages/kagent-adk/src/kagent/adk/converters/event_converter.py index 0b0d4425f..c2eae8b1e 100644 --- a/python/packages/kagent-adk/src/kagent/adk/converters/event_converter.py +++ b/python/packages/kagent-adk/src/kagent/adk/converters/event_converter.py @@ -17,6 +17,7 @@ A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, A2A_DATA_PART_METADATA_TYPE_KEY, + KAGENT_METADATA_KEY_PREFIX, get_kagent_metadata_key, ) @@ -92,6 +93,12 @@ def _get_context_metadata(event: Event, invocation_context: InvocationContext) - if field_value is not None: metadata[get_kagent_metadata_key(field_name)] = _serialize_metadata_value(field_value) + # Include caller metadata from session state if present + if invocation_context.session and invocation_context.session.state: + for key, value in invocation_context.session.state.items(): + if key.startswith(KAGENT_METADATA_KEY_PREFIX): + metadata[key] = _serialize_metadata_value(value) + return metadata except Exception as e: diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 26b3df144..d96e802c7 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -2,9 +2,11 @@ from typing import Any, Callable, Literal, Optional, Union import httpx +from a2a.types import Message as A2AMessage from agentsts.adk import ADKTokenPropagationPlugin from google.adk.agents import Agent from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import ToolUnion from google.adk.agents.readonly_context import ReadonlyContext from google.adk.agents.remote_a2a_agent import AGENT_CARD_WELL_KNOWN_PATH, DEFAULT_TIMEOUT, RemoteA2aAgent @@ -19,6 +21,7 @@ from kagent.adk.models._litellm import KAgentLiteLlm from kagent.adk.sandbox_code_executer import SandboxedLocalCodeExecutor from kagent.adk.tools.ask_user_tool import AskUserTool +from kagent.core.a2a import KAGENT_METADATA_KEY_PREFIX, get_kagent_metadata_key from .models import AzureOpenAI as OpenAIAzure from .models import OpenAI as OpenAINative @@ -32,6 +35,29 @@ HEADERS_STATE_KEY = "headers" +def _kagent_metadata_provider(ctx: InvocationContext, _message: A2AMessage) -> dict[str, Any]: + """Extract kagent metadata from session state for A2A request propagation. + + This provider is used by RemoteA2aAgent to inject caller metadata + (e.g. kagent_caller_session_id, kagent_caller_tool_call_id) into the + A2A wire protocol so the receiving child agent can store them in task + metadata for parent-child session correlation. + + It also propagates the parent's user_id so that child agent sessions + are created under the same identity, so that the user can view those + sessions (kagent requires that users can only see their own sessions). + """ + metadata: dict[str, Any] = {} + if ctx.session: + # Propagate the parent's user_id to the child agent session. + metadata[get_kagent_metadata_key("user_id")] = ctx.session.user_id + if ctx.session.state: + for key, value in ctx.session.state.items(): + if key.startswith(KAGENT_METADATA_KEY_PREFIX): + metadata[key] = value + return metadata + + def create_header_provider( allowed_headers: list[str] | None = None, sts_header_provider: Callable[[Optional[ReadonlyContext]], dict[str, str]] | None = None, @@ -371,6 +397,7 @@ async def rewrite_url_to_proxy(request: httpx.Request) -> None: agent_card=f"{remote_agent.url}{AGENT_CARD_WELL_KNOWN_PATH}", description=remote_agent.description, httpx_client=client, + a2a_request_meta_provider=_kagent_metadata_provider, ) tools.append(AgentTool(agent=remote_a2a_agent)) diff --git a/python/packages/kagent-adk/tests/unittests/converters/test_event_converter.py b/python/packages/kagent-adk/tests/unittests/converters/test_event_converter.py index 4279d5914..ee68d0476 100644 --- a/python/packages/kagent-adk/tests/unittests/converters/test_event_converter.py +++ b/python/packages/kagent-adk/tests/unittests/converters/test_event_converter.py @@ -14,6 +14,7 @@ def _create_mock_invocation_context(): context.app_name = "test_app" context.user_id = "test_user" context.session.id = "test_session" + context.session.state = {} return context diff --git a/python/packages/kagent-core/src/kagent/core/a2a/__init__.py b/python/packages/kagent-core/src/kagent/core/a2a/__init__.py index 41846ea49..74b1b4b99 100644 --- a/python/packages/kagent-core/src/kagent/core/a2a/__init__.py +++ b/python/packages/kagent-core/src/kagent/core/a2a/__init__.py @@ -15,6 +15,10 @@ KAGENT_HITL_DECISION_TYPE_REJECT, KAGENT_HITL_DECISIONS_KEY, KAGENT_HITL_REJECTION_REASONS_KEY, + KAGENT_METADATA_KEY_CALLER_APP_NAME, + KAGENT_METADATA_KEY_CALLER_SESSION_ID, + KAGENT_METADATA_KEY_CALLER_TOOL_CALL_ID, + KAGENT_METADATA_KEY_PREFIX, get_kagent_metadata_key, read_metadata_value, ) @@ -36,6 +40,9 @@ "get_kagent_metadata_key", "read_metadata_value", "ADK_METADATA_KEY_PREFIX", + "KAGENT_METADATA_KEY_PREFIX", + "KAGENT_METADATA_KEY_CALLER_SESSION_ID", + "KAGENT_METADATA_KEY_CALLER_TOOL_CALL_ID", "A2A_DATA_PART_METADATA_TYPE_KEY", "A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY", "A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", diff --git a/python/packages/kagent-core/src/kagent/core/a2a/_consts.py b/python/packages/kagent-core/src/kagent/core/a2a/_consts.py index 976ca025c..7b5f1beb2 100644 --- a/python/packages/kagent-core/src/kagent/core/a2a/_consts.py +++ b/python/packages/kagent-core/src/kagent/core/a2a/_consts.py @@ -13,6 +13,12 @@ ADK_METADATA_KEY_PREFIX = "adk_" +# Caller metadata constants +KAGENT_METADATA_KEY_CALLER_APP_NAME = "caller_app_name" +KAGENT_METADATA_KEY_CALLER_SESSION_ID = "caller_session_id" +KAGENT_METADATA_KEY_CALLER_TOOL_CALL_ID = "caller_tool_call_id" + + def get_kagent_metadata_key(key: str) -> str: """Gets the A2A event metadata key for the given key. diff --git a/ui/src/app/actions/sessions.ts b/ui/src/app/actions/sessions.ts index b1fa136d5..4e0a9d6c4 100644 --- a/ui/src/app/actions/sessions.ts +++ b/ui/src/app/actions/sessions.ts @@ -108,3 +108,21 @@ export async function checkSessionExists(sessionId: string): Promise(error, "Error checking session"); } } + +/** + * Gets a sub-agent session by parent session and tool call ID + * @param sessionId The parent session ID + * @param toolCallId The tool call ID + * @returns A promise with the session data + */ +export async function getSubAgentSession(sessionId: string, toolCallId: string): Promise> { + try { + const response = await fetchApi>(`/sessions/${sessionId}/subagentsessions/${toolCallId}`); + if (!response || !response.data) { + throw new Error("Failed to get sub-agent session"); + } + return { message: "Sub-agent session fetched successfully", data: response.data.session }; + } catch (error) { + return createErrorResponse(error, "Error getting sub-agent session"); + } +} diff --git a/ui/src/components/chat/AgentCallDisplay.tsx b/ui/src/components/chat/AgentCallDisplay.tsx index 6ce80d14c..6cc0efb93 100644 --- a/ui/src/components/chat/AgentCallDisplay.tsx +++ b/ui/src/components/chat/AgentCallDisplay.tsx @@ -1,9 +1,12 @@ -import { useMemo, useState } from "react"; +import { useCallback, useMemo, useState } from "react"; import { FunctionCall } from "@/types"; import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card"; import { convertToUserFriendlyName } from "@/lib/utils"; -import { ChevronDown, ChevronUp, MessageSquare, Loader2, AlertCircle, CheckCircle } from "lucide-react"; +import { ChevronDown, ChevronUp, MessageSquare, Loader2, AlertCircle, CheckCircle, ExternalLink } from "lucide-react"; import KagentLogo from "../kagent-logo"; +import { getSubAgentSession } from "@/app/actions/sessions"; +import { toast } from "sonner"; +import { Button } from "@/components/ui/button"; export type AgentCallStatus = "requested" | "executing" | "completed"; @@ -15,15 +18,39 @@ interface AgentCallDisplayProps { }; status?: AgentCallStatus; isError?: boolean; + sessionId?: string; } -const AgentCallDisplay = ({ call, result, status = "requested", isError = false }: AgentCallDisplayProps) => { +const AgentCallDisplay = ({ call, result, status = "requested", isError = false, sessionId }: AgentCallDisplayProps) => { const [areInputsExpanded, setAreInputsExpanded] = useState(false); const [areResultsExpanded, setAreResultsExpanded] = useState(false); const agentDisplay = useMemo(() => convertToUserFriendlyName(call.name), [call.name]); const hasResult = result !== undefined; + const callId = call.id; + + const onOpenSubAgentSession = useCallback(async () => { + try { + // Theoretically we must have a session ID to be rendered, check anyway + if(!sessionId) { + toast.error('No session ID, cannot lookup sub-agent session'); + return; + } + + const response = await getSubAgentSession(sessionId, callId); + if (response.data && response.data.id) { + const subagentSessionUrl = `/agents/${agentDisplay}/chat/${response.data.id}`; + window.open(subagentSessionUrl, '_blank'); + } else { + toast.error('Sub-agent session not found'); + } + } catch (error) { + console.error('Error opening subagent session:', error); + toast.error('Failed to open sub-agent session'); + } + }, [agentDisplay, sessionId, callId]); + const getStatusDisplay = () => { if (isError && status === "executing") { return ( @@ -80,6 +107,14 @@ const AgentCallDisplay = ({ call, result, status = "requested", isError = false
{getStatusDisplay()} +
diff --git a/ui/src/components/chat/ChatInterface.tsx b/ui/src/components/chat/ChatInterface.tsx index cbf01b287..930c303ad 100644 --- a/ui/src/components/chat/ChatInterface.tsx +++ b/ui/src/components/chat/ChatInterface.tsx @@ -648,6 +648,7 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se onReject={handleReject} onAskUserSubmit={handleAskUserSubmit} pendingDecisions={pendingDecisions} + sessionId={sessionId!} /> })} @@ -665,6 +666,7 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se onReject={handleReject} onAskUserSubmit={handleAskUserSubmit} pendingDecisions={pendingDecisions} + sessionId={session?.id || sessionId} /> })} diff --git a/ui/src/components/chat/ChatMessage.tsx b/ui/src/components/chat/ChatMessage.tsx index c73883bf2..eb744036f 100644 --- a/ui/src/components/chat/ChatMessage.tsx +++ b/ui/src/components/chat/ChatMessage.tsx @@ -21,9 +21,10 @@ interface ChatMessageProps { onReject?: (toolCallId: string, reason?: string) => void; onAskUserSubmit?: (answers: Array<{ answer: string[] }>) => void; pendingDecisions?: Record; + sessionId?: string; } -export default function ChatMessage({ message, allMessages, agentContext, onApprove, onReject, onAskUserSubmit, pendingDecisions }: ChatMessageProps) { +export default function ChatMessage({ message, allMessages, agentContext, onApprove, onReject, onAskUserSubmit, pendingDecisions, sessionId }: ChatMessageProps) { const [feedbackDialogOpen, setFeedbackDialogOpen] = useState(false); const [isPositiveFeedback, setIsPositiveFeedback] = useState(true); @@ -120,6 +121,7 @@ export default function ChatMessage({ message, allMessages, agentContext, onAppr onApprove={onApprove} onReject={onReject} pendingDecisions={pendingDecisions} + sessionId={sessionId} />; } @@ -135,7 +137,7 @@ export default function ChatMessage({ message, allMessages, agentContext, onAppr }); if (hasToolCalls) { - return ; + return ; } return null; } diff --git a/ui/src/components/chat/ToolCallDisplay.tsx b/ui/src/components/chat/ToolCallDisplay.tsx index f16440c81..759e641a9 100644 --- a/ui/src/components/chat/ToolCallDisplay.tsx +++ b/ui/src/components/chat/ToolCallDisplay.tsx @@ -12,6 +12,7 @@ interface ToolCallDisplayProps { onApprove?: (toolCallId: string) => void; onReject?: (toolCallId: string, reason?: string) => void; pendingDecisions?: Record; + sessionId?: string; } interface ToolCallState { @@ -162,7 +163,7 @@ const extractToolCallResults = (message: Message): ProcessedToolResultData[] => }; -const ToolCallDisplay = ({ currentMessage, allMessages, onApprove, onReject, pendingDecisions }: ToolCallDisplayProps) => { +const ToolCallDisplay = ({ currentMessage, allMessages, onApprove, onReject, pendingDecisions, sessionId }: ToolCallDisplayProps) => { // Determine which tool call IDs this component instance "owns" by finding, // for each ID introduced by currentMessage, whether currentMessage is the // FIRST message in allMessages that introduces that ID. @@ -184,20 +185,20 @@ const ToolCallDisplay = ({ currentMessage, allMessages, onApprove, onReject, pen } const ownedIds = new Set(currentRequests.map(r => r.id).filter(id => id !== undefined) as string[]); - + // Scan backwards from our index to see if any earlier message already has these IDs. // This avoids a full O(N) scan per component render by aborting early. for (let i = currentIndex - 1; i >= 0; i--) { const msg = allMessages[i]; if (!isToolCallRequestMessage(msg)) continue; - + const prevRequests = extractToolCallRequests(msg); for (const pr of prevRequests) { if (pr.id) { ownedIds.delete(pr.id); } } - + if (ownedIds.size === 0) break; // Early exit if all IDs were claimed by earlier messages } return ownedIds; @@ -314,6 +315,7 @@ const ToolCallDisplay = ({ currentMessage, allMessages, onApprove, onReject, pen result={toolCall.result} status={effectiveStatus === "pending_approval" ? "requested" : effectiveStatus as AgentCallStatus} isError={toolCall.result?.is_error} + sessionId={sessionId} /> ) : (