|
14 | 14 |
|
15 | 15 | import asyncio |
16 | 16 | import datetime |
| 17 | +import logging |
17 | 18 | from typing import Any |
18 | 19 | from typing import Iterable |
19 | 20 | from typing import Optional |
|
26 | 27 | from google.adk.sessions.session import Session |
27 | 28 | from google.genai import types |
28 | 29 | import pytest |
29 | | -from vertexai._genai.types import common as vertex_common_types |
| 30 | +from vertexai import types as vertex_types |
30 | 31 |
|
31 | 32 | MOCK_APP_NAME = 'test-app' |
32 | 33 | MOCK_USER_ID = 'test-user' |
33 | 34 |
|
34 | 35 |
|
35 | 36 | def _supports_generate_memories_metadata() -> bool: |
36 | 37 | return ( |
37 | | - 'metadata' |
38 | | - in vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields |
| 38 | + 'metadata' in vertex_types.GenerateAgentEngineMemoriesConfig.model_fields |
39 | 39 | ) |
40 | 40 |
|
41 | 41 |
|
42 | 42 | def _supports_create_memory_metadata() -> bool: |
43 | | - return 'metadata' in vertex_common_types.AgentEngineMemoryConfig.model_fields |
| 43 | + return 'metadata' in vertex_types.AgentEngineMemoryConfig.model_fields |
44 | 44 |
|
45 | 45 |
|
46 | 46 | def _supports_create_memory_revision_labels() -> bool: |
47 | | - return ( |
48 | | - 'revision_labels' |
49 | | - in vertex_common_types.AgentEngineMemoryConfig.model_fields |
50 | | - ) |
| 47 | + return 'revision_labels' in vertex_types.AgentEngineMemoryConfig.model_fields |
51 | 48 |
|
52 | 49 |
|
53 | 50 | class _AsyncListIterator: |
@@ -208,6 +205,9 @@ def mock_vertexai_client(): |
208 | 205 | mock_async_client.agent_engines.memories.generate = mock.AsyncMock() |
209 | 206 | mock_async_client.agent_engines.memories.create = mock.AsyncMock() |
210 | 207 | mock_async_client.agent_engines.memories.retrieve = mock.AsyncMock() |
| 208 | + mock_async_client.agent_engines.memories.retrieve_profiles = ( |
| 209 | + mock.AsyncMock() |
| 210 | + ) |
211 | 211 | mock_async_client.agent_engines.memories.ingest_events = mock.AsyncMock() |
212 | 212 |
|
213 | 213 | mock_client = mock.MagicMock() |
@@ -305,7 +305,7 @@ async def test_add_events_to_memory_with_explicit_events_and_metadata( |
305 | 305 | source = call_kwargs['direct_contents_source'] |
306 | 306 | assert len(source.events) == 1 |
307 | 307 | assert source.events[0].content.parts[0].text == 'test_content' |
308 | | - vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
| 308 | + vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
309 | 309 |
|
310 | 310 |
|
311 | 311 | @pytest.mark.asyncio |
@@ -336,7 +336,7 @@ async def test_add_events_to_memory_without_session_id( |
336 | 336 | source = call_kwargs['direct_contents_source'] |
337 | 337 | assert len(source.events) == 1 |
338 | 338 | assert source.events[0].content.parts[0].text == 'test_content' |
339 | | - vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
| 339 | + vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
340 | 340 | mock_vertexai_client.agent_engines.memories.create.assert_not_called() |
341 | 341 |
|
342 | 342 |
|
@@ -376,7 +376,7 @@ async def test_add_events_to_memory_merges_metadata_field_and_unknown_keys( |
376 | 376 | source = call_kwargs['direct_contents_source'] |
377 | 377 | assert len(source.events) == 1 |
378 | 378 | assert source.events[0].content.parts[0].text == 'test_content' |
379 | | - vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
| 379 | + vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
380 | 380 |
|
381 | 381 |
|
382 | 382 | @pytest.mark.asyncio |
@@ -407,7 +407,7 @@ async def test_add_events_to_memory_none_wait_for_completion_keeps_default( |
407 | 407 | source = call_kwargs['direct_contents_source'] |
408 | 408 | assert len(source.events) == 1 |
409 | 409 | assert source.events[0].content.parts[0].text == 'test_content' |
410 | | - vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
| 410 | + vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
411 | 411 |
|
412 | 412 |
|
413 | 413 | @pytest.mark.asyncio |
@@ -442,7 +442,7 @@ async def test_add_events_to_memory_ttl_used_when_revision_ttl_is_none( |
442 | 442 | source = call_kwargs['direct_contents_source'] |
443 | 443 | assert len(source.events) == 1 |
444 | 444 | assert source.events[0].content.parts[0].text == 'test_content' |
445 | | - vertex_common_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
| 445 | + vertex_types.GenerateAgentEngineMemoriesConfig(**call_kwargs['config']) |
446 | 446 |
|
447 | 447 |
|
448 | 448 | @pytest.mark.asyncio |
@@ -587,7 +587,7 @@ async def test_add_memory_calls_create( |
587 | 587 | 'config' |
588 | 588 | ] |
589 | 589 | ) |
590 | | - vertex_common_types.AgentEngineMemoryConfig(**create_config) |
| 590 | + vertex_types.AgentEngineMemoryConfig(**create_config) |
591 | 591 |
|
592 | 592 |
|
593 | 593 | @pytest.mark.asyncio |
@@ -634,7 +634,7 @@ async def test_add_memory_enable_consolidation_calls_generate_direct_source( |
634 | 634 | 'config' |
635 | 635 | ] |
636 | 636 | ) |
637 | | - vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) |
| 637 | + vertex_types.GenerateAgentEngineMemoriesConfig(**generate_config) |
638 | 638 |
|
639 | 639 |
|
640 | 640 | @pytest.mark.asyncio |
@@ -768,7 +768,7 @@ async def test_add_memory_calls_create_with_memory_entry_metadata( |
768 | 768 | 'config' |
769 | 769 | ] |
770 | 770 | ) |
771 | | - vertex_common_types.AgentEngineMemoryConfig(**create_config) |
| 771 | + vertex_types.AgentEngineMemoryConfig(**create_config) |
772 | 772 |
|
773 | 773 |
|
774 | 774 | @pytest.mark.asyncio |
@@ -1009,6 +1009,66 @@ async def test_search_memory_empty_results(mock_vertexai_client): |
1009 | 1009 | assert len(result.memories) == 0 |
1010 | 1010 |
|
1011 | 1011 |
|
| 1012 | +@pytest.mark.asyncio |
| 1013 | +async def test_retrieve_profiles(mock_vertexai_client, caplog): |
| 1014 | + """Returns the structured profiles for the scope as a list.""" |
| 1015 | + retrieve_profiles_response = vertex_types.RetrieveProfilesResponse( |
| 1016 | + profiles={ |
| 1017 | + 'user-profile': vertex_types.MemoryProfile( |
| 1018 | + schema_id='user-profile', |
| 1019 | + profile={'name': 'Kim'}, |
| 1020 | + ) |
| 1021 | + } |
| 1022 | + ) |
| 1023 | + mock_vertexai_client.agent_engines.memories.retrieve_profiles.return_value = ( |
| 1024 | + retrieve_profiles_response |
| 1025 | + ) |
| 1026 | + memory_service = mock_vertex_ai_memory_bank_service() |
| 1027 | + |
| 1028 | + with caplog.at_level(logging.INFO): |
| 1029 | + result = await memory_service.retrieve_profiles( |
| 1030 | + app_name=MOCK_APP_NAME, |
| 1031 | + user_id=MOCK_USER_ID, |
| 1032 | + ) |
| 1033 | + |
| 1034 | + mock_vertexai_client.agent_engines.memories.retrieve_profiles.assert_awaited_once_with( |
| 1035 | + name='reasoningEngines/123', |
| 1036 | + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, |
| 1037 | + ) |
| 1038 | + assert 'Retrieved 1 memory profiles.' in caplog.text |
| 1039 | + assert result == [ |
| 1040 | + vertex_types.MemoryProfile( |
| 1041 | + schema_id='user-profile', |
| 1042 | + profile={'name': 'Kim'}, |
| 1043 | + ) |
| 1044 | + ] |
| 1045 | + |
| 1046 | + |
| 1047 | +@pytest.mark.asyncio |
| 1048 | +async def test_retrieve_profiles_empty_results(mock_vertexai_client, caplog): |
| 1049 | + """Returns an empty list when the scope has no profiles.""" |
| 1050 | + retrieve_profiles_response = vertex_types.RetrieveProfilesResponse( |
| 1051 | + profiles=None |
| 1052 | + ) |
| 1053 | + mock_vertexai_client.agent_engines.memories.retrieve_profiles.return_value = ( |
| 1054 | + retrieve_profiles_response |
| 1055 | + ) |
| 1056 | + memory_service = mock_vertex_ai_memory_bank_service() |
| 1057 | + |
| 1058 | + with caplog.at_level(logging.INFO): |
| 1059 | + result = await memory_service.retrieve_profiles( |
| 1060 | + app_name=MOCK_APP_NAME, |
| 1061 | + user_id=MOCK_USER_ID, |
| 1062 | + ) |
| 1063 | + |
| 1064 | + mock_vertexai_client.agent_engines.memories.retrieve_profiles.assert_awaited_once_with( |
| 1065 | + name='reasoningEngines/123', |
| 1066 | + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, |
| 1067 | + ) |
| 1068 | + assert 'Retrieved no memory profiles.' in caplog.text |
| 1069 | + assert not result |
| 1070 | + |
| 1071 | + |
1012 | 1072 | async def test_search_memory_uses_async_client_path(): |
1013 | 1073 | sync_client = mock.MagicMock() |
1014 | 1074 | sync_client.agent_engines.memories.retrieve.side_effect = AssertionError( |
|
0 commit comments