From fe8c89f37c05b5743501b7a642b1111d9bba31b4 Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sat, 16 May 2026 18:22:51 +0530 Subject: [PATCH 1/7] feat: integrate Mem0 with ADK memory service - Add Mem0MemoryService implementing ADK BaseMemoryService interface - Register mem0:// URI scheme for memory service factory - Add PreloadMemoryTool for automatic memory injection on each turn - Remove delete_all_memories tool to prevent accidental data loss - Enhance health endpoint with memory service status - Initialize Mem0 client at server startup (fixes #67) --- src/blacki/agent.py | 3 + src/blacki/memory/__init__.py | 2 - src/blacki/memory/mem0_memory_service.py | 97 +++++++++++++++ src/blacki/memory/tools.py | 39 ------ src/blacki/prompt.py | 2 - src/blacki/registry.py | 2 - src/blacki/server.py | 48 ++++++-- tests/memory/test_mem0_memory_service.py | 150 +++++++++++++++++++++++ tests/memory/test_tools.py | 30 ----- tests/test_registry.py | 12 +- 10 files changed, 297 insertions(+), 88 deletions(-) create mode 100644 src/blacki/memory/mem0_memory_service.py create mode 100644 tests/memory/test_mem0_memory_service.py diff --git a/src/blacki/agent.py b/src/blacki/agent.py index 7cbcd05..03eaea7 100644 --- a/src/blacki/agent.py +++ b/src/blacki/agent.py @@ -155,8 +155,11 @@ def create_agent() -> LlmAgent: Returns: Configured LlmAgent instance. """ + from google.adk.tools.preload_memory_tool import preload_memory_tool + tool_config = build_tool_config_from_env() agent_tools = build_tools(tool_config) + agent_tools.append(preload_memory_tool) before_tool_callbacks: list[Any] = [logging_callbacks.before_tool] after_model_callbacks: list[Any] = [logging_callbacks.after_model] diff --git a/src/blacki/memory/__init__.py b/src/blacki/memory/__init__.py index 2d4f6b8..1b36ee0 100644 --- a/src/blacki/memory/__init__.py +++ b/src/blacki/memory/__init__.py @@ -6,7 +6,6 @@ reset_memory_client, ) from .tools import ( - delete_all_memories, delete_memory, get_all_memories, get_memory, @@ -16,7 +15,6 @@ ) __all__ = [ - "delete_all_memories", "delete_memory", "get_all_memories", "get_memory", diff --git a/src/blacki/memory/mem0_memory_service.py b/src/blacki/memory/mem0_memory_service.py new file mode 100644 index 0000000..3d0467c --- /dev/null +++ b/src/blacki/memory/mem0_memory_service.py @@ -0,0 +1,97 @@ +"""Memory service that bridges Mem0 to ADK's BaseMemoryService interface.""" + +from __future__ import annotations + +import logging +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING + +from google.adk.events.event import Event +from google.adk.memory.base_memory_service import ( + BaseMemoryService, + SearchMemoryResponse, +) +from google.adk.memory.memory_entry import MemoryEntry +from google.adk.sessions.session import Session +from google.genai import types + +if TYPE_CHECKING: + from mem0 import Memory + +logger = logging.getLogger(__name__) + + +class Mem0MemoryService(BaseMemoryService): + """Memory service backed by Mem0 OSS. + + Wraps the existing Mem0 client to provide ADK-compatible memory operations. + Memories are managed manually via save_memory tool (no automatic session ingestion). + """ + + def __init__(self, client: Memory): + self._client = client + + async def add_session_to_memory(self, session: Session) -> None: + """Not used - user chose manual memory management via save_memory tool.""" + pass + + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Not used - user chose manual memory management.""" + pass + + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Search memories via Mem0 and convert to ADK format. + + Args: + app_name: The application name (used as part of composite user_id). + user_id: The user identifier. + query: The search query. + + Returns: + SearchMemoryResponse with matching MemoryEntry objects. + """ + from .config import get_search_limit + + mem0_user_id = f"{app_name}/{user_id}" + limit = get_search_limit() + + try: + result = self._client.search(query=query, user_id=mem0_user_id, limit=limit) + except Exception: + logger.exception("Failed to search memories for user %s", mem0_user_id) + return SearchMemoryResponse(memories=[]) + + memories: list[MemoryEntry] = [] + for m in result.get("results", []): + memory_text = m.get("memory", "") + if not memory_text: + continue + + memories.append( + MemoryEntry( + content=types.Content( + role="user", + parts=[types.Part(text=memory_text)], + ), + id=m.get("id"), + ) + ) + + logger.debug( + "Found %d memories for query '%s' (user: %s)", + len(memories), + query[:30], + mem0_user_id, + ) + + return SearchMemoryResponse(memories=memories) diff --git a/src/blacki/memory/tools.py b/src/blacki/memory/tools.py index 81ef92b..140edf4 100644 --- a/src/blacki/memory/tools.py +++ b/src/blacki/memory/tools.py @@ -382,42 +382,3 @@ async def delete_memory( "status": "error", "error": f"Failed to delete memory: {e}", } - - -async def delete_all_memories( - tool_context: ToolContext, - user_id: str | None = None, -) -> dict[str, Any]: - """Delete all memories for a user. - - Use this tool with caution when a user wants to wipe all their stored - memories. This operation cannot be undone. - - Args: - tool_context: ADK tool context. - user_id: Unique identifier for the user. Defaults to MEM0_USER_ID env var. - - Returns: - Dictionary with status and result message. - """ - _ = tool_context - - client = get_memory_client() - if client is None: - return _memory_service_unavailable_response() - - user_id = user_id or get_default_user_id() - - try: - client.delete_all(user_id=user_id) - logger.warning("Deleted all memories for user %s", user_id) - return { - "status": "success", - "message": f"All memories deleted for user {user_id}.", - } - except Exception as e: - logger.exception("Failed to delete all memories for user %s", user_id) - return { - "status": "error", - "error": f"Failed to delete all memories: {e}", - } diff --git a/src/blacki/prompt.py b/src/blacki/prompt.py index ff07e98..4535d71 100644 --- a/src/blacki/prompt.py +++ b/src/blacki/prompt.py @@ -92,8 +92,6 @@ def return_instruction_root() -> str: information. You need the memory_id from search or list operations. - Use delete_memory when the user asks to forget specific information. You need the memory_id from search or list operations. -- Use delete_all_memories with caution when the user wants to wipe all - their stored memories. Confirm before executing. - All memory operations are scoped to the user_id. Memories are private and isolated per user. diff --git a/src/blacki/registry.py b/src/blacki/registry.py index 9b842d1..284e906 100644 --- a/src/blacki/registry.py +++ b/src/blacki/registry.py @@ -202,7 +202,6 @@ def _build_memory_tools() -> list[Any]: """Build memory tools.""" try: from blacki.memory import ( - delete_all_memories, delete_memory, get_all_memories, get_memory, @@ -218,7 +217,6 @@ def _build_memory_tools() -> list[Any]: get_memory, update_memory, delete_memory, - delete_all_memories, ] except ImportError as e: # pragma: no cover logger.warning("Failed to load Memory tools: %s", e) diff --git a/src/blacki/server.py b/src/blacki/server.py index ecb06a8..5b605ff 100644 --- a/src/blacki/server.py +++ b/src/blacki/server.py @@ -10,10 +10,13 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path +from typing import Any import uvicorn from fastapi import FastAPI from google.adk.cli.fast_api import get_fast_api_app +from google.adk.cli.service_registry import get_service_registry +from google.adk.memory.base_memory_service import BaseMemoryService from openinference.instrumentation.google_adk import GoogleADKInstrumentor from .adk_runtime import ( @@ -133,12 +136,35 @@ async def _stop_reminder_scheduler() -> None: session_uri = build_session_service_uri(env) session_db_kwargs = build_session_db_kwargs(env) + +def _create_mem0_memory_service(uri: str, **kwargs: Any) -> BaseMemoryService: + """Factory for mem0:// URI scheme. + + Returns Mem0MemoryService if client is available, InMemoryMemoryService otherwise. + """ + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + from blacki.memory.config import get_memory_client + + client = get_memory_client() + if client is None: + logger.info("Mem0 client not available, using in-memory memory service") + return InMemoryMemoryService() + + from blacki.memory.mem0_memory_service import Mem0MemoryService + + logger.info("Mem0 memory service initialized") + return Mem0MemoryService(client) + + +get_service_registry().register_memory_service("mem0", _create_mem0_memory_service) + app: FastAPI = get_fast_api_app( agents_dir=AGENT_DIR, session_service_uri=session_uri, session_db_kwargs=session_db_kwargs, artifact_service_uri=None, - memory_service_uri=None, + memory_service_uri="mem0://", allow_origins=env.allow_origins_list, web=env.serve_web_interface, reload_agents=env.reload_agents, @@ -195,22 +221,30 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]: @app.get("/health") -async def health() -> dict[str, str]: +async def health() -> dict[str, Any]: """Health check endpoint for container orchestration. Returns: dict with status key indicating service health. """ - checks: list[str] = [] + from blacki.memory.config import get_memory_client + + checks: dict[str, str] = {} + if _container is not None: try: await _container.pool.fetchval("SELECT 1") + checks["database"] = "healthy" except Exception: - checks.append("database:unreachable") + checks["database"] = "unhealthy" + + client = get_memory_client() + checks["memory_service"] = "healthy" if client else "unavailable" + + all_ok = all(v in ("healthy", "unavailable") for v in checks.values()) + status = "ok" if all_ok else "degraded" - if checks: - return {"status": "degraded", "details": "; ".join(checks)} - return {"status": "ok"} + return {"status": status, "checks": checks} def main() -> None: diff --git a/tests/memory/test_mem0_memory_service.py b/tests/memory/test_mem0_memory_service.py new file mode 100644 index 0000000..302dd67 --- /dev/null +++ b/tests/memory/test_mem0_memory_service.py @@ -0,0 +1,150 @@ +"""Tests for Mem0MemoryService.""" + +from unittest.mock import MagicMock + +import pytest +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry + +from blacki.memory.config import reset_memory_client +from blacki.memory.mem0_memory_service import Mem0MemoryService + + +class TestMem0MemoryService: + """Tests for Mem0MemoryService class.""" + + @pytest.fixture(autouse=True) + def reset_client(self) -> None: + """Reset the memory client before each test.""" + reset_memory_client() + + def test_init(self) -> None: + """Should initialize with Mem0 client.""" + mock_client = MagicMock() + service = Mem0MemoryService(mock_client) + assert service._client is mock_client + + @pytest.mark.asyncio + async def test_search_memory_success(self) -> None: + """Should search memories and convert to MemoryEntry objects.""" + mock_client = MagicMock() + mock_client.search.return_value = { + "results": [ + {"id": "mem_1", "memory": "User likes pizza", "score": 0.95}, + {"id": "mem_2", "memory": "User prefers tea", "score": 0.85}, + ] + } + + service = Mem0MemoryService(mock_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="food preferences" + ) + + assert isinstance(response, SearchMemoryResponse) + assert len(response.memories) == 2 + assert all(isinstance(m, MemoryEntry) for m in response.memories) + assert response.memories[0].content.parts is not None + assert response.memories[1].content.parts is not None + assert response.memories[0].content.parts[0].text == "User likes pizza" + assert response.memories[1].content.parts[0].text == "User prefers tea" + + mock_client.search.assert_called_once() + call_kwargs = mock_client.search.call_args[1] + assert call_kwargs["query"] == "food preferences" + assert call_kwargs["user_id"] == "test_app/test_user" + + @pytest.mark.asyncio + async def test_search_memory_empty_results(self) -> None: + """Should return empty list when no memories found.""" + mock_client = MagicMock() + mock_client.search.return_value = {"results": []} + + service = Mem0MemoryService(mock_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="nonexistent" + ) + + assert isinstance(response, SearchMemoryResponse) + assert len(response.memories) == 0 + + @pytest.mark.asyncio + async def test_search_memory_skips_empty_text(self) -> None: + """Should skip results with empty memory text.""" + mock_client = MagicMock() + mock_client.search.return_value = { + "results": [ + {"id": "mem_1", "memory": "Valid memory", "score": 0.95}, + {"id": "mem_2", "memory": "", "score": 0.85}, + {"id": "mem_3", "memory": "Another valid", "score": 0.75}, + ] + } + + service = Mem0MemoryService(mock_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="test" + ) + + assert len(response.memories) == 2 + assert response.memories[0].content.parts is not None + assert response.memories[1].content.parts is not None + assert response.memories[0].content.parts[0].text == "Valid memory" + assert response.memories[1].content.parts[0].text == "Another valid" + + @pytest.mark.asyncio + async def test_search_memory_handles_exception(self) -> None: + """Should return empty list on search failure.""" + mock_client = MagicMock() + mock_client.search.side_effect = Exception("Connection failed") + + service = Mem0MemoryService(mock_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="test" + ) + + assert isinstance(response, SearchMemoryResponse) + assert len(response.memories) == 0 + + @pytest.mark.asyncio + async def test_add_session_to_memory_noop(self) -> None: + """Should do nothing for add_session_to_memory.""" + from google.adk.sessions.session import Session + + mock_client = MagicMock() + mock_session = MagicMock(spec=Session) + service = Mem0MemoryService(mock_client) + + await service.add_session_to_memory(mock_session) + + mock_client.assert_not_called() + + @pytest.mark.asyncio + async def test_add_events_to_memory_noop(self) -> None: + """Should do nothing for add_events_to_memory.""" + mock_client = MagicMock() + service = Mem0MemoryService(mock_client) + + await service.add_events_to_memory( + app_name="test_app", + user_id="test_user", + events=[], + ) + + mock_client.assert_not_called() + + @pytest.mark.asyncio + async def test_search_memory_uses_custom_limit( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Should use configured search limit.""" + monkeypatch.setenv("MEM0_SEARCH_LIMIT", "10") + + mock_client = MagicMock() + mock_client.search.return_value = {"results": []} + + service = Mem0MemoryService(mock_client) + await service.search_memory( + app_name="test_app", user_id="test_user", query="test" + ) + + call_kwargs = mock_client.search.call_args[1] + assert call_kwargs["limit"] == 10 diff --git a/tests/memory/test_tools.py b/tests/memory/test_tools.py index a173b76..f60e249 100644 --- a/tests/memory/test_tools.py +++ b/tests/memory/test_tools.py @@ -9,7 +9,6 @@ from blacki.memory.config import reset_memory_client from blacki.memory.tools import ( - delete_all_memories, delete_memory, get_all_memories, get_memory, @@ -323,32 +322,3 @@ async def test_delete_memory_empty_id( assert result["status"] == "error" assert "non-empty" in result["error"].lower() - - -class TestDeleteAllMemories: - """Tests for delete_all_memories function.""" - - @staticmethod - def _tool_context() -> ToolContext: - return cast(ToolContext, MockToolContext(state=MockState({}))) - - @pytest.fixture(autouse=True) - def reset_client(self) -> None: - """Reset the memory client before each test.""" - reset_memory_client() - - @pytest.mark.asyncio - async def test_delete_all_memories_success( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Should delete all memories for a user.""" - monkeypatch.setenv("MEM0_API_KEY", "test_key") - tool_context = self._tool_context() - - mock_client = MagicMock() - - with patch("blacki.memory.tools.get_memory_client", return_value=mock_client): - result = await delete_all_memories(tool_context, user_id="test_user") - - assert result["status"] == "success" - mock_client.delete_all.assert_called_once_with(user_id="test_user") diff --git a/tests/test_registry.py b/tests/test_registry.py index 2143cc8..60c7681 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -42,7 +42,7 @@ def test_empty_config_returns_memory_tools(self) -> None: config = ToolConfig() tools = build_tools(config) - assert len(tools) == 9 + assert len(tools) == 8 def test_brave_search_tools_added(self) -> None: """Should add Brave Search tools when API key provided.""" @@ -50,7 +50,7 @@ def test_brave_search_tools_added(self) -> None: tools = build_tools(config) - assert len(tools) == 10 + assert len(tools) == 9 def test_database_tools_added(self) -> None: """Should add database-backed tools when database URL provided.""" @@ -66,7 +66,7 @@ def test_sandbox_tools_added(self) -> None: tools = build_tools(config) - assert len(tools) == 15 + assert len(tools) == 14 def test_weather_tools_disabled(self) -> None: """Should not add weather tools when disabled.""" @@ -74,7 +74,7 @@ def test_weather_tools_disabled(self) -> None: tools = build_tools(config) - assert len(tools) == 7 + assert len(tools) == 6 def test_all_tools_with_full_config(self) -> None: """Should include all tools with full configuration.""" @@ -98,7 +98,7 @@ def test_build_brave_search_tools_import_error(self) -> None: config = ToolConfig(brave_search_api_key="test-key") tools = build_tools(config) - assert len(tools) == 9 + assert len(tools) == 8 class TestBuildToolConfigFromEnv: @@ -252,7 +252,7 @@ def test_returns_tools_when_available(self) -> None: tools = _build_memory_tools() - assert len(tools) == 7 + assert len(tools) == 6 class TestBuildSkillTools: From d79e4c185fd2645690a49eb55222dc28f5ae12cd Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sat, 16 May 2026 18:32:28 +0530 Subject: [PATCH 2/7] fix: use plain user_id in Mem0MemoryService search The user_id should match what explicit memory tools use, not be prefixed with app_name. This ensures PreloadMemoryTool finds memories saved via save_memory tool. --- src/blacki/memory/mem0_memory_service.py | 2 +- tests/memory/test_mem0_memory_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blacki/memory/mem0_memory_service.py b/src/blacki/memory/mem0_memory_service.py index 3d0467c..2d3ccf0 100644 --- a/src/blacki/memory/mem0_memory_service.py +++ b/src/blacki/memory/mem0_memory_service.py @@ -62,7 +62,7 @@ async def search_memory( """ from .config import get_search_limit - mem0_user_id = f"{app_name}/{user_id}" + mem0_user_id = user_id limit = get_search_limit() try: diff --git a/tests/memory/test_mem0_memory_service.py b/tests/memory/test_mem0_memory_service.py index 302dd67..64d92a2 100644 --- a/tests/memory/test_mem0_memory_service.py +++ b/tests/memory/test_mem0_memory_service.py @@ -51,7 +51,7 @@ async def test_search_memory_success(self) -> None: mock_client.search.assert_called_once() call_kwargs = mock_client.search.call_args[1] assert call_kwargs["query"] == "food preferences" - assert call_kwargs["user_id"] == "test_app/test_user" + assert call_kwargs["user_id"] == "test_user" @pytest.mark.asyncio async def test_search_memory_empty_results(self) -> None: From 7f7e4ba986c4c82038c5d3c0be311930fc758db4 Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sat, 16 May 2026 18:41:23 +0530 Subject: [PATCH 3/7] fix: pass memory_service to Runner in AdkRuntime Telegram bot was not receiving preloaded memories because AdkRuntime created Runner without memory_service. PreloadMemoryTool's call to search_memory() would fail silently, preventing automatic memory injection. --- src/blacki/adk_runtime.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/blacki/adk_runtime.py b/src/blacki/adk_runtime.py index 6884cd6..5b2d2da 100644 --- a/src/blacki/adk_runtime.py +++ b/src/blacki/adk_runtime.py @@ -9,6 +9,7 @@ from google.adk.agents.run_config import RunConfig, StreamingMode from google.adk.events import Event +from google.adk.memory.base_memory_service import BaseMemoryService from google.adk.runners import Runner from google.adk.sessions import Session from google.adk.sessions.base_session_service import BaseSessionService @@ -99,7 +100,11 @@ class SessionLocator: class AdkRuntime: """Small helper around ADK Runner and SessionService.""" - def __init__(self, session_service: BaseSessionService) -> None: + def __init__( + self, + session_service: BaseSessionService, + memory_service: BaseMemoryService | None = None, + ) -> None: from .agent import app as agent_app self.app = agent_app @@ -109,6 +114,7 @@ def __init__(self, session_service: BaseSessionService) -> None: app=self.app, app_name=self.app_name, session_service=self.session_service, + memory_service=memory_service, auto_create_session=False, ) @@ -377,7 +383,15 @@ def create_adk_runtime(env: ServerEnv) -> AdkRuntime: session_db_kwargs=session_db_kwargs, agent_dir=env.agent_dir, ) - return AdkRuntime(session_service=session_service) + from google.adk.cli.service_registry import get_service_registry + + memory_service = get_service_registry().create_memory_service( + "mem0://", agents_dir=str(Path(env.agent_dir).resolve()) + ) + return AdkRuntime( + session_service=session_service, + memory_service=memory_service, + ) def _build_session_state( From 8dafdee118a0fb1a4ef6298d20076d7c5d06a732 Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sat, 16 May 2026 22:48:33 +0530 Subject: [PATCH 4/7] fix: address PR review and CI failures - Use asyncio.to_thread to avoid blocking event loop in search_memory - Handle various result types (dict, list, None) from Mem0 search - Fix reminder tests to pass user_id=None for no-user-id tests - Add tests for edge cases in Mem0MemoryService --- src/blacki/memory/mem0_memory_service.py | 62 +++++++++++++----------- src/blacki/memory/tools.py | 7 ++- tests/conftest.py | 2 + tests/memory/test_mem0_memory_service.py | 51 +++++++++++++++++++ tests/reminders/test_reminder_tools.py | 6 +-- 5 files changed, 94 insertions(+), 34 deletions(-) diff --git a/src/blacki/memory/mem0_memory_service.py b/src/blacki/memory/mem0_memory_service.py index 2d3ccf0..521f138 100644 --- a/src/blacki/memory/mem0_memory_service.py +++ b/src/blacki/memory/mem0_memory_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING @@ -53,7 +54,7 @@ async def search_memory( """Search memories via Mem0 and convert to ADK format. Args: - app_name: The application name (used as part of composite user_id). + app_name: The application name (unused, user_id is passed directly). user_id: The user identifier. query: The search query. @@ -62,36 +63,43 @@ async def search_memory( """ from .config import get_search_limit - mem0_user_id = user_id limit = get_search_limit() try: - result = self._client.search(query=query, user_id=mem0_user_id, limit=limit) - except Exception: - logger.exception("Failed to search memories for user %s", mem0_user_id) - return SearchMemoryResponse(memories=[]) + result = await asyncio.to_thread( + self._client.search, query=query, user_id=user_id, limit=limit + ) - memories: list[MemoryEntry] = [] - for m in result.get("results", []): - memory_text = m.get("memory", "") - if not memory_text: - continue - - memories.append( - MemoryEntry( - content=types.Content( - role="user", - parts=[types.Part(text=memory_text)], - ), - id=m.get("id"), + raw_results = ( + result.get("results", []) if isinstance(result, dict) else result + ) or [] + + memories: list[MemoryEntry] = [] + for m in raw_results: + if not isinstance(m, dict): + continue + memory_text = m.get("memory", "") + if not memory_text: + continue + + memories.append( + MemoryEntry( + content=types.Content( + role="user", + parts=[types.Part(text=memory_text)], + ), + id=m.get("id"), + ) ) - ) - logger.debug( - "Found %d memories for query '%s' (user: %s)", - len(memories), - query[:30], - mem0_user_id, - ) + logger.debug( + "Found %d memories for query '%s' (user: %s)", + len(memories), + query[:30], + user_id, + ) + return SearchMemoryResponse(memories=memories) - return SearchMemoryResponse(memories=memories) + except Exception: + logger.exception("Failed to search memories for user %s", user_id) + return SearchMemoryResponse(memories=[]) diff --git a/src/blacki/memory/tools.py b/src/blacki/memory/tools.py index 140edf4..551733d 100644 --- a/src/blacki/memory/tools.py +++ b/src/blacki/memory/tools.py @@ -8,7 +8,6 @@ from google.adk.tools import ToolContext from .config import ( - get_default_user_id, get_memory_client, get_memory_client_error, get_search_limit, @@ -60,7 +59,7 @@ async def save_memory( "error": "Memory text must be a non-empty string.", } - user_id = user_id or get_default_user_id() + user_id = user_id or tool_context.user_id try: result = client.add(text, user_id=user_id) @@ -111,7 +110,7 @@ async def search_memory( "results": [], } - user_id = user_id or get_default_user_id() + user_id = user_id or tool_context.user_id limit = limit or get_search_limit() try: @@ -179,7 +178,7 @@ async def get_all_memories( if client is None: return _memory_service_unavailable_response({"results": []}) - user_id = user_id or get_default_user_id() + user_id = user_id or tool_context.user_id if page > 3: logger.warning( diff --git a/tests/conftest.py b/tests/conftest.py index 916b302..350d937 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -129,6 +129,7 @@ def __init__( state: MockState | None = None, user_content: MockContent | None = None, actions: MockEventActions | None = None, + user_id: str | None = "test_user", ) -> None: """Initialize mock tool context.""" self.agent_name = agent_name @@ -136,6 +137,7 @@ def __init__( self.state = state if state is not None else MockState() self.user_content = user_content self.actions = actions if actions is not None else MockEventActions() + self.user_id = user_id class MockBaseTool: diff --git a/tests/memory/test_mem0_memory_service.py b/tests/memory/test_mem0_memory_service.py index 64d92a2..7d141e2 100644 --- a/tests/memory/test_mem0_memory_service.py +++ b/tests/memory/test_mem0_memory_service.py @@ -148,3 +148,54 @@ async def test_search_memory_uses_custom_limit( call_kwargs = mock_client.search.call_args[1] assert call_kwargs["limit"] == 10 + + @pytest.mark.asyncio + async def test_search_memory_handles_non_dict_results(self) -> None: + """Should skip non-dict items in results.""" + mock_client = MagicMock() + mock_client.search.return_value = { + "results": [ + {"id": "mem_1", "memory": "Valid memory"}, + "invalid_string_item", + 123, + None, + ] + } + + service = Mem0MemoryService(mock_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="test" + ) + + assert len(response.memories) == 1 + assert response.memories[0].content.parts is not None + assert response.memories[0].content.parts[0].text == "Valid memory" + + @pytest.mark.asyncio + async def test_search_memory_handles_list_result(self) -> None: + """Should handle direct list result from Mem0.""" + mock_client = MagicMock() + mock_client.search.return_value = [ + {"id": "mem_1", "memory": "List memory 1"}, + {"id": "mem_2", "memory": "List memory 2"}, + ] + + service = Mem0MemoryService(mock_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="test" + ) + + assert len(response.memories) == 2 + + @pytest.mark.asyncio + async def test_search_memory_handles_none_result(self) -> None: + """Should handle None result from Mem0.""" + mock_client = MagicMock() + mock_client.search.return_value = None + + service = Mem0MemoryService(mock_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="test" + ) + + assert len(response.memories) == 0 diff --git a/tests/reminders/test_reminder_tools.py b/tests/reminders/test_reminder_tools.py index fd3d223..dc1e0e1 100644 --- a/tests/reminders/test_reminder_tools.py +++ b/tests/reminders/test_reminder_tools.py @@ -63,7 +63,7 @@ async def test_schedule_one_time_reminder(self, mock_scheduler: MagicMock) -> No async def test_schedule_reminder_no_user_id(self) -> None: """Should return error if user_id not in context.""" state = MockState({}) - tool_context = cast(MagicMock, MockToolContext(state=state)) + tool_context = cast(MagicMock, MockToolContext(state=state, user_id=None)) result = await schedule_reminder( tool_context=tool_context, @@ -257,7 +257,7 @@ async def test_list_reminders_with_items(self, mock_scheduler: MagicMock) -> Non async def test_list_reminders_no_user_id(self) -> None: """Should return error if user_id not in context.""" state = MockState({}) - tool_context = cast(MagicMock, MockToolContext(state=state)) + tool_context = cast(MagicMock, MockToolContext(state=state, user_id=None)) result = await list_reminders(tool_context=tool_context) @@ -323,7 +323,7 @@ async def test_cancel_reminder_not_found(self, mock_scheduler: MagicMock) -> Non async def test_cancel_reminder_no_user_id(self) -> None: """Should return error if user_id not in context.""" state = MockState({}) - tool_context = cast(MagicMock, MockToolContext(state=state)) + tool_context = cast(MagicMock, MockToolContext(state=state, user_id=None)) result = await cancel_reminder( tool_context=tool_context, From 4bb019ace53c80fa1711f2d1b9c14d337cde68d6 Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sat, 16 May 2026 23:58:09 +0530 Subject: [PATCH 5/7] fix: address PR review feedback - Move mem0 service registration to adk_runtime.py for test safety - Update health check to return degraded when configured but unavailable --- src/blacki/adk_runtime.py | 26 ++++++++++++++++++++++++-- src/blacki/server.py | 38 ++++++++++---------------------------- tests/test_adk_runtime.py | 15 +++++++++++++++ 3 files changed, 49 insertions(+), 30 deletions(-) diff --git a/src/blacki/adk_runtime.py b/src/blacki/adk_runtime.py index 5b2d2da..4cc75b5 100644 --- a/src/blacki/adk_runtime.py +++ b/src/blacki/adk_runtime.py @@ -8,6 +8,7 @@ from typing import Any from google.adk.agents.run_config import RunConfig, StreamingMode +from google.adk.cli.service_registry import get_service_registry from google.adk.events import Event from google.adk.memory.base_memory_service import BaseMemoryService from google.adk.runners import Runner @@ -20,6 +21,29 @@ logger = logging.getLogger(__name__) + +def _create_mem0_memory_service(uri: str, **kwargs: Any) -> BaseMemoryService: + """Factory for mem0:// URI scheme. + + Returns Mem0MemoryService if client is available, InMemoryMemoryService otherwise. + """ + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + from blacki.memory.config import get_memory_client + + client = get_memory_client() + if client is None: + logger.info("Mem0 client not available, using in-memory memory service") + return InMemoryMemoryService() + + from blacki.memory.mem0_memory_service import Mem0MemoryService + + logger.info("Mem0 memory service initialized") + return Mem0MemoryService(client) + + +get_service_registry().register_memory_service("mem0", _create_mem0_memory_service) + DEFAULT_EMPTY_RESPONSE = "I apologize, but I couldn't generate a response." SESSION_VERSION_SEPARATOR = "-v" @@ -383,8 +407,6 @@ def create_adk_runtime(env: ServerEnv) -> AdkRuntime: session_db_kwargs=session_db_kwargs, agent_dir=env.agent_dir, ) - from google.adk.cli.service_registry import get_service_registry - memory_service = get_service_registry().create_memory_service( "mem0://", agents_dir=str(Path(env.agent_dir).resolve()) ) diff --git a/src/blacki/server.py b/src/blacki/server.py index 5b605ff..44fc0ea 100644 --- a/src/blacki/server.py +++ b/src/blacki/server.py @@ -15,8 +15,6 @@ import uvicorn from fastapi import FastAPI from google.adk.cli.fast_api import get_fast_api_app -from google.adk.cli.service_registry import get_service_registry -from google.adk.memory.base_memory_service import BaseMemoryService from openinference.instrumentation.google_adk import GoogleADKInstrumentor from .adk_runtime import ( @@ -136,29 +134,6 @@ async def _stop_reminder_scheduler() -> None: session_uri = build_session_service_uri(env) session_db_kwargs = build_session_db_kwargs(env) - -def _create_mem0_memory_service(uri: str, **kwargs: Any) -> BaseMemoryService: - """Factory for mem0:// URI scheme. - - Returns Mem0MemoryService if client is available, InMemoryMemoryService otherwise. - """ - from google.adk.memory.in_memory_memory_service import InMemoryMemoryService - - from blacki.memory.config import get_memory_client - - client = get_memory_client() - if client is None: - logger.info("Mem0 client not available, using in-memory memory service") - return InMemoryMemoryService() - - from blacki.memory.mem0_memory_service import Mem0MemoryService - - logger.info("Mem0 memory service initialized") - return Mem0MemoryService(client) - - -get_service_registry().register_memory_service("mem0", _create_mem0_memory_service) - app: FastAPI = get_fast_api_app( agents_dir=AGENT_DIR, session_service_uri=session_uri, @@ -227,7 +202,7 @@ async def health() -> dict[str, Any]: Returns: dict with status key indicating service health. """ - from blacki.memory.config import get_memory_client + from blacki.memory.config import get_memory_client, get_memory_client_error checks: dict[str, str] = {} @@ -239,9 +214,16 @@ async def health() -> dict[str, Any]: checks["database"] = "unhealthy" client = get_memory_client() - checks["memory_service"] = "healthy" if client else "unavailable" + error = get_memory_client_error() + + if client: + checks["memory_service"] = "healthy" + elif error: + checks["memory_service"] = "degraded" + else: + checks["memory_service"] = "unavailable" - all_ok = all(v in ("healthy", "unavailable") for v in checks.values()) + all_ok = all(v == "healthy" for v in checks.values()) status = "ok" if all_ok else "degraded" return {"status": status, "checks": checks} diff --git a/tests/test_adk_runtime.py b/tests/test_adk_runtime.py index 2fdf893..94de708 100644 --- a/tests/test_adk_runtime.py +++ b/tests/test_adk_runtime.py @@ -299,6 +299,21 @@ def test_create_adk_runtime_uses_env_configuration(tmp_path: Path) -> None: assert isinstance(runtime.session_service, DatabaseSessionService) +def test_create_adk_runtime_falls_back_to_in_memory_when_mem0_unavailable( + tmp_path: Path, +) -> None: + """Test runtime falls back to InMemoryMemoryService when Mem0 unavailable.""" + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + env = _build_server_env() + env.agent_dir = str(tmp_path) + + with patch("blacki.memory.config.get_memory_client", return_value=None): + runtime = create_adk_runtime(env) + + assert isinstance(runtime.runner.memory_service, InMemoryMemoryService) + + def test_extract_session_version_rejects_invalid_format() -> None: """Test that malformed versioned session IDs fail fast.""" with pytest.raises(ValueError, match="Unexpected session id format"): From 35ef6647dde9058959ecaf41ebe398133dce99e1 Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sun, 17 May 2026 00:15:34 +0530 Subject: [PATCH 6/7] test: add test for Mem0MemoryService happy path --- tests/test_adk_runtime.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_adk_runtime.py b/tests/test_adk_runtime.py index 94de708..fca27ed 100644 --- a/tests/test_adk_runtime.py +++ b/tests/test_adk_runtime.py @@ -299,6 +299,24 @@ def test_create_adk_runtime_uses_env_configuration(tmp_path: Path) -> None: assert isinstance(runtime.session_service, DatabaseSessionService) +def test_create_adk_runtime_uses_mem0_when_client_available( + tmp_path: Path, +) -> None: + """Test runtime uses Mem0MemoryService when client is available.""" + from unittest.mock import MagicMock + + from blacki.memory.mem0_memory_service import Mem0MemoryService + + env = _build_server_env() + env.agent_dir = str(tmp_path) + + mock_client = MagicMock() + with patch("blacki.memory.config.get_memory_client", return_value=mock_client): + runtime = create_adk_runtime(env) + + assert isinstance(runtime.runner.memory_service, Mem0MemoryService) + + def test_create_adk_runtime_falls_back_to_in_memory_when_mem0_unavailable( tmp_path: Path, ) -> None: From 91261bdc6391e07c6dd75390066d639e2b806606 Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sun, 17 May 2026 00:21:00 +0530 Subject: [PATCH 7/7] fix: make date tests timezone-independent Use specific dates instead of relative dates like 'yesterday' to avoid timezone-dependent test failures when CI runs near IST midnight. --- tests/calories/test_tools.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/calories/test_tools.py b/tests/calories/test_tools.py index 6424e96..90583a6 100644 --- a/tests/calories/test_tools.py +++ b/tests/calories/test_tools.py @@ -1,5 +1,4 @@ # mypy: disable-error-code="no-untyped-def" -from datetime import date from unittest.mock import AsyncMock, create_autospec, patch import pytest @@ -185,12 +184,12 @@ async def test_log_meal_with_past_date( mock_tool_context, description="apple", estimated_calories=95, - date="yesterday", + date="2026-04-25", ) assert result["status"] == "success" entry = mock_storage.add_entry.call_args[0][0] - assert entry.logged_date != str(date.today()) + assert entry.logged_date == "2026-04-25" @pytest.mark.asyncio @@ -229,12 +228,12 @@ async def test_edit_meal_with_date(mock_get_storage, mock_tool_context) -> None: mock_get_storage.return_value = mock_storage mock_storage.update_entry.return_value = True - result = await edit_meal(mock_tool_context, entry_id=1, date="yesterday") + result = await edit_meal(mock_tool_context, entry_id=1, date="2026-04-15") assert result["status"] == "success" call_kwargs = mock_storage.update_entry.call_args[1] assert "logged_date" in call_kwargs - assert call_kwargs["logged_date"] != str(date.today()) + assert call_kwargs["logged_date"] == "2026-04-15" @pytest.mark.asyncio