Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,19 +853,21 @@ def _submit_web_logs(
return

for message in messages:
try:
self._web_log_message_queue.put(message)
except Exception as e:
logger.warning(f"Failed to put message to web log queue: {e}", stack_info=True)

message_info = message.debug_info()
logger.debug(f"Submitted Scheduling log for web: {message_info}")
logger.info(f"[DIAGNOSTIC] base_scheduler._submit_web_logs: submitted {message_info}")

# Always call publish; the publisher now caches when offline and flushes after reconnect
logger.info(
f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message_info}"
)
self.rabbitmq_publish_message(message=message.to_dict())
logger.info(
"[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched "
"item_id=%s task_id=%s label=%s",
message.item_id,
message.task_id,
message.label,
)
logger.debug(
f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}"
)
Expand Down
26 changes: 16 additions & 10 deletions src/memos/mem_scheduler/general_modules/scheduler_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def create_autofilled_log_item(
user_id: str,
mem_cube_id: str,
mem_cube: GeneralMemCube,
item_id: str | None = None,
) -> ScheduleLogForWebItem:
if mem_cube is None:
logger.error(
Expand Down Expand Up @@ -94,16 +95,19 @@ def create_autofilled_log_item(
)
memory_capacities["parameter_memory_capacity"] = 1

log_message = ScheduleLogForWebItem(
user_id=user_id,
mem_cube_id=mem_cube_id,
label=label,
from_memory_type=from_memory_type,
to_memory_type=to_memory_type,
log_content=log_content,
current_memory_sizes=current_memory_sizes,
memory_capacities=memory_capacities,
)
log_kwargs = {
"user_id": user_id,
"mem_cube_id": mem_cube_id,
"label": label,
"from_memory_type": from_memory_type,
"to_memory_type": to_memory_type,
"log_content": log_content,
"current_memory_sizes": current_memory_sizes,
"memory_capacities": memory_capacities,
}
if item_id:
log_kwargs["item_id"] = item_id
log_message = ScheduleLogForWebItem(**log_kwargs)
return log_message

@log_exceptions(logger=logger)
Expand All @@ -120,6 +124,7 @@ def create_event_log(
memory_len: int,
memcube_name: str | None = None,
log_content: str | None = None,
item_id: str | None = None,
) -> ScheduleLogForWebItem:
item = self.create_autofilled_log_item(
log_content=log_content or "",
Expand All @@ -129,6 +134,7 @@ def create_event_log(
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
item_id=item_id,
)
item.memcube_log_content = memcube_log_content
item.metadata = metadata
Expand Down
12 changes: 12 additions & 0 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
metadata=[],
memory_len=1,
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
self._submit_web_logs([event])
Expand Down Expand Up @@ -322,6 +323,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
metadata=[],
memory_len=1,
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
self._submit_web_logs([event])
Expand Down Expand Up @@ -492,6 +494,7 @@ def send_add_log_messages_to_local_env(
metadata=add_meta_legacy,
memory_len=len(add_content_legacy),
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
events.append(event)
Expand All @@ -507,6 +510,7 @@ def send_add_log_messages_to_local_env(
metadata=update_meta_legacy,
memory_len=len(update_content_legacy),
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
events.append(event)
Expand Down Expand Up @@ -573,6 +577,7 @@ def send_add_log_messages_to_cloud_env(
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
event.task_id = msg.task_id
Expand Down Expand Up @@ -719,6 +724,7 @@ def _extract_fields(mem_item):
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=message.item_id,
)
event.log_content = (
f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
Expand Down Expand Up @@ -788,6 +794,7 @@ def process_message(message: ScheduleMessageItem):
user_name=user_name,
custom_tags=info.get("custom_tags", None),
task_id=message.task_id,
item_id=message.item_id,
info=info,
)

Expand Down Expand Up @@ -815,6 +822,7 @@ def _process_memories_with_reader(
user_name: str,
custom_tags: list[str] | None = None,
task_id: str | None = None,
item_id: str | None = None,
info: dict | None = None,
) -> None:
logger.info(
Expand Down Expand Up @@ -934,6 +942,7 @@ def _process_memories_with_reader(
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=item_id,
)
event.log_content = (
f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
Expand Down Expand Up @@ -979,6 +988,7 @@ def _process_memories_with_reader(
metadata=add_meta_legacy,
memory_len=len(add_content_legacy),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=item_id,
)
event.task_id = task_id
self._submit_web_logs([event])
Expand Down Expand Up @@ -1045,6 +1055,7 @@ def _process_memories_with_reader(
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=item_id,
)
event.log_content = f"Knowledge Base Memory Update failed: {exc!s}"
event.task_id = task_id
Expand Down Expand Up @@ -1212,6 +1223,7 @@ def process_message(message: ScheduleMessageItem):
metadata=meta,
memory_len=len(keys),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=message.item_id,
)
self._submit_web_logs([event])

Expand Down
2 changes: 1 addition & 1 deletion src/memos/mem_scheduler/schemas/message_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def debug_info(self) -> dict[str, Any]:
"""Return structured debug information for logging purposes."""
return {
"content_preview:": self.log_content[:50],
"log_id": self.item_id,
"item_id": self.item_id,
"user_id": self.user_id,
"mem_cube_id": self.mem_cube_id,
"operation": f"{self.from_memory_type} → {self.to_memory_type}",
Expand Down
6 changes: 6 additions & 0 deletions src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _maybe_emit_task_completion(
# messages in one batch can belong to different business task_ids; check each
task_ids = set()
task_id_to_doc_id = {}
task_id_to_item_id = {}

for msg in messages:
tid = getattr(msg, "task_id", None)
Expand All @@ -340,6 +341,8 @@ def _maybe_emit_task_completion(
sid = info.get("source_doc_id")
if sid:
task_id_to_doc_id[tid] = sid
if tid not in task_id_to_item_id:
task_id_to_item_id[tid] = msg.item_id

if not task_ids:
return
Expand All @@ -356,6 +359,7 @@ def _maybe_emit_task_completion(

for task_id in task_ids:
source_doc_id = task_id_to_doc_id.get(task_id)
event_item_id = task_id_to_item_id.get(task_id)
status_data = self.status_tracker.get_task_status_by_business_id(
business_task_id=task_id, user_id=user_id
)
Expand All @@ -369,6 +373,7 @@ def _maybe_emit_task_completion(
# (Although if status is 'completed', local error shouldn't happen theoretically,
# unless status update lags or is inconsistent. We trust status_tracker here.)
event = ScheduleLogForWebItem(
item_id=event_item_id,
task_id=task_id,
user_id=user_id,
mem_cube_id=mem_cube_id,
Expand All @@ -393,6 +398,7 @@ def _maybe_emit_task_completion(
error_msg = "Unknown error (check system logs)"

event = ScheduleLogForWebItem(
item_id=event_item_id,
task_id=task_id,
user_id=user_id,
mem_cube_id=mem_cube_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,15 @@ def rabbitmq_publish_message(self, message: dict):
logger.debug(f"Published message: {message}")
return True
except Exception as e:
logger.error(
"[DIAGNOSTIC] RabbitMQ publish error. label=%s item_id=%s exchange=%s "
"routing_key=%s error=%s",
label,
message.get("item_id"),
exchange_name,
routing_key,
e,
)
logger.error(f"Failed to publish message: {e}")
# Cache message for retry on next connection
self.rabbitmq_publish_cache.put(message)
Expand Down
41 changes: 9 additions & 32 deletions tests/mem_scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,44 +139,21 @@ def test_submit_web_logs(self):
},
)

# Empty the queue by consuming all elements
while not self.scheduler._web_log_message_queue.empty():
self.scheduler._web_log_message_queue.get()
self.scheduler.rabbitmq_config = MagicMock()
self.scheduler.rabbitmq_publish_message = MagicMock()

# Submit the log message
self.scheduler._submit_web_logs(messages=log_message)

# Verify the message was added to the queue
self.assertEqual(self.scheduler._web_log_message_queue.qsize(), 1)

# Get the actual message from the queue
actual_message = self.scheduler._web_log_message_queue.get()

# Verify core fields
self.assertEqual(actual_message.user_id, "test_user")
self.assertEqual(actual_message.mem_cube_id, "test_cube")
self.assertEqual(actual_message.label, QUERY_TASK_LABEL)
self.assertEqual(actual_message.from_memory_type, "WorkingMemory")
self.assertEqual(actual_message.to_memory_type, "LongTermMemory")
self.assertEqual(actual_message.log_content, "Test Content")

# Verify memory sizes
self.assertEqual(actual_message.current_memory_sizes["long_term_memory_size"], 0)
self.assertEqual(actual_message.current_memory_sizes["user_memory_size"], 0)
self.assertEqual(actual_message.current_memory_sizes["working_memory_size"], 0)
self.assertEqual(actual_message.current_memory_sizes["transformed_act_memory_size"], 0)

# Verify memory capacities
self.assertEqual(actual_message.memory_capacities["long_term_memory_capacity"], 1000)
self.assertEqual(actual_message.memory_capacities["user_memory_capacity"], 500)
self.assertEqual(actual_message.memory_capacities["working_memory_capacity"], 100)
self.assertEqual(actual_message.memory_capacities["transformed_act_memory_capacity"], 0)
self.scheduler.rabbitmq_publish_message.assert_called_once_with(
message=log_message.to_dict()
)

# Verify auto-generated fields exist
self.assertTrue(hasattr(actual_message, "item_id"))
self.assertTrue(isinstance(actual_message.item_id, str))
self.assertTrue(hasattr(actual_message, "timestamp"))
self.assertTrue(isinstance(actual_message.timestamp, datetime))
self.assertTrue(hasattr(log_message, "item_id"))
self.assertTrue(isinstance(log_message.item_id, str))
self.assertTrue(hasattr(log_message, "timestamp"))
self.assertTrue(isinstance(log_message.timestamp, datetime))

def test_activation_memory_update(self):
"""Test activation memory update functionality with DynamicCache handling."""
Expand Down