Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/blacki/adk_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import Any

from google.adk.agents.run_config import RunConfig, StreamingMode
from google.adk.cli.service_registry import get_service_registry
from google.adk.events import Event
from google.adk.memory.base_memory_service import BaseMemoryService
from google.adk.runners import Runner
from google.adk.sessions import Session
from google.adk.sessions.base_session_service import BaseSessionService
Expand All @@ -19,6 +21,29 @@

logger = logging.getLogger(__name__)


def _create_mem0_memory_service(uri: str, **kwargs: Any) -> BaseMemoryService:
"""Factory for mem0:// URI scheme.

Returns Mem0MemoryService if client is available, InMemoryMemoryService otherwise.
"""
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService

from blacki.memory.config import get_memory_client

client = get_memory_client()
if client is None:
logger.info("Mem0 client not available, using in-memory memory service")
return InMemoryMemoryService()

from blacki.memory.mem0_memory_service import Mem0MemoryService

logger.info("Mem0 memory service initialized")
return Mem0MemoryService(client)


get_service_registry().register_memory_service("mem0", _create_mem0_memory_service)

DEFAULT_EMPTY_RESPONSE = "I apologize, but I couldn't generate a response."
SESSION_VERSION_SEPARATOR = "-v"

Expand Down Expand Up @@ -99,7 +124,11 @@ class SessionLocator:
class AdkRuntime:
"""Small helper around ADK Runner and SessionService."""

def __init__(self, session_service: BaseSessionService) -> None:
def __init__(
self,
session_service: BaseSessionService,
memory_service: BaseMemoryService | None = None,
) -> None:
from .agent import app as agent_app

self.app = agent_app
Expand All @@ -109,6 +138,7 @@ def __init__(self, session_service: BaseSessionService) -> None:
app=self.app,
app_name=self.app_name,
session_service=self.session_service,
memory_service=memory_service,
auto_create_session=False,
)

Expand Down Expand Up @@ -377,7 +407,13 @@ def create_adk_runtime(env: ServerEnv) -> AdkRuntime:
session_db_kwargs=session_db_kwargs,
agent_dir=env.agent_dir,
)
return AdkRuntime(session_service=session_service)
memory_service = get_service_registry().create_memory_service(
"mem0://", agents_dir=str(Path(env.agent_dir).resolve())
)
Comment on lines +410 to +412
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The create_memory_service call for the mem0:// scheme relies on a registration that currently happens in src/blacki/server.py. If create_adk_runtime is called from a context where server.py has not been imported (e.g., a standalone script or certain test configurations), this will raise a ValueError because the scheme won't be registered in the global service_registry. Consider moving the registration to a more central location or ensuring it happens lazily within this function if not already present.

return AdkRuntime(
session_service=session_service,
memory_service=memory_service,
)


def _build_session_state(
Expand Down
3 changes: 3 additions & 0 deletions src/blacki/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ def create_agent() -> LlmAgent:
Returns:
Configured LlmAgent instance.
"""
from google.adk.tools.preload_memory_tool import preload_memory_tool

tool_config = build_tool_config_from_env()
agent_tools = build_tools(tool_config)
agent_tools.append(preload_memory_tool)

before_tool_callbacks: list[Any] = [logging_callbacks.before_tool]
after_model_callbacks: list[Any] = [logging_callbacks.after_model]
Expand Down
2 changes: 0 additions & 2 deletions src/blacki/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
reset_memory_client,
)
from .tools import (
delete_all_memories,
delete_memory,
get_all_memories,
get_memory,
Expand All @@ -16,7 +15,6 @@
)

__all__ = [
"delete_all_memories",
"delete_memory",
"get_all_memories",
"get_memory",
Expand Down
105 changes: 105 additions & 0 deletions src/blacki/memory/mem0_memory_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Memory service that bridges Mem0 to ADK's BaseMemoryService interface."""

from __future__ import annotations

import asyncio
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING

from google.adk.events.event import Event
from google.adk.memory.base_memory_service import (
BaseMemoryService,
SearchMemoryResponse,
)
from google.adk.memory.memory_entry import MemoryEntry
from google.adk.sessions.session import Session
from google.genai import types

if TYPE_CHECKING:
from mem0 import Memory

logger = logging.getLogger(__name__)


class Mem0MemoryService(BaseMemoryService):
"""Memory service backed by Mem0 OSS.

Wraps the existing Mem0 client to provide ADK-compatible memory operations.
Memories are managed manually via save_memory tool (no automatic session ingestion).
"""

def __init__(self, client: Memory):
self._client = client

async def add_session_to_memory(self, session: Session) -> None:
"""Not used - user chose manual memory management via save_memory tool."""
pass

async def add_events_to_memory(
self,
*,
app_name: str,
user_id: str,
events: Sequence[Event],
session_id: str | None = None,
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Not used - user chose manual memory management."""
pass

async def search_memory(
self, *, app_name: str, user_id: str, query: str
) -> SearchMemoryResponse:
"""Search memories via Mem0 and convert to ADK format.

Args:
app_name: The application name (unused, user_id is passed directly).
user_id: The user identifier.
query: The search query.

Returns:
SearchMemoryResponse with matching MemoryEntry objects.
"""
from .config import get_search_limit

limit = get_search_limit()

try:
result = await asyncio.to_thread(
self._client.search, query=query, user_id=user_id, limit=limit
)

raw_results = (
result.get("results", []) if isinstance(result, dict) else result
) or []

memories: list[MemoryEntry] = []
for m in raw_results:
if not isinstance(m, dict):
continue
memory_text = m.get("memory", "")
if not memory_text:
continue

memories.append(
MemoryEntry(
content=types.Content(
role="user",
parts=[types.Part(text=memory_text)],
),
id=m.get("id"),
)
)

logger.debug(
"Found %d memories for query '%s' (user: %s)",
len(memories),
query[:30],
user_id,
)
return SearchMemoryResponse(memories=memories)

except Exception:
logger.exception("Failed to search memories for user %s", user_id)
return SearchMemoryResponse(memories=[])
46 changes: 3 additions & 43 deletions src/blacki/memory/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from google.adk.tools import ToolContext

from .config import (
get_default_user_id,
get_memory_client,
get_memory_client_error,
get_search_limit,
Expand Down Expand Up @@ -60,7 +59,7 @@ async def save_memory(
"error": "Memory text must be a non-empty string.",
}

user_id = user_id or get_default_user_id()
user_id = user_id or tool_context.user_id

try:
result = client.add(text, user_id=user_id)
Expand Down Expand Up @@ -111,7 +110,7 @@ async def search_memory(
"results": [],
}

user_id = user_id or get_default_user_id()
user_id = user_id or tool_context.user_id
limit = limit or get_search_limit()

try:
Expand Down Expand Up @@ -179,7 +178,7 @@ async def get_all_memories(
if client is None:
return _memory_service_unavailable_response({"results": []})

user_id = user_id or get_default_user_id()
user_id = user_id or tool_context.user_id

if page > 3:
logger.warning(
Expand Down Expand Up @@ -382,42 +381,3 @@ async def delete_memory(
"status": "error",
"error": f"Failed to delete memory: {e}",
}


async def delete_all_memories(
tool_context: ToolContext,
user_id: str | None = None,
) -> dict[str, Any]:
"""Delete all memories for a user.

Use this tool with caution when a user wants to wipe all their stored
memories. This operation cannot be undone.

Args:
tool_context: ADK tool context.
user_id: Unique identifier for the user. Defaults to MEM0_USER_ID env var.

Returns:
Dictionary with status and result message.
"""
_ = tool_context

client = get_memory_client()
if client is None:
return _memory_service_unavailable_response()

user_id = user_id or get_default_user_id()

try:
client.delete_all(user_id=user_id)
logger.warning("Deleted all memories for user %s", user_id)
return {
"status": "success",
"message": f"All memories deleted for user {user_id}.",
}
except Exception as e:
logger.exception("Failed to delete all memories for user %s", user_id)
return {
"status": "error",
"error": f"Failed to delete all memories: {e}",
}
2 changes: 0 additions & 2 deletions src/blacki/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def return_instruction_root() -> str:
information. You need the memory_id from search or list operations.
- Use delete_memory when the user asks to forget specific information.
You need the memory_id from search or list operations.
- Use delete_all_memories with caution when the user wants to wipe all
their stored memories. Confirm before executing.
- All memory operations are scoped to the user_id. Memories are private
and isolated per user.
</memory_spec>
Expand Down
2 changes: 0 additions & 2 deletions src/blacki/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def _build_memory_tools() -> list[Any]:
"""Build memory tools."""
try:
from blacki.memory import (
delete_all_memories,
delete_memory,
get_all_memories,
get_memory,
Expand All @@ -218,7 +217,6 @@ def _build_memory_tools() -> list[Any]:
get_memory,
update_memory,
delete_memory,
delete_all_memories,
]
except ImportError as e: # pragma: no cover
logger.warning("Failed to load Memory tools: %s", e)
Expand Down
30 changes: 23 additions & 7 deletions src/blacki/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any

import uvicorn
from fastapi import FastAPI
Expand Down Expand Up @@ -138,7 +139,7 @@ async def _stop_reminder_scheduler() -> None:
session_service_uri=session_uri,
session_db_kwargs=session_db_kwargs,
artifact_service_uri=None,
memory_service_uri=None,
memory_service_uri="mem0://",
allow_origins=env.allow_origins_list,
web=env.serve_web_interface,
reload_agents=env.reload_agents,
Expand Down Expand Up @@ -195,22 +196,37 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:


@app.get("/health")
async def health() -> dict[str, str]:
async def health() -> dict[str, Any]:
"""Health check endpoint for container orchestration.

Returns:
dict with status key indicating service health.
"""
checks: list[str] = []
from blacki.memory.config import get_memory_client, get_memory_client_error

checks: dict[str, str] = {}

if _container is not None:
try:
await _container.pool.fetchval("SELECT 1")
checks["database"] = "healthy"
except Exception:
checks.append("database:unreachable")
checks["database"] = "unhealthy"

client = get_memory_client()
error = get_memory_client_error()

if client:
checks["memory_service"] = "healthy"
elif error:
checks["memory_service"] = "degraded"
else:
checks["memory_service"] = "unavailable"

all_ok = all(v == "healthy" for v in checks.values())
status = "ok" if all_ok else "degraded"

if checks:
return {"status": "degraded", "details": "; ".join(checks)}
return {"status": "ok"}
return {"status": status, "checks": checks}


def main() -> None:
Expand Down
Loading
Loading