From f124842371b06b75d27f202eed9b24b412dc7770 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 17:53:30 +0800 Subject: [PATCH 1/4] feat: add plugin system with hook, middleware, and route registration - Add MemOSPlugin base class with unified register_router/register_middleware/register_hook APIs - Add hook runtime: register_hook, trigger_hook, @hookable decorator with pipeline support - Add HookSpec declaration registry (hook_defs) for formalized hook definitions - Add PluginManager for entry_points-based plugin discovery and lifecycle management - Integrate plugin_manager into server_api (discover + init_app) - Add @hookable("add") and @hookable("search") to handlers - Add custom trigger_hook for add.memories.post_process in add_handler - Add comprehensive plugin framework tests Made-with: Cursor --- src/memos/api/handlers/add_handler.py | 6 + src/memos/api/handlers/search_handler.py | 2 + src/memos/api/server_api.py | 5 + src/memos/plugins/__init__.py | 20 ++ src/memos/plugins/base.py | 72 ++++ src/memos/plugins/hook_defs.py | 88 +++++ src/memos/plugins/hooks.py | 124 +++++++ src/memos/plugins/manager.py | 75 ++++ tests/plugins/__init__.py | 0 tests/plugins/conftest.py | 17 + tests/plugins/run_plugin_server.py | 0 tests/plugins/test_plugin_demo.py | 439 +++++++++++++++++++++++ 12 files changed, 848 insertions(+) create mode 100644 src/memos/plugins/__init__.py create mode 100644 src/memos/plugins/base.py create mode 100644 src/memos/plugins/hook_defs.py create mode 100644 src/memos/plugins/hooks.py create mode 100644 src/memos/plugins/manager.py create mode 100644 tests/plugins/__init__.py create mode 100644 tests/plugins/conftest.py create mode 100644 tests/plugins/run_plugin_server.py create mode 100644 tests/plugins/test_plugin_demo.py diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3cdbedabf..e53ef9393 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,6 +15,8 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hook_defs import H +from memos.plugins.hooks import hookable, trigger_hook from memos.types import MessageList @@ -37,6 +39,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. @@ -106,6 +109,9 @@ def _check_messages(messages: MessageList) -> None: results = cube_view.add_memories(add_req) + rv = trigger_hook(H.ADD_MEMORIES_POST_PROCESS, request=add_req, result=results) + results = rv if rv is not None else results + self.logger.info(f"[AddHandler] Final add results count={len(results)}") return MemoryResponse( diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 8e7785ad5..2877e5138 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -20,6 +20,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable logger = get_logger(__name__) @@ -44,6 +45,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) + @hookable("search") def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ Main handler for search memories endpoint. diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 529a709a4..78185a035 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -9,10 +9,13 @@ from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware from memos.api.routers.server_router import router as server_router +from memos.plugins.manager import plugin_manager load_dotenv() +plugin_manager.discover() + # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -38,6 +41,8 @@ # Fallback for unknown errors app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) +plugin_manager.init_app(app) + if __name__ == "__main__": import argparse diff --git a/src/memos/plugins/__init__.py b/src/memos/plugins/__init__.py new file mode 100644 index 000000000..0a0f8cde3 --- /dev/null +++ b/src/memos/plugins/__init__.py @@ -0,0 +1,20 @@ +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H, HookSpec, all_hook_specs, define_hook, get_hook_spec +from memos.plugins.hooks import hookable, register_hook, register_hooks, trigger_hook +from memos.plugins.manager import PluginManager, plugin_manager + + +__all__ = [ + "H", + "HookSpec", + "MemOSPlugin", + "PluginManager", + "all_hook_specs", + "define_hook", + "get_hook_spec", + "hookable", + "plugin_manager", + "register_hook", + "register_hooks", + "trigger_hook", +] diff --git a/src/memos/plugins/base.py b/src/memos/plugins/base.py new file mode 100644 index 000000000..f55d81b75 --- /dev/null +++ b/src/memos/plugins/base.py @@ -0,0 +1,72 @@ +"""MemOS plugin base class — all plugins must inherit from this class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastapi import FastAPI + from starlette.middleware.base import BaseHTTPMiddleware + + +class MemOSPlugin: + """MemOS plugin base class. + + Provides three unified registration methods. Plugin developers need only + inherit from this class and register capabilities via self.register_* + in init_app. + """ + + name: str = "unnamed" + version: str = "0.0.0" + description: str = "" + + _app: FastAPI | None = None + + # ------------------------------------------------------------------ # + # Registration methods — called by plugins in init_app + # ------------------------------------------------------------------ # + + def register_router(self, router, **kwargs) -> None: + """Register a router.""" + self._app.include_router(router, **kwargs) + + def register_middleware(self, middleware_cls: type[BaseHTTPMiddleware], **kwargs) -> None: + """Register middleware.""" + self._app.add_middleware(middleware_cls, **kwargs) + + def register_hook(self, name: str, callback: Callable) -> None: + """Register a single Hook callback.""" + from memos.plugins.hooks import register_hook + + register_hook(name, callback) + + def register_hooks(self, names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple Hook points.""" + from memos.plugins.hooks import register_hooks + + register_hooks(names, callback) + + # ------------------------------------------------------------------ # + # Internal methods — called by PluginManager, plugin developers need not care + # ------------------------------------------------------------------ # + + def _bind_app(self, app: FastAPI) -> None: + """Bind FastAPI instance so that register_* methods are available.""" + self._app = app + + # ------------------------------------------------------------------ # + # Lifecycle methods — override in subclasses + # ------------------------------------------------------------------ # + + def on_load(self) -> None: + """Called after the plugin is discovered. Used for initialization logic, e.g. checking dependencies, reading config.""" + + def init_app(self) -> None: + """Called after FastAPI app is bound. Register routes, middleware, and Hooks via self.register_* here.""" + + def on_shutdown(self) -> None: + """Called when the service shuts down. Used for resource cleanup.""" diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py new file mode 100644 index 000000000..a3b9dbaf8 --- /dev/null +++ b/src/memos/plugins/hook_defs.py @@ -0,0 +1,88 @@ +"""Hook declaration registry — single source of truth for CE repo Hook points. + +The @hookable decorator automatically declares its before/after Hooks; no need to manually define_hook. +Hooks triggered by custom trigger_hook must be explicitly declared in this file. + +Plugin-owned Hooks should be declared within each plugin package, not in this file. +""" + +from __future__ import annotations + +import logging + +from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + +_specs: dict[str, HookSpec] = {} + + +@dataclass(frozen=True) +class HookSpec: + """Hook spec definition.""" + + name: str + description: str + params: list[str] + pipe_key: str | None = None + + +def define_hook( + name: str, + *, + description: str, + params: list[str], + pipe_key: str | None = None, +) -> None: + """Declare a Hook point. Skips if already exists (idempotent).""" + if name in _specs: + return + _specs[name] = HookSpec( + name=name, + description=description, + params=params, + pipe_key=pipe_key, + ) + logger.debug("Hook defined: %s (pipe_key=%s)", name, pipe_key) + + +def get_hook_spec(name: str) -> HookSpec | None: + return _specs.get(name) + + +def all_hook_specs() -> dict[str, HookSpec]: + """Return all declared Hooks (including @hookable auto-declared + plugin-declared).""" + return dict(_specs) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE Hook name constants +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +class H: + """CE Hook name constants. Plugin-owned Hook constants should be defined within the plugin package.""" + + # @hookable("add") — AddHandler.handle_add_memories + ADD_BEFORE = "add.before" + ADD_AFTER = "add.after" + + # @hookable("search") — SearchHandler.handle_search_memories + SEARCH_BEFORE = "search.before" + SEARCH_AFTER = "search.after" + + # Custom Hook (manually triggered via trigger_hook) + ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE custom Hook declarations (@hookable-generated ones need not be declared here) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +define_hook( + H.ADD_MEMORIES_POST_PROCESS, + description="Post-process result after add_memories returns, before constructing Response", + params=["request", "result"], + pipe_key="result", +) diff --git a/src/memos/plugins/hooks.py b/src/memos/plugins/hooks.py new file mode 100644 index 000000000..eda98f98a --- /dev/null +++ b/src/memos/plugins/hooks.py @@ -0,0 +1,124 @@ +"""Hook runtime — registration, triggering, and @hookable decorator.""" + +from __future__ import annotations + +import asyncio +import logging + +from collections import defaultdict +from functools import wraps +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + +logger = logging.getLogger(__name__) + +_hooks: dict[str, list[Callable]] = defaultdict(list) + + +def register_hook(name: str, callback: Callable) -> None: + """Register a hook callback. Undeclared hook names will log a warning.""" + from memos.plugins.hook_defs import get_hook_spec + + if get_hook_spec(name) is None: + logger.warning( + "Registering callback for undeclared hook: %s (callback=%s)", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + _hooks[name].append(callback) + logger.debug( + "Hook registered: %s -> %s", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + + +def register_hooks(names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple hook points.""" + for name in names: + register_hook(name, callback) + + +def trigger_hook(name: str, **kwargs: Any) -> Any: + """Trigger a hook, invoking all registered callbacks in order. + + - Zero overhead when no callbacks are registered + - Undeclared hook names will log a warning and be skipped + - pipe_key is auto-fetched from HookSpec, supports piped return value passing + """ + from memos.plugins.hook_defs import get_hook_spec + + spec = get_hook_spec(name) + if spec is None: + logger.warning("Undeclared hook triggered: %s — ignored", name) + return None + + pipe_key = spec.pipe_key + + for cb in _hooks.get(name, []): + try: + rv = cb(**kwargs) + if pipe_key is not None and rv is not None: + kwargs[pipe_key] = rv + except Exception: + logger.exception( + "Hook %s callback %s failed", + name, + getattr(cb, "__qualname__", repr(cb)), + ) + + return kwargs.get(pipe_key) if pipe_key else None + + +def hookable(name: str): + """Decorator: automatically triggers name.before / name.after hook before and after the method. + + Auto-declares before/after Hooks (idempotent); no need to manually define_hook in hook_defs.py. + Supports piped return values: before can modify request, after can modify result. + Compatible with both sync and async methods. + """ + from memos.plugins.hook_defs import define_hook + + define_hook( + f"{name}.before", + description=f"Before {name} executes; can modify request", + params=["request"], + pipe_key="request", + ) + define_hook( + f"{name}.after", + description=f"After {name} executes; can modify result", + params=["request", "result"], + pipe_key="result", + ) + + def decorator(func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = await func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return async_wrapper + + @wraps(func) + def sync_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return sync_wrapper + + return decorator diff --git a/src/memos/plugins/manager.py b/src/memos/plugins/manager.py new file mode 100644 index 000000000..3706855a9 --- /dev/null +++ b/src/memos/plugins/manager.py @@ -0,0 +1,75 @@ +"""Plugin manager — discover, load, and manage MemOS plugins.""" + +from __future__ import annotations + +import importlib.metadata +import logging + +from typing import TYPE_CHECKING + +from memos.plugins.base import MemOSPlugin + + +if TYPE_CHECKING: + from fastapi import FastAPI + +logger = logging.getLogger(__name__) + +ENTRY_POINT_GROUP = "memos.plugins" + + +class PluginManager: + """Discover, load, and manage MemOS plugins.""" + + def __init__(self): + self._plugins: dict[str, MemOSPlugin] = {} + + @property + def plugins(self) -> dict[str, MemOSPlugin]: + return dict(self._plugins) + + def discover(self) -> None: + """Discover and load all installed plugins via entry_points.""" + try: + eps = importlib.metadata.entry_points() + if hasattr(eps, "select"): + plugin_eps = eps.select(group=ENTRY_POINT_GROUP) + else: + plugin_eps = eps.get(ENTRY_POINT_GROUP, []) + except Exception: + logger.exception("Failed to query entry_points") + return + + for ep in plugin_eps: + try: + plugin_cls = ep.load() + plugin = plugin_cls() + if not isinstance(plugin, MemOSPlugin): + logger.warning("Plugin %s does not extend MemOSPlugin, skipped", ep.name) + continue + plugin.on_load() + self._plugins[plugin.name] = plugin + logger.info("Plugin discovered: %s v%s", plugin.name, plugin.version) + except Exception: + logger.exception("Failed to load plugin: %s", ep.name) + + def init_app(self, app: FastAPI) -> None: + """Bind app and initialize all loaded plugins.""" + for plugin in self._plugins.values(): + try: + plugin._bind_app(app) + plugin.init_app() + logger.info("Plugin initialized: %s", plugin.name) + except Exception: + logger.exception("Failed to init plugin: %s", plugin.name) + + def shutdown(self) -> None: + """Shut down all plugins and release resources.""" + for plugin in self._plugins.values(): + try: + plugin.on_shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) + + +plugin_manager = PluginManager() diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/conftest.py b/tests/plugins/conftest.py new file mode 100644 index 000000000..6a1a16b68 --- /dev/null +++ b/tests/plugins/conftest.py @@ -0,0 +1,17 @@ +"""Ensure @hookable-generated hooks are declared for core framework tests. + +In production, @hookable("add") runs at import time of add_handler.py, +declaring add.before / add.after. Core framework tests don't import handler +modules (to avoid heavy dependencies), so we trigger declarations here. + +Plugin-specific hooks are declared in each plugin's own tests/conftest.py. +""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") +hookable("chat") +hookable("feedback") +hookable("memory.get") diff --git a/tests/plugins/run_plugin_server.py b/tests/plugins/run_plugin_server.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/test_plugin_demo.py b/tests/plugins/test_plugin_demo.py new file mode 100644 index 000000000..51997cba4 --- /dev/null +++ b/tests/plugins/test_plugin_demo.py @@ -0,0 +1,439 @@ +""" +Plugin system core framework tests. + +Covers generic capabilities of the memos.plugins package (independent of specific plugin implementations): +1. Hook declaration registry (hook_defs) +2. Hook registration and triggering / pipe_key pipeline return value +3. @hookable decorator (sync + async + auto-declaration + pipeline return value) +4. MemOSPlugin base class register_* methods + +Plugin-specific functional tests are located in each plugin package: + extensions/memos_demo_plugin/tests/ +""" + +import asyncio +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +# ========================================================================= # +# 1. Hook declaration registry (hook_defs) +# ========================================================================= # + + +class TestHookDefs: + def test_define_hook_and_get_spec(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook( + "test.custom.hook", + description="test hook", + params=["request", "result"], + pipe_key="result", + ) + + spec = get_hook_spec("test.custom.hook") + assert spec is not None + assert spec.name == "test.custom.hook" + assert spec.params == ["request", "result"] + assert spec.pipe_key == "result" + + def test_define_hook_is_idempotent(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook("test.idempotent", description="first", params=["a"], pipe_key="a") + define_hook("test.idempotent", description="second", params=["b"], pipe_key="b") + + spec = get_hook_spec("test.idempotent") + assert spec.description == "first" + + def test_get_hook_spec_returns_none_for_unknown(self): + from memos.plugins.hook_defs import get_hook_spec + + assert get_hook_spec("definitely.does.not.exist") is None + + def test_all_hook_specs_includes_custom(self): + from memos.plugins.hook_defs import H, all_hook_specs + + specs = all_hook_specs() + assert H.ADD_MEMORIES_POST_PROCESS in specs + + def test_h_constants(self): + from memos.plugins.hook_defs import H + + assert H.ADD_BEFORE == "add.before" + assert H.ADD_AFTER == "add.after" + assert H.SEARCH_BEFORE == "search.before" + assert H.SEARCH_AFTER == "search.after" + assert H.ADD_MEMORIES_POST_PROCESS == "add.memories.post_process" + + +# ========================================================================= # +# 2. Hook registration and triggering / pipe_key pipeline return value +# ========================================================================= # + + +class TestHookMechanism: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_and_trigger(self): + from memos.plugins.hooks import register_hook, trigger_hook + + captured = {} + + def my_callback(*, request, **kwargs): + captured["request"] = request + + register_hook("add.before", my_callback) + trigger_hook("add.before", request="test_request") + + assert captured["request"] == "test_request" + + def test_register_hooks_batch(self): + from memos.plugins.hooks import register_hooks, trigger_hook + + call_count = 0 + + def my_callback(**kwargs): + nonlocal call_count + call_count += 1 + + register_hooks(["add.before", "search.before"], my_callback) + trigger_hook("add.before") + trigger_hook("search.before") + + assert call_count == 2 + + def test_trigger_undeclared_hook_returns_none(self): + from memos.plugins.hooks import trigger_hook + + result = trigger_hook("nonexistent.undeclared.hook", request="anything") + assert result is None + + def test_hook_exception_does_not_propagate(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook("test.exception", description="test", params=["x"]) + + results = [] + + def bad_callback(**kwargs): + raise ValueError("intentional error") + + def good_callback(**kwargs): + results.append("ok") + + register_hook("test.exception", bad_callback) + register_hook("test.exception", good_callback) + trigger_hook("test.exception", x=1) + + assert results == ["ok"] + + def test_trigger_hook_pipe_key_returns_modified_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.pipe", + description="pipe test", + params=["request", "result"], + pipe_key="result", + ) + + def double_result(*, request, result, **kwargs): + return result * 2 + + register_hook("test.pipe", double_result) + rv = trigger_hook("test.pipe", request="req", result=5) + + assert rv == 10 + + def test_trigger_hook_pipe_key_chains_callbacks(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.chain", + description="chain test", + params=["result"], + pipe_key="result", + ) + + def add_one(*, result, **kwargs): + return result + 1 + + def add_ten(*, result, **kwargs): + return result + 10 + + register_hook("test.chain", add_one) + register_hook("test.chain", add_ten) + + rv = trigger_hook("test.chain", result=0) + assert rv == 11 + + def test_trigger_hook_pipe_key_none_callback_no_modify(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.noop", + description="noop test", + params=["result"], + pipe_key="result", + ) + + def noop(*, result, **kwargs): + return None # explicitly return None — should not modify + + register_hook("test.noop", noop) + rv = trigger_hook("test.noop", result="original") + + assert rv == "original" + + def test_trigger_hook_notification_mode(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.notify", + description="notification test", + params=["data"], + pipe_key=None, + ) + + captured = [] + + def observer(*, data, **kwargs): + captured.append(data) + + register_hook("test.notify", observer) + rv = trigger_hook("test.notify", data="hello") + + assert rv is None + assert captured == ["hello"] + + +# ========================================================================= # +# 3. @hookable decorator +# ========================================================================= # + + +class TestHookableDecorator: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_hookable_auto_declares_specs(self): + from memos.plugins.hook_defs import get_hook_spec + from memos.plugins.hooks import hookable + + @hookable("auto_test") + def dummy(self, request): + return request + + before_spec = get_hook_spec("auto_test.before") + after_spec = get_hook_spec("auto_test.after") + + assert before_spec is not None + assert before_spec.pipe_key == "request" + assert after_spec is not None + assert after_spec.pipe_key == "result" + + def test_hookable_sync(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append(("before", request)) + + def on_after(*, request, result, **kwargs): + events.append(("after", result)) + + register_hook("sync_demo.before", on_before) + register_hook("sync_demo.after", on_after) + + class FakeHandler: + @hookable("sync_demo") + def do_work(self, request): + return f"processed:{request}" + + result = FakeHandler().do_work("my_input") + + assert result == "processed:my_input" + assert events == [("before", "my_input"), ("after", "processed:my_input")] + + def test_hookable_async(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append("before") + + def on_after(*, request, result, **kwargs): + events.append("after") + + register_hook("async_demo.before", on_before) + register_hook("async_demo.after", on_after) + + class FakeHandler: + @hookable("async_demo") + async def do_work(self, request): + return "async_result" + + result = asyncio.get_event_loop().run_until_complete(FakeHandler().do_work("req")) + + assert result == "async_result" + assert events == ["before", "after"] + + def test_hookable_before_can_modify_request(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_request(*, request, **kwargs): + return "modified_request" + + register_hook("modify_req.before", rewrite_request) + + class FakeHandler: + @hookable("modify_req") + def do_work(self, request): + return f"got:{request}" + + result = FakeHandler().do_work("original") + assert result == "got:modified_request" + + def test_hookable_after_can_modify_result(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_result(*, request, result, **kwargs): + return f"{result}+modified" + + register_hook("modify_res.after", rewrite_result) + + class FakeHandler: + @hookable("modify_res") + def do_work(self, request): + return "original_result" + + result = FakeHandler().do_work("req") + assert result == "original_result+modified" + + def test_hookable_falsy_return_preserved(self): + """ensure empty list / 0 / empty string are not treated as None""" + from memos.plugins.hooks import hookable, register_hook + + def return_empty_list(*, request, result, **kwargs): + return [] + + register_hook("falsy_test.after", return_empty_list) + + class FakeHandler: + @hookable("falsy_test") + def do_work(self, request): + return [1, 2, 3] + + result = FakeHandler().do_work("req") + assert result == [] + + +# ========================================================================= # +# 4. Base class register_* methods +# ========================================================================= # + + +class TestBaseClassRegisterMethods: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_router(self): + from fastapi import APIRouter + + from memos.plugins.base import MemOSPlugin + + app = FastAPI() + plugin = MemOSPlugin() + plugin._bind_app(app) + + router = APIRouter(prefix="/test") + + @router.get("/ping") + async def ping(): + return {"pong": True} + + plugin.register_router(router) + + paths = [r.path for r in app.routes] + assert "/test/ping" in paths + + def test_register_middleware(self): + from starlette.middleware.base import BaseHTTPMiddleware + + from memos.plugins.base import MemOSPlugin + + class NoopMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) + + app = FastAPI() + + @app.get("/x") + async def x(): + return {} + + plugin = MemOSPlugin() + plugin._bind_app(app) + plugin.register_middleware(NoopMiddleware) + + client = TestClient(app) + resp = client.get("/x") + assert resp.status_code == 200 + + def test_register_hook(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("test.reg.event", description="test", params=["x"]) + + called = [] + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hook("test.reg.event", lambda **kw: called.append(True)) + + trigger_hook("test.reg.event", x=1) + assert called == [True] + + def test_register_hooks_batch(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("batch.a", description="a", params=["x"]) + define_hook("batch.b", description="b", params=["x"]) + + count = 0 + + def cb(**kw): + nonlocal count + count += 1 + + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hooks(["batch.a", "batch.b"], cb) + + trigger_hook("batch.a", x=1) + trigger_hook("batch.b", x=2) + assert count == 2 From b3bab30da6d818b628ed50be466d2c86892efd2b Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 22:06:40 +0800 Subject: [PATCH 2/4] feat: add hookable --- src/memos/api/handlers/add_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3cdbedabf..1bd52c108 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,6 +15,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable from memos.types import MessageList @@ -37,6 +38,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. From fb38bd1a58d330860ac3e59e655711f6277a6888 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 5 Mar 2026 14:28:14 +0800 Subject: [PATCH 3/4] ci: test --- tests/plugins/test_plugin_demo.py | 439 ++++++++++++++++++++++++++++++ 1 file changed, 439 insertions(+) create mode 100644 tests/plugins/test_plugin_demo.py diff --git a/tests/plugins/test_plugin_demo.py b/tests/plugins/test_plugin_demo.py new file mode 100644 index 000000000..77ea8dfce --- /dev/null +++ b/tests/plugins/test_plugin_demo.py @@ -0,0 +1,439 @@ +""" +Plugin system core framework tests. + +Covers generic capabilities of the memos.plugins package (independent of specific plugin implementations): +1. Hook declaration registry (hook_defs) +2. Hook registration and triggering / pipe_key pipeline return value +3. @hookable decorator (sync + async + auto-declaration + pipeline return value) +4. MemOSPlugin base class register_* methods + +Plugin-specific functional tests are located in each plugin package: + extensions/memos_demo_plugin/tests/ +""" + +import asyncio +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +# ========================================================================= # +# 1. Hook declaration registry (hook_defs) +# ========================================================================= # + + +class TestHookDefs: + def test_define_hook_and_get_spec(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook( + "test.custom.hook", + description="test hook", + params=["request", "result"], + pipe_key="result", + ) + + spec = get_hook_spec("test.custom.hook") + assert spec is not None + assert spec.name == "test.custom.hook" + assert spec.params == ["request", "result"] + assert spec.pipe_key == "result" + + def test_define_hook_is_idempotent(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook("test.idempotent", description="first", params=["a"], pipe_key="a") + define_hook("test.idempotent", description="second", params=["b"], pipe_key="b") + + spec = get_hook_spec("test.idempotent") + assert spec.description == "first" + + def test_get_hook_spec_returns_none_for_unknown(self): + from memos.plugins.hook_defs import get_hook_spec + + assert get_hook_spec("definitely.does.not.exist") is None + + def test_all_hook_specs_includes_custom(self): + from memos.plugins.hook_defs import H, all_hook_specs + + specs = all_hook_specs() + assert H.ADD_MEMORIES_POST_PROCESS in specs + + def test_h_constants(self): + from memos.plugins.hook_defs import H + + assert H.ADD_BEFORE == "add.before" + assert H.ADD_AFTER == "add.after" + assert H.SEARCH_BEFORE == "search.before" + assert H.SEARCH_AFTER == "search.after" + assert H.ADD_MEMORIES_POST_PROCESS == "add.memories.post_process" + + +# ========================================================================= # +# 2. Hook registration and triggering / pipe_key pipeline return value +# ========================================================================= # + + +class TestHookMechanism: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_and_trigger(self): + from memos.plugins.hooks import register_hook, trigger_hook + + captured = {} + + def my_callback(*, request, **kwargs): + captured["request"] = request + + register_hook("add.before", my_callback) + trigger_hook("add.before", request="test_request") + + assert captured["request"] == "test_request" + + def test_register_hooks_batch(self): + from memos.plugins.hooks import register_hooks, trigger_hook + + call_count = 0 + + def my_callback(**kwargs): + nonlocal call_count + call_count += 1 + + register_hooks(["add.before", "search.before"], my_callback) + trigger_hook("add.before") + trigger_hook("search.before") + + assert call_count == 2 + + def test_trigger_undeclared_hook_returns_none(self): + from memos.plugins.hooks import trigger_hook + + result = trigger_hook("nonexistent.undeclared.hook", request="anything") + assert result is None + + def test_hook_exception_does_not_propagate(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook("test.exception", description="test", params=["x"]) + + results = [] + + def bad_callback(**kwargs): + raise ValueError("intentional error") + + def good_callback(**kwargs): + results.append("ok") + + register_hook("test.exception", bad_callback) + register_hook("test.exception", good_callback) + trigger_hook("test.exception", x=1) + + assert results == ["ok"] + + def test_trigger_hook_pipe_key_returns_modified_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.pipe", + description="pipe test", + params=["request", "result"], + pipe_key="result", + ) + + def double_result(*, request, result, **kwargs): + return result * 2 + + register_hook("test.pipe", double_result) + rv = trigger_hook("test.pipe", request="req", result=5) + + assert rv == 10 + + def test_trigger_hook_pipe_key_chains_callbacks(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.chain", + description="chain test", + params=["result"], + pipe_key="result", + ) + + def add_one(*, result, **kwargs): + return result + 1 + + def add_ten(*, result, **kwargs): + return result + 10 + + register_hook("test.chain", add_one) + register_hook("test.chain", add_ten) + + rv = trigger_hook("test.chain", result=0) + assert rv == 11 + + def test_trigger_hook_pipe_key_none_callback_no_modify(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.noop", + description="noop test", + params=["result"], + pipe_key="result", + ) + + def noop(*, result, **kwargs): + return None # explicitly return None — should not modify + + register_hook("test.noop", noop) + rv = trigger_hook("test.noop", result="original") + + assert rv == "original" + + def test_trigger_hook_notification_mode(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.notify", + description="notification test", + params=["data"], + pipe_key=None, + ) + + captured = [] + + def observer(*, data, **kwargs): + captured.append(data) + + register_hook("test.notify", observer) + rv = trigger_hook("test.notify", data="hello") + + assert rv is None + assert captured == ["hello"] + + +# ========================================================================= # +# 3. @hookable decorator +# ========================================================================= # + + +class TestHookableDecorator: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_hookable_auto_declares_specs(self): + from memos.plugins.hook_defs import get_hook_spec + from memos.plugins.hooks import hookable + + @hookable("auto_test") + def dummy(self, request): + return request + + before_spec = get_hook_spec("auto_test.before") + after_spec = get_hook_spec("auto_test.after") + + assert before_spec is not None + assert before_spec.pipe_key == "request" + assert after_spec is not None + assert after_spec.pipe_key == "result" + + def test_hookable_sync(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append(("before", request)) + + def on_after(*, request, result, **kwargs): + events.append(("after", result)) + + register_hook("sync_demo.before", on_before) + register_hook("sync_demo.after", on_after) + + class FakeHandler: + @hookable("sync_demo") + def do_work(self, request): + return f"processed:{request}" + + result = FakeHandler().do_work("my_input") + + assert result == "processed:my_input" + assert events == [("before", "my_input"), ("after", "processed:my_input")] + + def test_hookable_async(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append("before") + + def on_after(*, request, result, **kwargs): + events.append("after") + + register_hook("async_demo.before", on_before) + register_hook("async_demo.after", on_after) + + class FakeHandler: + @hookable("async_demo") + async def do_work(self, request): + return "async_result" + + result = asyncio.run(FakeHandler().do_work("req")) + + assert result == "async_result" + assert events == ["before", "after"] + + def test_hookable_before_can_modify_request(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_request(*, request, **kwargs): + return "modified_request" + + register_hook("modify_req.before", rewrite_request) + + class FakeHandler: + @hookable("modify_req") + def do_work(self, request): + return f"got:{request}" + + result = FakeHandler().do_work("original") + assert result == "got:modified_request" + + def test_hookable_after_can_modify_result(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_result(*, request, result, **kwargs): + return f"{result}+modified" + + register_hook("modify_res.after", rewrite_result) + + class FakeHandler: + @hookable("modify_res") + def do_work(self, request): + return "original_result" + + result = FakeHandler().do_work("req") + assert result == "original_result+modified" + + def test_hookable_falsy_return_preserved(self): + """ensure empty list / 0 / empty string are not treated as None""" + from memos.plugins.hooks import hookable, register_hook + + def return_empty_list(*, request, result, **kwargs): + return [] + + register_hook("falsy_test.after", return_empty_list) + + class FakeHandler: + @hookable("falsy_test") + def do_work(self, request): + return [1, 2, 3] + + result = FakeHandler().do_work("req") + assert result == [] + + +# ========================================================================= # +# 4. Base class register_* methods +# ========================================================================= # + + +class TestBaseClassRegisterMethods: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_router(self): + from fastapi import APIRouter + + from memos.plugins.base import MemOSPlugin + + app = FastAPI() + plugin = MemOSPlugin() + plugin._bind_app(app) + + router = APIRouter(prefix="/test") + + @router.get("/ping") + async def ping(): + return {"pong": True} + + plugin.register_router(router) + + paths = [r.path for r in app.routes] + assert "/test/ping" in paths + + def test_register_middleware(self): + from starlette.middleware.base import BaseHTTPMiddleware + + from memos.plugins.base import MemOSPlugin + + class NoopMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) + + app = FastAPI() + + @app.get("/x") + async def x(): + return {} + + plugin = MemOSPlugin() + plugin._bind_app(app) + plugin.register_middleware(NoopMiddleware) + + client = TestClient(app) + resp = client.get("/x") + assert resp.status_code == 200 + + def test_register_hook(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("test.reg.event", description="test", params=["x"]) + + called = [] + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hook("test.reg.event", lambda **kw: called.append(True)) + + trigger_hook("test.reg.event", x=1) + assert called == [True] + + def test_register_hooks_batch(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("batch.a", description="a", params=["x"]) + define_hook("batch.b", description="b", params=["x"]) + + count = 0 + + def cb(**kw): + nonlocal count + count += 1 + + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hooks(["batch.a", "batch.b"], cb) + + trigger_hook("batch.a", x=1) + trigger_hook("batch.b", x=2) + assert count == 2 From 212aa6a99a0c608f2f13d06ce587e3489e63b7d2 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 5 Mar 2026 14:34:18 +0800 Subject: [PATCH 4/4] feat: delete add_memories hookable --- src/memos/api/handlers/add_handler.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3a04bf6a1..3cdbedabf 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,7 +15,6 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView -from memos.plugins.hooks import hookable from memos.types import MessageList @@ -38,7 +37,6 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) - @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. @@ -108,9 +106,6 @@ def _check_messages(messages: MessageList) -> None: results = cube_view.add_memories(add_req) - rv = trigger_hook(H.ADD_MEMORIES_POST_PROCESS, request=add_req, result=results) - results = rv if rv is not None else results - self.logger.info(f"[AddHandler] Final add results count={len(results)}") return MemoryResponse(