diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index e255b1a48..2b481d5c6 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -653,6 +653,11 @@ def scheduler_status(user_name: str | None = None): cube = getattr(task, "mem_cube_id", "unknown") task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 + try: + metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot() + except Exception: + metrics_snapshot = {} + return { "message": "ok", "data": { @@ -661,6 +666,7 @@ def scheduler_status(user_name: str | None = None): "task_count_per_user": task_count_per_user, "timestamp": time.time(), "instance_id": INSTANCE_ID, + "metrics": metrics_snapshot, }, } except Exception as err: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index c2f606146..b3b457c36 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -49,7 +49,6 @@ from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.memos_tools.notification_utils import send_online_bot_notification from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE @@ -127,21 +126,6 @@ def __init__(self, config: BaseSchedulerConfig): "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS ) - # queue monitor (optional) - self._queue_monitor_thread: threading.Thread | None = None - self._queue_monitor_running: bool = False - self.queue_monitor_interval_seconds: float = self.config.get( - "queue_monitor_interval_seconds", 60.0 - ) - self.queue_monitor_warn_utilization: float = self.config.get( - "queue_monitor_warn_utilization", 0.7 - ) - self.queue_monitor_crit_utilization: float = self.config.get( - "queue_monitor_crit_utilization", 0.9 - ) - self.enable_queue_monitor: bool = self.config.get("enable_queue_monitor", False) - self._online_bot_callable = None # type: ignore[var-annotated] - # other attributes self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None @@ -541,6 +525,10 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt logger.error(error_msg) raise TypeError(error_msg) + if getattr(message, "timestamp", None) is None: + with contextlib.suppress(Exception): + message.timestamp = datetime.utcnow() + if self.disable_handlers and message.label in self.disable_handlers: logger.info(f"Skipping disabled handler: {message.label} - {message.content}") continue @@ -555,6 +543,9 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt logger.info( f"Submitted message to local queue: {message.label} - {message.content}" ) + with contextlib.suppress(Exception): + if messages: + self.dispatcher.on_messages_enqueued(messages) def _submit_web_logs( self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem] @@ -706,13 +697,6 @@ def start(self) -> None: self._consumer_thread.start() logger.info("Message consumer thread started") - # optionally start queue monitor if enabled and bot callable present - if self.enable_queue_monitor and self._online_bot_callable is not None: - try: - self.start_queue_monitor(self._online_bot_callable) - except Exception as e: - logger.warning(f"Failed to start queue monitor: {e}") - def stop(self) -> None: """Stop all scheduler components gracefully. @@ -762,9 +746,6 @@ def stop(self) -> None: self._cleanup_queues() logger.info("Memory Scheduler stopped completely") - # Stop queue monitor - self.stop_queue_monitor() - @property def handlers(self) -> dict[str, Callable]: """ @@ -997,16 +978,6 @@ def _fmt_eta(seconds: float | None) -> str: return True - # ---------------- Queue monitor & notifications ---------------- - def set_notification_bots(self, online_bot=None): - """ - Set external notification callables. - - Args: - online_bot: a callable matching dinding_report_bot.online_bot signature - """ - self._online_bot_callable = online_bot - def _gather_queue_stats(self) -> dict: """Collect queue/dispatcher stats for reporting.""" stats: dict[str, int | float | str] = {} @@ -1044,71 +1015,3 @@ def _gather_queue_stats(self) -> dict: except Exception: stats.update({"running": 0, "inflight": 0, "handlers": 0}) return stats - - def _queue_monitor_loop(self, online_bot) -> None: - logger.info(f"Queue monitor started (interval={self.queue_monitor_interval_seconds}s)") - self._queue_monitor_running = True - while self._queue_monitor_running: - time.sleep(self.queue_monitor_interval_seconds) - try: - stats = self._gather_queue_stats() - # decide severity based on utilization if local queue - title_color = "#00956D" - subtitle = "Scheduler" - if not stats.get("use_redis_queue"): - util = float(stats.get("utilization", 0.0)) - if util >= self.queue_monitor_crit_utilization: - title_color = "#C62828" # red - subtitle = "Scheduler (CRITICAL)" - elif util >= self.queue_monitor_warn_utilization: - title_color = "#E65100" # orange - subtitle = "Scheduler (WARNING)" - - other_data1 = { - "use_redis_queue": stats.get("use_redis_queue"), - "handlers": stats.get("handlers"), - "running": stats.get("running"), - "inflight": stats.get("inflight"), - } - if not stats.get("use_redis_queue"): - other_data2 = { - "qsize": stats.get("qsize"), - "unfinished_tasks": stats.get("unfinished_tasks"), - "maxsize": stats.get("maxsize"), - "utilization": f"{float(stats.get('utilization', 0.0)):.2%}", - } - else: - other_data2 = { - "redis_mode": True, - } - - send_online_bot_notification( - online_bot=online_bot, - header_name="Scheduler Queue", - sub_title_name=subtitle, - title_color=title_color, - other_data1=other_data1, - other_data2=other_data2, - emoji={"Runtime": "🧠", "Queue": "πŸ“¬"}, - ) - except Exception as e: - logger.warning(f"Queue monitor iteration failed: {e}") - logger.info("Queue monitor stopped") - - def start_queue_monitor(self, online_bot) -> None: - if self._queue_monitor_thread and self._queue_monitor_thread.is_alive(): - return - self._online_bot_callable = online_bot - self._queue_monitor_thread = threading.Thread( - target=self._queue_monitor_loop, - args=(online_bot,), - daemon=True, - name="QueueMonitorThread", - ) - self._queue_monitor_thread.start() - - def stop_queue_monitor(self) -> None: - self._queue_monitor_running = False - if self._queue_monitor_thread and self._queue_monitor_thread.is_alive(): - with contextlib.suppress(Exception): - self._queue_monitor_thread.join(timeout=2.0) diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index 997b01302..c2407b9e6 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -1,8 +1,10 @@ import concurrent import threading +import time from collections import defaultdict from collections.abc import Callable +from datetime import timezone from typing import Any from memos.context.context import ContextThreadPoolExecutor @@ -11,6 +13,7 @@ from memos.mem_scheduler.general_modules.task_threads import ThreadManager from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.utils.metrics import MetricsRegistry logger = get_logger(__name__) @@ -70,6 +73,19 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): self._completed_tasks = [] self.completed_tasks_max_show_size = 10 + self.metrics = MetricsRegistry( + topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50) + ) + + def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None: + if not msgs: + return + now = time.time() + for m in msgs: + self.metrics.on_enqueue( + label=m.label, mem_cube_id=m.mem_cube_id, inst_rate=1.0, now=now + ) + def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): """ Create a wrapper around the handler to track task execution and capture results. @@ -84,9 +100,37 @@ def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): def wrapped_handler(messages: list[ScheduleMessageItem]): try: + # --- mark start: record queuing time(now - enqueue_ts)--- + now = time.time() + for m in messages: + enq_ts = getattr(m, "timestamp", None) + + # Path 1: epoch seconds (preferred) + if isinstance(enq_ts, int | float): + enq_epoch = float(enq_ts) + + # Path 2: datetime -> normalize to UTC epoch + elif hasattr(enq_ts, "timestamp"): + dt = enq_ts + if dt.tzinfo is None: + # treat naive as UTC to neutralize +8h skew + dt = dt.replace(tzinfo=timezone.utc) + enq_epoch = dt.timestamp() + else: + # fallback: treat as "just now" + enq_epoch = now + + wait_sec = max(0.0, now - enq_epoch) + self.metrics.on_start( + label=m.label, mem_cube_id=m.mem_cube_id, wait_sec=wait_sec, now=now + ) + # Execute the original handler result = handler(messages) + # --- mark done --- + for m in messages: + self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) # Mark task as completed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: @@ -100,6 +144,9 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): except Exception as e: # Mark task as failed and remove from tracking + for m in messages: + self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) + # Mark task as failed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: task_item.mark_failed(str(e)) diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py new file mode 100644 index 000000000..5155c98b3 --- /dev/null +++ b/src/memos/mem_scheduler/utils/metrics.py @@ -0,0 +1,250 @@ +# metrics.py +from __future__ import annotations + +import threading +import time + +from dataclasses import dataclass, field + + +# ==== global window config ==== +WINDOW_SEC = 120 # 2 minutes sliding window + + +# ---------- O(1) EWMA ---------- +class Ewma: + """ + Time-decayed EWMA: + """ + + __slots__ = ("alpha", "last_ts", "tau", "value") + + def __init__(self, alpha: float = 0.3, tau: float = WINDOW_SEC): + self.alpha = alpha + self.value = 0.0 + self.last_ts: float = time.time() + self.tau = max(1e-6, float(tau)) + + def _decay_to(self, now: float | None = None): + now = time.time() if now is None else now + dt = max(0.0, now - self.last_ts) + if dt <= 0: + return + from math import exp + + self.value *= exp(-dt / self.tau) + self.last_ts = now + + def update(self, instant: float, now: float | None = None): + self._decay_to(now) + self.value = self.alpha * instant + (1 - self.alpha) * self.value + + def value_at(self, now: float | None = None) -> float: + now = time.time() if now is None else now + dt = max(0.0, now - self.last_ts) + if dt <= 0: + return self.value + from math import exp + + return self.value * exp(-dt / self.tau) + + +# ---------- approximate P95(Reservoir sample) ---------- +class ReservoirP95: + __slots__ = ("_i", "buf", "k", "n", "window") + + def __init__(self, k: int = 512, window: float = WINDOW_SEC): + self.k = k + self.buf: list[tuple[float, float]] = [] # (value, ts) + self.n = 0 + self._i = 0 + self.window = float(window) + + def _gc(self, now: float): + win_start = now - self.window + self.buf = [p for p in self.buf if p[1] >= win_start] + if self.buf: + self._i %= len(self.buf) + else: + self._i = 0 + + def add(self, x: float, now: float | None = None): + now = time.time() if now is None else now + self._gc(now) + self.n += 1 + if len(self.buf) < self.k: + self.buf.append((x, now)) + return + self.buf[self._i] = (x, now) + self._i = (self._i + 1) % self.k + + def p95(self, now: float | None = None) -> float: + now = time.time() if now is None else now + self._gc(now) + if not self.buf: + return 0.0 + arr = sorted(v for v, _ in self.buf) + idx = int(0.95 * (len(arr) - 1)) + return arr[idx] + + +# ---------- Space-Saving Top-K ---------- +class SpaceSaving: + """only topK:add(key) O(1),query topk O(K log K)""" + + def __init__(self, k: int = 100): + self.k = k + self.cnt: dict[str, int] = {} + + def add(self, key: str): + if key in self.cnt: + self.cnt[key] += 1 + return + if len(self.cnt) < self.k: + self.cnt[key] = 1 + return + victim = min(self.cnt, key=self.cnt.get) + self.cnt[key] = self.cnt.pop(victim) + 1 + + def topk(self) -> list[tuple[str, int]]: + return sorted(self.cnt.items(), key=lambda kv: kv[1], reverse=True) + + +@dataclass +class KeyStats: + backlog: int = 0 + lambda_ewma: Ewma = field(default_factory=lambda: Ewma(0.3, WINDOW_SEC)) + mu_ewma: Ewma = field(default_factory=lambda: Ewma(0.3, WINDOW_SEC)) + wait_p95: ReservoirP95 = field(default_factory=lambda: ReservoirP95(512, WINDOW_SEC)) + last_ts: float = field(default_factory=time.time) + # last event timestamps for rate estimation + last_enqueue_ts: float | None = None + last_done_ts: float | None = None + + def snapshot(self, now: float | None = None) -> dict: + now = time.time() if now is None else now + lam = self.lambda_ewma.value_at(now) + mu = self.mu_ewma.value_at(now) + delta = mu - lam + eta = float("inf") if delta <= 1e-9 else self.backlog / delta + return { + "backlog": self.backlog, + "lambda": round(lam, 3), + "mu": round(mu, 3), + "delta": round(delta, 3), + "eta_sec": None if eta == float("inf") else round(eta, 1), + "wait_p95_sec": round(self.wait_p95.p95(now), 3), + } + + +class MetricsRegistry: + """ + metrics: + - 1st phase:label(must) + - 2nd phase:labelXmem_cube_id(only Top-K) + - on_enqueue(label, mem_cube_id) + - on_start(label, mem_cube_id, wait_sec) + - on_done(label, mem_cube_id) + """ + + def __init__(self, topk_per_label: int = 50): + self._lock = threading.RLock() + self._label_stats: dict[str, KeyStats] = {} + self._label_topk: dict[str, SpaceSaving] = {} + self._detail_stats: dict[tuple[str, str], KeyStats] = {} + self._topk_per_label = topk_per_label + + # ---------- helpers ---------- + def _get_label(self, label: str) -> KeyStats: + if label not in self._label_stats: + self._label_stats[label] = KeyStats() + self._label_topk[label] = SpaceSaving(self._topk_per_label) + return self._label_stats[label] + + def _get_detail(self, label: str, mem_cube_id: str) -> KeyStats | None: + # εͺζœ‰ Top-K ηš„ mem_cube_id 才建细粒度 key + ss = self._label_topk[label] + if mem_cube_id in ss.cnt or len(ss.cnt) < ss.k: + key = (label, mem_cube_id) + if key not in self._detail_stats: + self._detail_stats[key] = KeyStats() + return self._detail_stats[key] + return None + + # ---------- events ---------- + def on_enqueue( + self, label: str, mem_cube_id: str, inst_rate: float = 1.0, now: float | None = None + ): + with self._lock: + now = time.time() if now is None else now + ls = self._get_label(label) + # derive instantaneous arrival rate from inter-arrival time (events/sec) + prev_ts = ls.last_enqueue_ts + dt = (now - prev_ts) if prev_ts is not None else None + inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 # first sample: no spike + ls.last_enqueue_ts = now + ls.backlog += 1 + old_lam = ls.lambda_ewma.value_at(now) + ls.lambda_ewma.update(inst_rate, now) + new_lam = ls.lambda_ewma.value_at(now) + print( + f"[DEBUG enqueue] {label} backlog={ls.backlog} dt={dt if dt is not None else 'β€”'}s inst={inst_rate:.3f} Ξ» {old_lam:.3f}β†’{new_lam:.3f}" + ) + self._label_topk[label].add(mem_cube_id) + ds = self._get_detail(label, mem_cube_id) + if ds: + prev_ts_d = ds.last_enqueue_ts + dt_d = (now - prev_ts_d) if prev_ts_d is not None else None + inst_rate_d = (1.0 / max(1e-3, dt_d)) if dt_d is not None else 0.0 + ds.last_enqueue_ts = now + ds.backlog += 1 + ds.lambda_ewma.update(inst_rate_d, now) + + def on_start(self, label: str, mem_cube_id: str, wait_sec: float, now: float | None = None): + with self._lock: + now = time.time() if now is None else now + ls = self._get_label(label) + ls.wait_p95.add(wait_sec, now) + ds = self._detail_stats.get((label, mem_cube_id)) + if ds: + ds.wait_p95.add(wait_sec, now) + + def on_done( + self, label: str, mem_cube_id: str, inst_rate: float = 1.0, now: float | None = None + ): + with self._lock: + now = time.time() if now is None else now + ls = self._get_label(label) + # derive instantaneous service rate from inter-completion time (events/sec) + prev_ts = ls.last_done_ts + dt = (now - prev_ts) if prev_ts is not None else None + inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 + ls.last_done_ts = now + if ls.backlog > 0: + ls.backlog -= 1 + old_mu = ls.mu_ewma.value_at(now) + ls.mu_ewma.update(inst_rate, now) + new_mu = ls.mu_ewma.value_at(now) + print( + f"[DEBUG done] {label} backlog={ls.backlog} dt={dt if dt is not None else 'β€”'}s inst={inst_rate:.3f} ΞΌ {old_mu:.3f}β†’{new_mu:.3f}" + ) + ds = self._detail_stats.get((label, mem_cube_id)) + if ds: + prev_ts_d = ds.last_done_ts + dt_d = (now - prev_ts_d) if prev_ts_d is not None else None + inst_rate_d = (1.0 / max(1e-3, dt_d)) if dt_d is not None else 0.0 + ds.last_done_ts = now + if ds.backlog > 0: + ds.backlog -= 1 + ds.mu_ewma.update(inst_rate_d, now) + + # ---------- snapshots ---------- + def snapshot(self) -> dict: + with self._lock: + now = time.time() + by_label = {lbl: ks.snapshot(now) for lbl, ks in self._label_stats.items()} + heavy = {lbl: self._label_topk[lbl].topk() for lbl in self._label_topk} + details = {} + for (lbl, cube), ks in self._detail_stats.items(): + details.setdefault(lbl, {})[cube] = ks.snapshot(now) + return {"by_label": by_label, "heavy": heavy, "details": details}