diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 7f61d54ac..1a3c328f1 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -15,6 +15,7 @@ from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig +from memos.context.context import ContextThread from memos.mem_cube.general import GeneralMemCube @@ -178,7 +179,6 @@ def start_watch_if_enabled(cls) -> None: if not enable: return interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60")) - import threading def _loop() -> None: while True: @@ -188,7 +188,7 @@ def _loop() -> None: logger.error(f"❌ Nacos watch loop error: {e}") time.sleep(interval) - threading.Thread(target=_loop, daemon=True).start() + ContextThread(target=_loop, daemon=True).start() logger.info(f"Nacos watch thread started (interval={interval}s).") @classmethod diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index b3b457c36..028fe8e3f 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -12,6 +12,7 @@ from sqlalchemy.engine import Engine from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig +from memos.context.context import ContextThread from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube @@ -689,7 +690,7 @@ def start(self) -> None: logger.info("Message consumer process started") else: # Default to thread mode - self._consumer_thread = threading.Thread( + self._consumer_thread = ContextThread( target=self._message_consumer, daemon=True, name="MessageConsumerThread", diff --git a/src/memos/mem_scheduler/general_modules/task_threads.py b/src/memos/mem_scheduler/general_modules/task_threads.py index 551e8b726..73b570a8b 100644 --- a/src/memos/mem_scheduler/general_modules/task_threads.py +++ b/src/memos/mem_scheduler/general_modules/task_threads.py @@ -5,6 +5,7 @@ from concurrent.futures import as_completed from typing import Any, TypeVar +from memos.context.context import ContextThread from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule @@ -138,7 +139,7 @@ def worker(task_name: str, func: Callable, args: tuple): # Start all threads for task_name, (func, args) in tasks.items(): - thread = threading.Thread( + thread = ContextThread( target=worker, args=(task_name, func, args), name=f"task-{task_name}" ) threads[task_name] = thread @@ -283,7 +284,7 @@ def run_race( # Create and start threads for each task for task_name, task_func in tasks.items(): - thread = threading.Thread( + thread = ContextThread( target=self.worker, args=(task_func, task_name), name=f"race-{task_name}" ) self.threads[task_name] = thread diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 0ebb7da4f..46c4e2d49 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -4,7 +4,7 @@ from time import perf_counter from memos.configs.mem_scheduler import BaseSchedulerConfig -from memos.context.context import ContextThreadPoolExecutor +from memos.context.context import ContextThread, ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher @@ -340,7 +340,7 @@ def start(self) -> bool: return False self._running = True - self._monitor_thread = threading.Thread( + self._monitor_thread = ContextThread( target=self._monitor_loop, name="threadpool_monitor", daemon=True ) self._monitor_thread.start() diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index b240f4369..3c0dff907 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -6,6 +6,7 @@ from pathlib import Path from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig +from memos.context.context import ContextThread from memos.dependency import require_python_package from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule @@ -96,7 +97,7 @@ def initialize_rabbitmq( ) # Start IOLoop in dedicated thread - self._io_loop_thread = threading.Thread( + self._io_loop_thread = ContextThread( target=self.rabbitmq_connection.ioloop.start, daemon=True ) self._io_loop_thread.start() diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index d86911e82..5439af9c6 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -1,12 +1,12 @@ import asyncio import os import subprocess -import threading import time from collections.abc import Callable from typing import Any +from memos.context.context import ContextThread from memos.dependency import require_python_package from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule @@ -41,7 +41,7 @@ def __init__(self): self.query_list_capacity = 1000 self._redis_listener_running = False - self._redis_listener_thread: threading.Thread | None = None + self._redis_listener_thread: ContextThread | None = None self._redis_listener_loop: asyncio.AbstractEventLoop | None = None @property @@ -336,7 +336,7 @@ def redis_start_listening(self, handler: Callable | None = None): if handler is None: handler = self.redis_consume_message_stream - self._redis_listener_thread = threading.Thread( + self._redis_listener_thread = ContextThread( target=self._redis_run_listener_async, args=(handler,), daemon=True, diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index 0337225d1..ea06a7c60 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -1,5 +1,4 @@ import json -import threading import time import traceback @@ -10,7 +9,7 @@ import numpy as np -from memos.context.context import ContextThreadPoolExecutor +from memos.context.context import ContextThread, ContextThreadPoolExecutor from memos.dependency import require_python_package from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.item import GraphDBEdge, GraphDBNode @@ -94,12 +93,12 @@ def __init__( self._reorganize_needed = True if self.is_reorganize: # ____ 1. For queue message driven thread ___________ - self.thread = threading.Thread(target=self._run_message_consumer_loop) + self.thread = ContextThread(target=self._run_message_consumer_loop) self.thread.start() # ____ 2. For periodic structure optimization _______ self._stop_scheduler = False self._is_optimizing = {"LongTermMemory": False, "UserMemory": False} - self.structure_optimizer_thread = threading.Thread( + self.structure_optimizer_thread = ContextThread( target=self._run_structure_organizer_loop ) self.structure_optimizer_thread.start()