Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from memos.multi_mem_cube.composite_cube import CompositeCubeView
from memos.multi_mem_cube.single_cube import SingleCubeView
from memos.multi_mem_cube.views import MemCubeView
from memos.plugins.hooks import hookable


logger = get_logger(__name__)
Expand All @@ -44,6 +45,7 @@ def __init__(self, dependencies: HandlerDependencies):
"naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent"
)

@hookable("search")
def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse:
"""
Main handler for search memories endpoint.
Expand Down
5 changes: 5 additions & 0 deletions src/memos/api/server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from memos.api.exceptions import APIExceptionHandler
from memos.api.middleware.request_context import RequestContextMiddleware
from memos.api.routers.server_router import router as server_router
from memos.plugins.manager import plugin_manager


load_dotenv()

plugin_manager.discover()

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
Expand All @@ -38,6 +41,8 @@
# Fallback for unknown errors
app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler)

plugin_manager.init_app(app)


if __name__ == "__main__":
import argparse
Expand Down
20 changes: 20 additions & 0 deletions src/memos/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from memos.plugins.base import MemOSPlugin
from memos.plugins.hook_defs import H, HookSpec, all_hook_specs, define_hook, get_hook_spec
from memos.plugins.hooks import hookable, register_hook, register_hooks, trigger_hook
from memos.plugins.manager import PluginManager, plugin_manager


__all__ = [
"H",
"HookSpec",
"MemOSPlugin",
"PluginManager",
"all_hook_specs",
"define_hook",
"get_hook_spec",
"hookable",
"plugin_manager",
"register_hook",
"register_hooks",
"trigger_hook",
]
72 changes: 72 additions & 0 deletions src/memos/plugins/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""MemOS plugin base class — all plugins must inherit from this class."""

from __future__ import annotations

from typing import TYPE_CHECKING


if TYPE_CHECKING:
from collections.abc import Callable

from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware


class MemOSPlugin:
"""MemOS plugin base class.

Provides three unified registration methods. Plugin developers need only
inherit from this class and register capabilities via self.register_*
in init_app.
"""

name: str = "unnamed"
version: str = "0.0.0"
description: str = ""

_app: FastAPI | None = None

# ------------------------------------------------------------------ #
# Registration methods — called by plugins in init_app
# ------------------------------------------------------------------ #

def register_router(self, router, **kwargs) -> None:
"""Register a router."""
self._app.include_router(router, **kwargs)

def register_middleware(self, middleware_cls: type[BaseHTTPMiddleware], **kwargs) -> None:
"""Register middleware."""
self._app.add_middleware(middleware_cls, **kwargs)

def register_hook(self, name: str, callback: Callable) -> None:
"""Register a single Hook callback."""
from memos.plugins.hooks import register_hook

register_hook(name, callback)

def register_hooks(self, names: list[str], callback: Callable) -> None:
"""Batch-register the same callback to multiple Hook points."""
from memos.plugins.hooks import register_hooks

register_hooks(names, callback)

# ------------------------------------------------------------------ #
# Internal methods — called by PluginManager, plugin developers need not care
# ------------------------------------------------------------------ #

def _bind_app(self, app: FastAPI) -> None:
"""Bind FastAPI instance so that register_* methods are available."""
self._app = app

# ------------------------------------------------------------------ #
# Lifecycle methods — override in subclasses
# ------------------------------------------------------------------ #

def on_load(self) -> None:
"""Called after the plugin is discovered. Used for initialization logic, e.g. checking dependencies, reading config."""

def init_app(self) -> None:
"""Called after FastAPI app is bound. Register routes, middleware, and Hooks via self.register_* here."""

def on_shutdown(self) -> None:
"""Called when the service shuts down. Used for resource cleanup."""
88 changes: 88 additions & 0 deletions src/memos/plugins/hook_defs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Hook declaration registry — single source of truth for CE repo Hook points.

The @hookable decorator automatically declares its before/after Hooks; no need to manually define_hook.
Hooks triggered by custom trigger_hook must be explicitly declared in this file.

Plugin-owned Hooks should be declared within each plugin package, not in this file.
"""

from __future__ import annotations

import logging

from dataclasses import dataclass


logger = logging.getLogger(__name__)

_specs: dict[str, HookSpec] = {}


@dataclass(frozen=True)
class HookSpec:
"""Hook spec definition."""

name: str
description: str
params: list[str]
pipe_key: str | None = None


def define_hook(
name: str,
*,
description: str,
params: list[str],
pipe_key: str | None = None,
) -> None:
"""Declare a Hook point. Skips if already exists (idempotent)."""
if name in _specs:
return
_specs[name] = HookSpec(
name=name,
description=description,
params=params,
pipe_key=pipe_key,
)
logger.debug("Hook defined: %s (pipe_key=%s)", name, pipe_key)


def get_hook_spec(name: str) -> HookSpec | None:
return _specs.get(name)


def all_hook_specs() -> dict[str, HookSpec]:
"""Return all declared Hooks (including @hookable auto-declared + plugin-declared)."""
return dict(_specs)


# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# CE Hook name constants
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━


class H:
"""CE Hook name constants. Plugin-owned Hook constants should be defined within the plugin package."""

# @hookable("add") — AddHandler.handle_add_memories
ADD_BEFORE = "add.before"
ADD_AFTER = "add.after"

# @hookable("search") — SearchHandler.handle_search_memories
SEARCH_BEFORE = "search.before"
SEARCH_AFTER = "search.after"

# Custom Hook (manually triggered via trigger_hook)
ADD_MEMORIES_POST_PROCESS = "add.memories.post_process"


# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# CE custom Hook declarations (@hookable-generated ones need not be declared here)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

define_hook(
H.ADD_MEMORIES_POST_PROCESS,
description="Post-process result after add_memories returns, before constructing Response",
params=["request", "result"],
pipe_key="result",
)
124 changes: 124 additions & 0 deletions src/memos/plugins/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Hook runtime — registration, triggering, and @hookable decorator."""

from __future__ import annotations

import asyncio
import logging

from collections import defaultdict
from functools import wraps
from typing import TYPE_CHECKING, Any


if TYPE_CHECKING:
from collections.abc import Callable


logger = logging.getLogger(__name__)

_hooks: dict[str, list[Callable]] = defaultdict(list)


def register_hook(name: str, callback: Callable) -> None:
"""Register a hook callback. Undeclared hook names will log a warning."""
from memos.plugins.hook_defs import get_hook_spec

if get_hook_spec(name) is None:
logger.warning(
"Registering callback for undeclared hook: %s (callback=%s)",
name,
getattr(callback, "__qualname__", repr(callback)),
)
_hooks[name].append(callback)
logger.debug(
"Hook registered: %s -> %s",
name,
getattr(callback, "__qualname__", repr(callback)),
)


def register_hooks(names: list[str], callback: Callable) -> None:
"""Batch-register the same callback to multiple hook points."""
for name in names:
register_hook(name, callback)


def trigger_hook(name: str, **kwargs: Any) -> Any:
"""Trigger a hook, invoking all registered callbacks in order.

- Zero overhead when no callbacks are registered
- Undeclared hook names will log a warning and be skipped
- pipe_key is auto-fetched from HookSpec, supports piped return value passing
"""
from memos.plugins.hook_defs import get_hook_spec

spec = get_hook_spec(name)
if spec is None:
logger.warning("Undeclared hook triggered: %s — ignored", name)
return None

pipe_key = spec.pipe_key

for cb in _hooks.get(name, []):
try:
rv = cb(**kwargs)
if pipe_key is not None and rv is not None:
kwargs[pipe_key] = rv
except Exception:
logger.exception(
"Hook %s callback %s failed",
name,
getattr(cb, "__qualname__", repr(cb)),
)

return kwargs.get(pipe_key) if pipe_key else None


def hookable(name: str):
"""Decorator: automatically triggers name.before / name.after hook before and after the method.

Auto-declares before/after Hooks (idempotent); no need to manually define_hook in hook_defs.py.
Supports piped return values: before can modify request, after can modify result.
Compatible with both sync and async methods.
"""
from memos.plugins.hook_defs import define_hook

define_hook(
f"{name}.before",
description=f"Before {name} executes; can modify request",
params=["request"],
pipe_key="request",
)
define_hook(
f"{name}.after",
description=f"After {name} executes; can modify result",
params=["request", "result"],
pipe_key="result",
)

def decorator(func):
if asyncio.iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(self, request, *args, **kwargs):
rv = trigger_hook(f"{name}.before", request=request)
request = rv if rv is not None else request
result = await func(self, request, *args, **kwargs)
rv = trigger_hook(f"{name}.after", request=request, result=result)
result = rv if rv is not None else result
return result

return async_wrapper

@wraps(func)
def sync_wrapper(self, request, *args, **kwargs):
rv = trigger_hook(f"{name}.before", request=request)
request = rv if rv is not None else request
result = func(self, request, *args, **kwargs)
rv = trigger_hook(f"{name}.after", request=request, result=result)
result = rv if rv is not None else result
return result

return sync_wrapper

return decorator
Loading