Skip to content
15 changes: 7 additions & 8 deletions examples/mem_scheduler/memos_w_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import re
import shutil
import sys

from datetime import datetime
from pathlib import Path
from queue import Queue

from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
from datetime import datetime
import re

from memos.configs.mem_scheduler import AuthConfig
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_os.main import MOS
from memos.mem_scheduler.general_scheduler import GeneralScheduler
from memos.mem_scheduler.schemas.general_schemas import (
QUERY_LABEL,
ANSWER_LABEL,
ADD_LABEL,
ANSWER_LABEL,
MEM_ARCHIVE_LABEL,
MEM_ORGANIZE_LABEL,
MEM_UPDATE_LABEL,
MEM_ARCHIVE_LABEL,
NOT_APPLICABLE_TYPE,
QUERY_LABEL,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem
from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
from memos.mem_scheduler.general_scheduler import GeneralScheduler


FILE_PATH = Path(__file__).absolute()
Expand Down
446 changes: 386 additions & 60 deletions src/memos/api/handlers/chat_handler.py

Large diffs are not rendered by default.

39 changes: 38 additions & 1 deletion src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from memos.api.config import APIConfig
from memos.api.handlers.config_builders import (
build_chat_llm_config,
build_embedder_config,
build_graph_db_config,
build_internet_retriever_config,
Expand Down Expand Up @@ -77,6 +78,38 @@ def _get_default_memory_size(cube_config: Any) -> dict[str, int]:
}


def _init_chat_llms(chat_llm_configs: list[dict]) -> dict[str, Any]:
"""
Initialize chat language models from configuration.

Args:
chat_llm_configs: List of chat LLM configuration dictionaries

Returns:
Dictionary mapping model names to initialized LLM instances
"""

def _list_models(client):
try:
models = (
[model.id for model in client.models.list().data]
if client.models.list().data
else client.models.list().models
)
except Exception as e:
logger.error(f"Error listing models: {e}")
models = []
return models

model_name_instrance_maping = {}
for cfg in chat_llm_configs:
llm = LLMFactory.from_config(cfg["config_class"])
if cfg["support_models"]:
for model_name in cfg["support_models"]:
model_name_instrance_maping[model_name] = llm
return model_name_instrance_maping


def init_server() -> dict[str, Any]:
"""
Initialize all server components and configurations.
Expand Down Expand Up @@ -104,6 +137,7 @@ def init_server() -> dict[str, Any]:
# Build component configurations
graph_db_config = build_graph_db_config()
llm_config = build_llm_config()
chat_llm_config = build_chat_llm_config()
embedder_config = build_embedder_config()
mem_reader_config = build_mem_reader_config()
reranker_config = build_reranker_config()
Expand All @@ -123,13 +157,16 @@ def init_server() -> dict[str, Any]:
else None
)
llm = LLMFactory.from_config(llm_config)
chat_llms = _init_chat_llms(chat_llm_config)
embedder = EmbedderFactory.from_config(embedder_config)
mem_reader = MemReaderFactory.from_config(mem_reader_config)
reranker = RerankerFactory.from_config(reranker_config)
internet_retriever = InternetRetrieverFactory.from_config(
internet_retriever_config, embedder=embedder
)

# Initialize chat llms

logger.debug("Core components instantiated")

# Initialize memory manager
Expand Down Expand Up @@ -234,7 +271,6 @@ def init_server() -> dict[str, Any]:
tree_mem: TreeTextMemory = naive_mem_cube.text_mem
searcher: Searcher = tree_mem.get_searcher(
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
moscube=False,
)
logger.debug("Searcher created")

Expand Down Expand Up @@ -276,6 +312,7 @@ def init_server() -> dict[str, Any]:
"graph_db": graph_db,
"mem_reader": mem_reader,
"llm": llm,
"chat_llms": chat_llms,
"embedder": embedder,
"reranker": reranker,
"internet_retriever": internet_retriever,
Expand Down
27 changes: 27 additions & 0 deletions src/memos/api/handlers/config_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
a configuration dictionary using the appropriate ConfigFactory.
"""

import json
import os

from typing import Any
Expand Down Expand Up @@ -81,6 +82,32 @@ def build_llm_config() -> dict[str, Any]:
)


def build_chat_llm_config() -> list[dict[str, Any]]:
"""
Build chat LLM configuration.

Returns:
Validated chat LLM configuration dictionary
"""
configs = json.loads(os.getenv("CHAT_MODEL_LIST"))
return [
{
"config_class": LLMConfigFactory.model_validate(
{
"backend": cfg.get("backend", "openai"),
"config": (
{k: v for k, v in cfg.items() if k not in ["backend", "support_models"]}
)
if cfg
else APIConfig.get_openai_config(),
}
),
"support_models": cfg.get("support_models", None),
}
for cfg in configs
]


def build_embedder_config() -> dict[str, Any]:
"""
Build embedder configuration.
Expand Down
43 changes: 42 additions & 1 deletion src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

from typing import Any, Literal

from memos.api.product_models import MemoryResponse
from memos.api.handlers.formatters_handler import format_memory_item
from memos.api.product_models import (
DeleteMemoryRequest,
DeleteMemoryResponse,
GetMemoryRequest,
GetMemoryResponse,
MemoryResponse,
)
from memos.log import get_logger
from memos.mem_os.utils.format_utils import (
convert_graph_to_tree_forworkmem,
Expand Down Expand Up @@ -149,3 +156,37 @@ def handle_get_subgraph(
except Exception as e:
logger.error(f"Failed to get subgraph: {e}", exc_info=True)
raise


def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse:
# TODO: Implement get memory with filter
memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"]
filter_params: dict[str, Any] = {}
if get_mem_req.user_id is not None:
filter_params["user_id"] = get_mem_req.user_id
if get_mem_req.mem_cube_id is not None:
filter_params["mem_cube_id"] = get_mem_req.mem_cube_id
preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params)
return GetMemoryResponse(
message="Memories retrieved successfully",
data={
"text_mem": memories,
"pref_mem": [format_memory_item(mem) for mem in preferences],
},
)


def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: Any):
try:
naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids)
naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
except Exception as e:
logger.error(f"Failed to delete memories: {e}", exc_info=True)
return DeleteMemoryResponse(
message="Failed to delete memories",
data="failure",
)
return DeleteMemoryResponse(
message="Memories deleted successfully",
data={"status": "success"},
)
8 changes: 4 additions & 4 deletions src/memos/api/handlers/scheduler_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def handle_scheduler_status(
user_name: str | None = None,
mem_cube_id: str | None = None,
mem_scheduler: Any | None = None,
instance_id: str = "",
) -> dict[str, Any]:
Expand All @@ -43,17 +43,17 @@ def handle_scheduler_status(
HTTPException: If status retrieval fails
"""
try:
if user_name:
if mem_cube_id:
running = mem_scheduler.dispatcher.get_running_tasks(
lambda task: getattr(task, "mem_cube_id", None) == user_name
lambda task: getattr(task, "mem_cube_id", None) == mem_cube_id
)
tasks_iter = to_iter(running)
running_count = len(tasks_iter)
return {
"message": "ok",
"data": {
"scope": "user",
"user_name": user_name,
"mem_cube_id": mem_cube_id,
"running_tasks": running_count,
"timestamp": time.time(),
"instance_id": instance_id,
Expand Down
Loading
Loading