Skip to content

Commit fb2b3af

Browse files
GWealecopybara-github
authored andcommitted
feat(memory): add Vertex AI load_profiles tool
Add VertexAiLoadProfilesTool for explicit agent access to structured user profiles from Vertex AI Memory Bank, backed by a new VertexAiMemoryBankService.retrieve_profiles method. Profiles are a Vertex backend capability (a scope-keyed lookup), not a memory-interface concept, so BaseMemoryService is unchanged. Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 938264248
1 parent 7b87f91 commit fb2b3af

5 files changed

Lines changed: 290 additions & 27 deletions

File tree

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
if TYPE_CHECKING:
3535
import vertexai
36+
from vertexai import types as vertex_types
3637

3738
from ..events.event import Event
3839
from ..sessions.session import Session
@@ -107,22 +108,21 @@ def _should_use_generate_memories(
107108
def _supports_generate_memories_metadata() -> bool:
108109
"""Returns whether installed Vertex SDK supports config.metadata."""
109110
try:
110-
from vertexai._genai.types import common as vertex_common_types
111+
from vertexai import types as vertex_types
111112
except ImportError:
112113
return False
113114
return (
114-
'metadata'
115-
in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
115+
'metadata' in vertex_types.GenerateAgentEngineMemoriesConfig.model_fields
116116
)
117117

118118

119119
def _supports_create_memory_metadata() -> bool:
120120
"""Returns whether installed Vertex SDK supports create config.metadata."""
121121
try:
122-
from vertexai._genai.types import common as vertex_common_types
122+
from vertexai import types as vertex_types
123123
except ImportError:
124124
return False
125-
return 'metadata' in vertex_common_types.AgentEngineMemoryConfig.model_fields
125+
return 'metadata' in vertex_types.AgentEngineMemoryConfig.model_fields
126126

127127

128128
@lru_cache(maxsize=1)
@@ -133,14 +133,12 @@ def _get_generate_memories_config_keys() -> frozenset[str]:
133133
allowlist to preserve compatibility when introspection is unavailable.
134134
"""
135135
try:
136-
from vertexai._genai.types import common as vertex_common_types
136+
from vertexai import types as vertex_types
137137
except ImportError:
138138
return _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS
139139

140140
try:
141-
model_fields = (
142-
vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
143-
)
141+
model_fields = vertex_types.GenerateAgentEngineMemoriesConfig.model_fields
144142
except AttributeError:
145143
return _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS
146144

@@ -157,12 +155,12 @@ def _get_create_memory_config_keys() -> frozenset[str]:
157155
allowlist to preserve compatibility when introspection is unavailable.
158156
"""
159157
try:
160-
from vertexai._genai.types import common as vertex_common_types
158+
from vertexai import types as vertex_types
161159
except ImportError:
162160
return _CREATE_MEMORY_CONFIG_FALLBACK_KEYS
163161

164162
try:
165-
model_fields = vertex_common_types.AgentEngineMemoryConfig.model_fields
163+
model_fields = vertex_types.AgentEngineMemoryConfig.model_fields
166164
except AttributeError:
167165
return _CREATE_MEMORY_CONFIG_FALLBACK_KEYS
168166

@@ -574,6 +572,39 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str):
574572
)
575573
return SearchMemoryResponse(memories=memory_events)
576574

575+
async def retrieve_profiles(
576+
self,
577+
*,
578+
app_name: str,
579+
user_id: str,
580+
) -> list[vertex_types.MemoryProfile]:
581+
"""Retrieves structured user profiles for the scope, one per schema.
582+
583+
Profiles are a Vertex Memory Bank capability distinct from memory search:
584+
a scope-keyed lookup, not a semantic query.
585+
586+
Args:
587+
app_name: The application name for the profile scope.
588+
user_id: The user ID for the profile scope.
589+
590+
Returns:
591+
The structured profiles for the scope, one per registered schema.
592+
"""
593+
api_client = self._get_api_client()
594+
response = await api_client.agent_engines.memories.retrieve_profiles(
595+
name='reasoningEngines/' + self._agent_engine_id,
596+
scope={
597+
'app_name': app_name,
598+
'user_id': user_id,
599+
},
600+
)
601+
profiles = list((response.profiles or {}).values())
602+
if profiles:
603+
logger.info('Retrieved %d memory profiles.', len(profiles))
604+
else:
605+
logger.info('Retrieved no memory profiles.')
606+
return profiles
607+
577608
def _get_api_client(self) -> vertexai.AsyncClient:
578609
"""Instantiates an API client for the given project and location.
579610

src/google/adk/tools/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .transfer_to_agent_tool import transfer_to_agent
4343
from .transfer_to_agent_tool import TransferToAgentTool
4444
from .url_context_tool import url_context
45+
from .vertex_ai_load_profiles_tool import VertexAiLoadProfilesTool
4546
from .vertex_ai_search_tool import VertexAiSearchTool
4647

4748
# If you are adding a new tool to this file, please make sure you add it to the
@@ -89,6 +90,10 @@
8990
'TransferToAgentTool',
9091
),
9192
'url_context': ('.url_context_tool', 'url_context'),
93+
'VertexAiLoadProfilesTool': (
94+
'.vertex_ai_load_profiles_tool',
95+
'VertexAiLoadProfilesTool',
96+
),
9297
'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'),
9398
'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'),
9499
'McpToolset': ('.mcp_tool.mcp_toolset', 'McpToolset'),
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
from typing import TYPE_CHECKING
19+
20+
from google.genai import types
21+
from typing_extensions import override
22+
23+
from ..features import FeatureName
24+
from ..features import is_feature_enabled
25+
from .function_tool import FunctionTool
26+
from .tool_context import ToolContext
27+
28+
if TYPE_CHECKING:
29+
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
30+
31+
32+
class VertexAiLoadProfilesTool(FunctionTool):
33+
"""A tool that loads a user's structured profiles from Vertex Memory Bank."""
34+
35+
def __init__(self, memory_service: VertexAiMemoryBankService):
36+
super().__init__(self.load_profiles)
37+
self._memory_service = memory_service
38+
39+
async def load_profiles(self, tool_context: ToolContext) -> dict[str, Any]:
40+
"""Loads structured user profiles for the current user."""
41+
profiles = await self._memory_service.retrieve_profiles(
42+
app_name=tool_context.session.app_name,
43+
user_id=tool_context.user_id,
44+
)
45+
return {
46+
'profiles': [profile.profile for profile in profiles if profile.profile]
47+
}
48+
49+
@override
50+
def _get_declaration(self) -> types.FunctionDeclaration | None:
51+
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
52+
return types.FunctionDeclaration(
53+
name=self.name,
54+
description=self.description,
55+
parameters_json_schema={
56+
'type': 'object',
57+
'properties': {},
58+
},
59+
)
60+
return types.FunctionDeclaration(
61+
name=self.name,
62+
description=self.description,
63+
parameters=types.Schema(
64+
type=types.Type.OBJECT,
65+
properties={},
66+
),
67+
)

tests/unittests/memory/test_vertex_ai_memory_bank_service.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import asyncio
1616
import datetime
17+
import logging
1718
from typing import Any
1819
from typing import Iterable
1920
from typing import Optional
@@ -26,28 +27,24 @@
2627
from google.adk.sessions.session import Session
2728
from google.genai import types
2829
import pytest
29-
from vertexai._genai.types import common as vertex_common_types
30+
from vertexai import types as vertex_types
3031

3132
MOCK_APP_NAME = 'test-app'
3233
MOCK_USER_ID = 'test-user'
3334

3435

3536
def _supports_generate_memories_metadata() -> bool:
3637
return (
37-
'metadata'
38-
in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
38+
'metadata' in vertex_types.GenerateAgentEngineMemoriesConfig.model_fields
3939
)
4040

4141

4242
def _supports_create_memory_metadata() -> bool:
43-
return 'metadata' in vertex_common_types.AgentEngineMemoryConfig.model_fields
43+
return 'metadata' in vertex_types.AgentEngineMemoryConfig.model_fields
4444

4545

4646
def _supports_create_memory_revision_labels() -> bool:
47-
return (
48-
'revision_labels'
49-
in vertex_common_types.AgentEngineMemoryConfig.model_fields
50-
)
47+
return 'revision_labels' in vertex_types.AgentEngineMemoryConfig.model_fields
5148

5249

5350
class _AsyncListIterator:
@@ -208,6 +205,9 @@ def mock_vertexai_client():
208205
mock_async_client.agent_engines.memories.generate = mock.AsyncMock()
209206
mock_async_client.agent_engines.memories.create = mock.AsyncMock()
210207
mock_async_client.agent_engines.memories.retrieve = mock.AsyncMock()
208+
mock_async_client.agent_engines.memories.retrieve_profiles = (
209+
mock.AsyncMock()
210+
)
211211
mock_async_client.agent_engines.memories.ingest_events = mock.AsyncMock()
212212

213213
mock_client = mock.MagicMock()
@@ -305,7 +305,7 @@ async def test_add_events_to_memory_with_explicit_events_and_metadata(
305305
source = call_kwargs['direct_contents_source']
306306
assert len(source.events) == 1
307307
assert source.events[0].content.parts[0].text == 'test_content'
308-
vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
308+
vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
309309

310310

311311
@pytest.mark.asyncio
@@ -336,7 +336,7 @@ async def test_add_events_to_memory_without_session_id(
336336
source = call_kwargs['direct_contents_source']
337337
assert len(source.events) == 1
338338
assert source.events[0].content.parts[0].text == 'test_content'
339-
vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
339+
vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
340340
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
341341

342342

@@ -376,7 +376,7 @@ async def test_add_events_to_memory_merges_metadata_field_and_unknown_keys(
376376
source = call_kwargs['direct_contents_source']
377377
assert len(source.events) == 1
378378
assert source.events[0].content.parts[0].text == 'test_content'
379-
vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
379+
vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
380380

381381

382382
@pytest.mark.asyncio
@@ -407,7 +407,7 @@ async def test_add_events_to_memory_none_wait_for_completion_keeps_default(
407407
source = call_kwargs['direct_contents_source']
408408
assert len(source.events) == 1
409409
assert source.events[0].content.parts[0].text == 'test_content'
410-
vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
410+
vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
411411

412412

413413
@pytest.mark.asyncio
@@ -442,7 +442,7 @@ async def test_add_events_to_memory_ttl_used_when_revision_ttl_is_none(
442442
source = call_kwargs['direct_contents_source']
443443
assert len(source.events) == 1
444444
assert source.events[0].content.parts[0].text == 'test_content'
445-
vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
445+
vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config'])
446446

447447

448448
@pytest.mark.asyncio
@@ -587,7 +587,7 @@ async def test_add_memory_calls_create(
587587
'config'
588588
]
589589
)
590-
vertex_common_types.AgentEngineMemoryConfig(**create_config)
590+
vertex_types.AgentEngineMemoryConfig(**create_config)
591591

592592

593593
@pytest.mark.asyncio
@@ -634,7 +634,7 @@ async def test_add_memory_enable_consolidation_calls_generate_direct_source(
634634
'config'
635635
]
636636
)
637-
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
637+
vertex_types.GenerateAgentEngineMemoriesConfig(**generate_config)
638638

639639

640640
@pytest.mark.asyncio
@@ -768,7 +768,7 @@ async def test_add_memory_calls_create_with_memory_entry_metadata(
768768
'config'
769769
]
770770
)
771-
vertex_common_types.AgentEngineMemoryConfig(**create_config)
771+
vertex_types.AgentEngineMemoryConfig(**create_config)
772772

773773

774774
@pytest.mark.asyncio
@@ -1009,6 +1009,66 @@ async def test_search_memory_empty_results(mock_vertexai_client):
10091009
assert len(result.memories) == 0
10101010

10111011

1012+
@pytest.mark.asyncio
1013+
async def test_retrieve_profiles(mock_vertexai_client, caplog):
1014+
"""Returns the structured profiles for the scope as a list."""
1015+
retrieve_profiles_response = vertex_types.RetrieveProfilesResponse(
1016+
profiles={
1017+
'user-profile': vertex_types.MemoryProfile(
1018+
schema_id='user-profile',
1019+
profile={'name': 'Kim'},
1020+
)
1021+
}
1022+
)
1023+
mock_vertexai_client.agent_engines.memories.retrieve_profiles.return_value = (
1024+
retrieve_profiles_response
1025+
)
1026+
memory_service = mock_vertex_ai_memory_bank_service()
1027+
1028+
with caplog.at_level(logging.INFO):
1029+
result = await memory_service.retrieve_profiles(
1030+
app_name=MOCK_APP_NAME,
1031+
user_id=MOCK_USER_ID,
1032+
)
1033+
1034+
mock_vertexai_client.agent_engines.memories.retrieve_profiles.assert_awaited_once_with(
1035+
name='reasoningEngines/123',
1036+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
1037+
)
1038+
assert 'Retrieved 1 memory profiles.' in caplog.text
1039+
assert result == [
1040+
vertex_types.MemoryProfile(
1041+
schema_id='user-profile',
1042+
profile={'name': 'Kim'},
1043+
)
1044+
]
1045+
1046+
1047+
@pytest.mark.asyncio
1048+
async def test_retrieve_profiles_empty_results(mock_vertexai_client, caplog):
1049+
"""Returns an empty list when the scope has no profiles."""
1050+
retrieve_profiles_response = vertex_types.RetrieveProfilesResponse(
1051+
profiles=None
1052+
)
1053+
mock_vertexai_client.agent_engines.memories.retrieve_profiles.return_value = (
1054+
retrieve_profiles_response
1055+
)
1056+
memory_service = mock_vertex_ai_memory_bank_service()
1057+
1058+
with caplog.at_level(logging.INFO):
1059+
result = await memory_service.retrieve_profiles(
1060+
app_name=MOCK_APP_NAME,
1061+
user_id=MOCK_USER_ID,
1062+
)
1063+
1064+
mock_vertexai_client.agent_engines.memories.retrieve_profiles.assert_awaited_once_with(
1065+
name='reasoningEngines/123',
1066+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
1067+
)
1068+
assert 'Retrieved no memory profiles.' in caplog.text
1069+
assert not result
1070+
1071+
10121072
async def test_search_memory_uses_async_client_path():
10131073
sync_client = mock.MagicMock()
10141074
sync_client.agent_engines.memories.retrieve.side_effect = AssertionError(

0 commit comments

Comments
 (0)