Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go/api/database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Client interface {
GetTool(name string) (*Tool, error)
GetToolServer(name string) (*ToolServer, error)
GetPushNotification(taskID string, configID string) (*protocol.TaskPushNotificationConfig, error)
GetSubAgentSession(sessionID, toolCallID string) (string, error)

// List methods
ListTools() ([]Tool, error)
Expand Down
22 changes: 22 additions & 0 deletions go/core/internal/database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,28 @@ func (c *clientImpl) GetSession(sessionID string, userID string) (*dbpkg.Session
Clause{Key: "user_id", Value: userID})
}

func (c *clientImpl) GetSubAgentSession(sessionID, toolCallID string) (string, error) {
var sessionName string
query := c.db.Table("task").Select("session_id")

if c.db.Name() == "sqlite" {
query = query.Where("json_extract(data, '$.metadata.kagent_caller_session_id') = ?", sessionID).
Where("json_extract(data, '$.metadata.kagent_caller_tool_call_id') = ?", toolCallID)
} else {
query = query.Where("(data::json -> 'metadata' ->> 'kagent_caller_session_id') = ?", sessionID).
Where("(data::json -> 'metadata' ->> 'kagent_caller_tool_call_id') = ?", toolCallID)
}

err := query.Order("created_at DESC").Limit(1).Pluck("session_id", &sessionName).Error
if err != nil {
return "", err
}
if sessionName == "" {
return "", errors.New("sub-agent session not found")
}
return sessionName, nil
}

// GetAgent retrieves an agent by name and user ID
func (c *clientImpl) GetAgent(agentID string) (*dbpkg.Agent, error) {
return get[dbpkg.Agent](c.db, Clause{Key: "id", Value: agentID})
Expand Down
28 changes: 28 additions & 0 deletions go/core/internal/database/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,31 @@ func TestPruneExpiredMemories(t *testing.T) {
assert.Contains(t, ids, "prune-hot", "Expired popular memory should have TTL extended and be retained")
assert.Contains(t, ids, "prune-live", "Non-expired memory should be retained")
}

func TestGetSubAgentSession(t *testing.T) {
db := setupTestDB(t)
client := NewClient(db)

sessionID := "parent-session"
toolCallID := "call-1"
subSessionID := "sub-session-1"

taskData := `{"metadata": {"kagent_caller_session_id": "parent-session", "kagent_caller_tool_call_id": "call-1"}}`

err := client.(*clientImpl).db.Create(&dbpkg.Task{
ID: "task-1",
SessionID: subSessionID,
Data: taskData,
CreatedAt: time.Now(),
}).Error
require.NoError(t, err)

// Test success
foundSessionID, err := client.GetSubAgentSession(sessionID, toolCallID)
require.NoError(t, err)
assert.Equal(t, subSessionID, foundSessionID)

// Test not found (wrong session id)
_, err = client.GetSubAgentSession("wrong-session", toolCallID)
assert.Error(t, err)
}
42 changes: 41 additions & 1 deletion go/core/internal/database/fake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,47 @@ func (c *InMemoryFakeClient) GetSession(sessionID string, userID string) (*datab
return session, nil
}

// GetAgent retrieves an agent by name
// GetSubAgentSession finds a subagent session based on parent session ID and tool call ID
func (c *InMemoryFakeClient) GetSubAgentSession(sessionID, toolCallID string) (string, error) {
c.mu.RLock()
defer c.mu.RUnlock()

var matchedTasks []*database.Task

for _, task := range c.tasks {
var taskData protocol.Task
if err := json.Unmarshal([]byte(task.Data), &taskData); err != nil {
continue
}

if taskData.Metadata == nil {
continue
}

if taskData.Metadata["kagent_caller_session_id"] == sessionID &&
taskData.Metadata["kagent_caller_tool_call_id"] == toolCallID {
matchedTasks = append(matchedTasks, task)
}
}

if len(matchedTasks) == 0 {
return "", gorm.ErrRecordNotFound
}

// Sort by created_at DESC to get the latest
slices.SortStableFunc(matchedTasks, func(a, b *database.Task) int {
if a.CreatedAt.After(b.CreatedAt) {
return -1
}
if a.CreatedAt.Before(b.CreatedAt) {
return 1
}
return 0
})

return matchedTasks[0].SessionID, nil
}

func (c *InMemoryFakeClient) GetAgent(agentName string) (*database.Agent, error) {
c.mu.RLock()
defer c.mu.RUnlock()
Expand Down
45 changes: 45 additions & 0 deletions go/core/internal/httpserver/handlers/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,51 @@ func (h *SessionsHandler) HandleAddEventToSession(w ErrorResponseWriter, r *http
RespondWithJSON(w, http.StatusCreated, data)
}

func (h *SessionsHandler) HandleGetSubAgentSession(w ErrorResponseWriter, r *http.Request) {
log := ctrllog.FromContext(r.Context()).WithName("sessions-handler").WithValues("operation", "get-sub-agent-session")

parentSessionID, err := GetPathParam(r, "session_id")
if err != nil {
w.RespondWithError(errors.NewBadRequestError("Missing required path parameter: session_id", err))
return
}
toolCallID, err := GetPathParam(r, "tool_call_id")
if err != nil {
w.RespondWithError(errors.NewBadRequestError("Missing required path parameter: tool_call_id", err))
return
}

userID, err := getUserIDOrAgentUser(r)
if err != nil {
w.RespondWithError(errors.NewBadRequestError("Failed to get user ID", err))
return
}

log = log.WithValues("parentSessionID", parentSessionID, "toolCallID", toolCallID, "userID", userID)

sessionID, err := h.DatabaseService.GetSubAgentSession(parentSessionID, toolCallID)
if err != nil {
w.RespondWithError(errors.NewNotFoundError("Sub-agent session not found", err))
return
}

// Now get the session details as standard session response
session, err := h.DatabaseService.GetSession(sessionID, userID)
if err != nil {
w.RespondWithError(errors.NewNotFoundError("Session not found", err))
return
}

// We don't return events here for the lookup as it's a lookup, but we can if we want to follow HandleGetSession
// For now let's just return the session to be consistent with creating a session
log.Info("Successfully found sub-agent session", "sessionID", sessionID)
data := api.NewResponse(SessionResponse{
Session: session,
Events: nil,
}, "Successfully found sub-agent session", false)
RespondWithJSON(w, http.StatusOK, data)
}

func getUserID(r *http.Request) (string, error) {
log := ctrllog.Log.WithName("http-helpers")

Expand Down
65 changes: 64 additions & 1 deletion go/core/internal/httpserver/handlers/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,72 @@ func TestSessionsHandler(t *testing.T) {
req = mux.SetURLVars(req, map[string]string{"session_id": sessionID})

handler.HandleListTasksForSession(responseRecorder, req)

assert.Equal(t, http.StatusBadRequest, responseRecorder.Code)
assert.NotNil(t, responseRecorder.errorReceived)
})
})

t.Run("HandleGetSubAgentSession", func(t *testing.T) {
t.Run("Success", func(t *testing.T) {
handler, dbClient, responseRecorder := setupHandler()
sessionID := "parent-session"
toolCallID := "call-1"
subSessionID := "sub-session-1"
userID := "test-user"

// Create sub-session
dbClient.StoreSession(&database.Session{
ID: subSessionID,
UserID: userID,
})

// Add task with caller metadata
dbClient.AddTask(&database.Task{
ID: "task-1",
SessionID: subSessionID,
Data: `{"id": "task-1", "metadata": {"kagent_caller_session_id": "parent-session", "kagent_caller_tool_call_id": "call-1"}}`,
CreatedAt: time.Now(),
})

req := httptest.NewRequest("GET", "/api/sessions/"+sessionID+"/subagentsessions/"+toolCallID, nil)
req = mux.SetURLVars(req, map[string]string{
"session_id": sessionID,
"tool_call_id": toolCallID,
})
req = setUser(req, userID)
handler.HandleGetSubAgentSession(responseRecorder, req)

assert.Equal(t, http.StatusOK, responseRecorder.Code)

var response api.StandardResponse[handlers.SessionResponse]
err := json.Unmarshal(responseRecorder.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, subSessionID, response.Data.Session.ID)
})

t.Run("MissingParams", func(t *testing.T) {
handler, _, responseRecorder := setupHandler()

req := httptest.NewRequest("GET", "/api/sessions/s/toolcalls/c/subagentsession", nil)
// No URL vars set - will fail on first GetPathParam
req = setUser(req, "test-user")
handler.HandleGetSubAgentSession(responseRecorder, req)

assert.Equal(t, http.StatusBadRequest, responseRecorder.Code)
})

t.Run("NotFound", func(t *testing.T) {
handler, _, responseRecorder := setupHandler()

req := httptest.NewRequest("GET", "/api/sessions/c/toolcalls/d/subagentsession", nil)
req = mux.SetURLVars(req, map[string]string{
"session_id": "c",
"tool_call_id": "d",
})
req = setUser(req, "test-user")
handler.HandleGetSubAgentSession(responseRecorder, req)

assert.Equal(t, http.StatusNotFound, responseRecorder.Code)
})
})
}
1 change: 1 addition & 0 deletions go/core/internal/httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ func (s *HTTPServer) setupRoutes() {
// Sessions - using database handlers
s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleListSessions)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions, adaptHandler(s.handlers.Sessions.HandleCreateSession)).Methods(http.MethodPost)
s.router.HandleFunc(APIPathSessions+"/{session_id}/subagentsessions/{tool_call_id}", adaptHandler(s.handlers.Sessions.HandleGetSubAgentSession)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions+"/agent/{namespace}/{name}", adaptHandler(s.handlers.Sessions.HandleGetSessionsForAgent)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleGetSession)).Methods(http.MethodGet)
s.router.HandleFunc(APIPathSessions+"/{session_id}/tasks", adaptHandler(s.handlers.Sessions.HandleListTasksForSession)).Methods(http.MethodGet)
Expand Down
17 changes: 17 additions & 0 deletions python/packages/kagent-adk/src/kagent/adk/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from kagent.core.a2a import (
KAGENT_HITL_DECISION_TYPE_APPROVE,
KAGENT_HITL_DECISION_TYPE_BATCH,
KAGENT_METADATA_KEY_PREFIX,
TaskResultAggregator,
extract_ask_user_answers_from_message,
extract_batch_decisions_from_message,
Expand Down Expand Up @@ -475,6 +476,12 @@ async def _handle_request(
runner: Runner,
run_args: dict[str, Any],
):
# If the caller propagated a user_id via A2A request metadata,
# use it so the child session is owned by the same identity.
user_id_key = get_kagent_metadata_key("user_id")
if user_id_key in context.metadata:
run_args["user_id"] = context.metadata[user_id_key]

# ensure the session exists
session = await self._prepare_session(context, run_args, runner)

Expand Down Expand Up @@ -515,6 +522,16 @@ async def _handle_request(
get_kagent_metadata_key("session_id"): run_args["session_id"],
}

# Include caller metadata from A2A request metadata if present.
# When a parent agent calls a child via RemoteA2aAgent, the
# a2a_request_meta_provider injects kagent_* keys into the A2A
# wire protocol's MessageSendParams.metadata which arrives here
# via context.metadata.
if context.metadata:
for key, value in context.metadata.items():
if key.startswith(KAGENT_METADATA_KEY_PREFIX):
run_metadata[key] = value

# publish the task working event
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import logging
from typing import Any

from google.adk.plugins.base_plugin import BasePlugin
from google.adk.tools.agent_tool import AgentTool
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext

from kagent.core.a2a import (
KAGENT_METADATA_KEY_CALLER_APP_NAME,
KAGENT_METADATA_KEY_CALLER_SESSION_ID,
KAGENT_METADATA_KEY_CALLER_TOOL_CALL_ID,
get_kagent_metadata_key,
)

logger = logging.getLogger("kagent_adk." + __name__)


class SessionLinkingPlugin(BasePlugin):
"""A plugin that propagates parent session metadata to child agents.

This plugin intercepts tool calls to AgentTool and injects the parent's
session_id and tool_call_id into the session state.
Because AgentTool clones the current session state when creating a child
session, these values are automatically propagated to the sub-agent and
eventually to the remote A2A task metadata.
"""

def __init__(self, name: str = "session_linking"):
super().__init__(name)

async def before_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> None:
"""Inject parent metadata into the state before an AgentTool runs."""
# We only need to do this for tools that spin up sub-sessions (AgentTools)
if isinstance(tool, AgentTool):
invocation_context = tool_context._invocation_context
if not invocation_context:
logger.warning(
"No invocation context found in tool_context for tool %s. Cannot link sessions.",
tool.name,
)
return

# Parent metadata keys. We use a prefix that is NOT filtered out by
# AgentTool (which typically filters '_adk').
parent_metadata = {
get_kagent_metadata_key(KAGENT_METADATA_KEY_CALLER_APP_NAME): invocation_context.app_name,
get_kagent_metadata_key(KAGENT_METADATA_KEY_CALLER_SESSION_ID): invocation_context.session.id,
get_kagent_metadata_key(KAGENT_METADATA_KEY_CALLER_TOOL_CALL_ID): tool_context.function_call_id,
}

# Update the current session state. This state is then cloned by
# the AgentTool.run_async method to create the child session.
tool_context.state.update(parent_metadata)
13 changes: 7 additions & 6 deletions python/packages/kagent-adk/src/kagent/adk/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from kagent.core import KAgentConfig, configure_logging, configure_tracing

from . import AgentConfig, KAgentApp
from ._session_linking_plugin import SessionLinkingPlugin
from .tools import add_skills_tool_to_agent

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,10 +61,10 @@ def static(
with open(os.path.join(filepath, "agent-card.json"), "r") as f:
agent_card = json.load(f)
agent_card = AgentCard.model_validate(agent_card)
plugins = None
plugins = [SessionLinkingPlugin()]
sts_integration = create_sts_integration()
if sts_integration:
plugins = [sts_integration]
plugins.append(sts_integration)

if agent_config.model.api_key_passthrough:
from ._llm_passthrough_plugin import LLMPassthroughPlugin
Expand Down Expand Up @@ -136,10 +137,10 @@ def run(
):
app_cfg = KAgentConfig()

plugins = None
plugins = [SessionLinkingPlugin()]
sts_integration = create_sts_integration()
if sts_integration:
plugins = [sts_integration]
plugins.append(sts_integration)

agent_loader = AgentLoader(agents_dir=working_dir)

Expand Down Expand Up @@ -206,10 +207,10 @@ def root_agent_factory() -> BaseAgent:

async def test_agent(agent_config: AgentConfig, agent_card: AgentCard, task: str):
app_cfg = KAgentConfig(url="http://fake-url.example.com", name="test-agent", namespace="kagent")
plugins = None
plugins = [SessionLinkingPlugin()]
sts_integration = create_sts_integration()
if sts_integration:
plugins = [sts_integration]
plugins.append(sts_integration)

def root_agent_factory() -> BaseAgent:
root_agent = agent_config.to_agent(app_cfg.name, sts_integration)
Expand Down
Loading
Loading