Skip to content
Merged
7 changes: 7 additions & 0 deletions src/memos/configs/mem_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ class RabbitMQConfig(
ge=1, # Port must be >= 1
le=65535, # Port must be <= 65535
)
exchange_name: str = Field(
default="memos-fanout",
description="Exchange name for RabbitMQ (e.g., memos-fanout, memos-memory-change)",
)
exchange_type: str = Field(
default="fanout", description="Exchange type for RabbitMQ (fanout or direct)"
)


class GraphDBAuthConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
Expand Down
2 changes: 1 addition & 1 deletion src/memos/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def close(self):
},
"handlers": {
"console": {
"level": selected_log_level,
"level": "DEBUG",
"class": "logging.StreamHandler",
"stream": stdout,
"formatter": "no_datetime",
Expand Down
56 changes: 30 additions & 26 deletions src/memos/mem_scheduler/general_modules/scheduler_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,21 @@ def log_working_memory_replacement(
or getattr(itm.metadata, "update_at", None),
}
)
ev = self.create_event_log(
label="scheduleMemory",
from_memory_type=TEXT_MEMORY_TYPE,
to_memory_type=WORKING_MEMORY_TYPE,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
memcube_log_content=memcube_content,
metadata=meta,
memory_len=len(memcube_content),
memcube_name=self._map_memcube_name(mem_cube_id),
)
log_func_callback([ev])
# Only create log if there are actual memory changes
if memcube_content:
ev = self.create_event_log(
label="scheduleMemory",
from_memory_type=TEXT_MEMORY_TYPE,
to_memory_type=WORKING_MEMORY_TYPE,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
memcube_log_content=memcube_content,
metadata=meta,
memory_len=len(memcube_content),
memcube_name=self._map_memcube_name(mem_cube_id),
)
log_func_callback([ev])

@log_exceptions(logger=logger)
def log_activation_memory_update(
Expand Down Expand Up @@ -235,19 +237,21 @@ def log_activation_memory_update(
"updated_at": None,
}
)
ev = self.create_event_log(
label="scheduleMemory",
from_memory_type=ACTIVATION_MEMORY_TYPE,
to_memory_type=PARAMETER_MEMORY_TYPE,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
memcube_log_content=memcube_content,
metadata=meta,
memory_len=len(added_memories),
memcube_name=self._map_memcube_name(mem_cube_id),
)
log_func_callback([ev])
# Only create log if there are actual memory changes
if memcube_content:
ev = self.create_event_log(
label="scheduleMemory",
from_memory_type=ACTIVATION_MEMORY_TYPE,
to_memory_type=PARAMETER_MEMORY_TYPE,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
memcube_log_content=memcube_content,
metadata=meta,
memory_len=len(added_memories),
memcube_name=self._map_memcube_name(mem_cube_id),
)
log_func_callback([ev])

@log_exceptions(logger=logger)
def log_adding_memory(
Expand Down
18 changes: 14 additions & 4 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
is_all_english,
transform_name_to_key,
)
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.memories.textual.item import TextualMemoryItem
from memos.memories.textual.preference import PreferenceTextMemory
from memos.memories.textual.tree import TreeTextMemory
Expand Down Expand Up @@ -157,7 +158,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
"""
logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.")

grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages)
grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)

self.validate_schedule_messages(messages=messages, label=QUERY_LABEL)

Expand Down Expand Up @@ -201,7 +202,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
messages: List of answer messages to process
"""
logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.")
grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages)
grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)

self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL)

Expand Down Expand Up @@ -237,7 +238,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.")
# Process the query in a session turn
grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages)
grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)

self.validate_schedule_messages(messages=messages, label=ADD_LABEL)
try:
Expand Down Expand Up @@ -758,8 +759,17 @@ def process_message(message: ScheduleMessageItem):

# Get the preference memory from the mem_cube
pref_mem = mem_cube.pref_mem
if pref_mem is None:
logger.warning(
f"Preference memory not initialized for mem_cube_id={mem_cube_id}, "
f"skipping pref_add processing"
)
return
if not isinstance(pref_mem, PreferenceTextMemory):
logger.error(f"Expected PreferenceTextMemory but got {type(pref_mem).__name__}")
logger.error(
f"Expected PreferenceTextMemory but got {type(pref_mem).__name__} "
f"for mem_cube_id={mem_cube_id}"
)
return

# Use pref_mem.get_memory to process the memories
Expand Down
19 changes: 17 additions & 2 deletions src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(self):
# RabbitMQ settings
self.rabbitmq_config: RabbitMQConfig | None = None
self.rabbit_queue_name = "memos-scheduler"
self.rabbitmq_exchange_name = "memos-fanout"
self.rabbitmq_exchange_type = FANOUT_EXCHANGE_TYPE
self.rabbitmq_exchange_name = "memos-fanout" # Default, will be overridden by config
self.rabbitmq_exchange_type = FANOUT_EXCHANGE_TYPE # Default, will be overridden by config
self.rabbitmq_connection = None
self.rabbitmq_channel = None

Expand Down Expand Up @@ -87,6 +87,21 @@ def initialize_rabbitmq(
else:
logger.error("Not implemented")

# Load exchange configuration from config
if self.rabbitmq_config:
if (
hasattr(self.rabbitmq_config, "exchange_name")
and self.rabbitmq_config.exchange_name
):
self.rabbitmq_exchange_name = self.rabbitmq_config.exchange_name
logger.info(f"Using configured exchange name: {self.rabbitmq_exchange_name}")
if (
hasattr(self.rabbitmq_config, "exchange_type")
and self.rabbitmq_config.exchange_type
):
self.rabbitmq_exchange_type = self.rabbitmq_config.exchange_type
logger.info(f"Using configured exchange type: {self.rabbitmq_exchange_type}")

# Start connection process
parameters = self.get_rabbitmq_connection_param()
self.rabbitmq_connection = SelectConnection(
Expand Down
Loading