Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 136 additions & 134 deletions src/agentic_layer/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@
}


MILVUS_REPO_MAP = {
MemoryType.FORESIGHT: ForesightMilvusRepository,
MemoryType.EVENT_LOG: EventLogMilvusRepository,
MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository,
}


def _memory_types_label(memory_types) -> str:
"""Comma-joined memory type values for metric labels."""
if not memory_types:
return 'unknown'
return ','.join(mt.value for mt in memory_types)


@dataclass
class EventLogCandidate:
"""Event Log candidate object (used for retrieval from atomic_fact)"""
Expand Down Expand Up @@ -298,11 +312,7 @@ async def retrieve_mem_keyword(
) -> RetrieveMemResponse:
"""Keyword-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self.get_keyword_search_results(
Expand Down Expand Up @@ -339,11 +349,7 @@ async def get_keyword_search_results(
) -> List[Dict[str, Any]]:
"""Keyword search with stage-level metrics"""
stage_start = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
# Get parameters from Request
Expand Down Expand Up @@ -375,32 +381,41 @@ async def get_keyword_search_results(
if end_time is not None:
date_range["lte"] = end_time

mem_type = memory_types[0]

repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.warning(f"Unsupported memory_type: {mem_type}")
return []
# Iterate over ALL memory_types and merge results
all_results = []
seen_ids = set()
for mem_type in memory_types:
repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(f"Skipping unsupported memory type for keyword search: {mem_type}")
continue

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")
es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)
results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)

# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
# Deduplicate by id
rid = r.get('id', '')
if rid and rid in seen_ids:
continue
if rid:
seen_ids.add(rid)
all_results.append(r)

# Record stage metrics
record_retrieve_stage(
Expand All @@ -410,7 +425,7 @@ async def get_keyword_search_results(
duration_seconds=time.perf_counter() - stage_start,
)

return results or []
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand All @@ -433,11 +448,7 @@ async def retrieve_mem_vector(
) -> RetrieveMemResponse:
"""Vector-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self.get_vector_search_results(
Expand Down Expand Up @@ -473,11 +484,7 @@ async def get_vector_search_results(
retrieve_method: str = RetrieveMethod.VECTOR.value,
) -> List[Dict[str, Any]]:
"""Vector search with stage-level metrics (embedding + milvus_search)"""
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
# Get parameters from Request
Expand All @@ -497,7 +504,6 @@ async def get_vector_search_results(
top_k = retrieve_mem_request.top_k
start_time = retrieve_mem_request.start_time
end_time = retrieve_mem_request.end_time
mem_type = retrieve_mem_request.memory_types[0]

logger.debug(
f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
Expand All @@ -506,7 +512,7 @@ async def get_vector_search_results(
# Get vectorization service
vectorize_service = get_vectorize_service()

# Convert query text to vector (embedding stage)
# Convert query text to vector (embedding stage) — only once
logger.debug(f"Starting to vectorize query text: {query}")
embedding_start = time.perf_counter()
query_vector = await vectorize_service.get_embedding(query)
Expand All @@ -521,87 +527,95 @@ async def get_vector_search_results(
f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
)

# Select Milvus repository based on memory type
match mem_type:
case MemoryType.FORESIGHT:
milvus_repo = get_bean_by_type(ForesightMilvusRepository)
case MemoryType.EVENT_LOG:
milvus_repo = get_bean_by_type(EventLogMilvusRepository)
case MemoryType.EPISODIC_MEMORY:
milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
case _:
raise ValueError(f"Unsupported memory type: {mem_type}")
# Iterate over ALL memory_types and merge results
all_results = []
seen_ids = set()
milvus_start = time.perf_counter()

# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None
for mem_type in retrieve_mem_request.memory_types:
milvus_repo_class = MILVUS_REPO_MAP.get(mem_type)
if not milvus_repo_class:
logger.info(f"Skipping unsupported memory type for vector search: {mem_type}")
continue

if start_time is not None:
start_time_dt = (
from_iso_format(start_time)
if isinstance(start_time, str)
else start_time
)
milvus_repo = get_bean_by_type(milvus_repo_class)

if end_time is not None:
if isinstance(end_time, str):
end_time_dt = from_iso_format(end_time)
# If date only format, set to end of day
if len(end_time) == 10:
end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59)
# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None

if start_time is not None:
start_time_dt = (
from_iso_format(start_time)
if isinstance(start_time, str)
else start_time
)

if end_time is not None:
if isinstance(end_time, str):
end_time_dt = from_iso_format(end_time)
# If date only format, set to end of day
if len(end_time) == 10:
end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59)
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Deduplicate by id
rid = r.get('id', '')
if rid and rid in seen_ids:
continue
if rid:
seen_ids.add(rid)
all_results.append(r)

record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename

return search_results
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand All @@ -624,11 +638,7 @@ async def retrieve_mem_hybrid(
) -> RetrieveMemResponse:
"""Hybrid memory retrieval: keyword + vector + rerank"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self._search_hybrid(
Expand Down Expand Up @@ -699,9 +709,7 @@ async def _search_hybrid(
retrieve_method: str = RetrieveMethod.HYBRID.value,
) -> List[Dict]:
"""Core hybrid search: keyword + vector + rerank, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
)
memory_type = _memory_types_label(request.memory_types)
# Run keyword and vector search concurrently
kw_results, vec_results = await asyncio.gather(
self.get_keyword_search_results(request, retrieve_method=retrieve_method),
Expand All @@ -722,9 +730,7 @@ async def _search_rrf(
retrieve_method: str = RetrieveMethod.RRF.value,
) -> List[Dict]:
"""Core RRF search: keyword + vector + RRF fusion, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
)
memory_type = _memory_types_label(request.memory_types)

# Run keyword and vector search concurrently
kw, vec = await asyncio.gather(
Expand Down Expand Up @@ -766,7 +772,7 @@ async def _to_response(
"""Convert flat hits list to grouped RetrieveMemResponse"""
user_id = req.user_id if req else ""
source_type = req.retrieve_method.value
memory_type = req.memory_types[0].value
memory_type = _memory_types_label(req.memory_types)

if not hits:
return RetrieveMemResponse(
Expand Down Expand Up @@ -808,11 +814,7 @@ async def retrieve_mem_rrf(
) -> RetrieveMemResponse:
"""RRF-based memory retrieval: keyword + vector + RRF fusion"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
else 'unknown'
)
memory_type = _memory_types_label(retrieve_mem_request.memory_types)

try:
hits = await self._search_rrf(
Expand Down Expand Up @@ -855,7 +857,7 @@ async def retrieve_mem_agentic(
req = retrieve_mem_request # alias
top_k = req.top_k
config = AgenticConfig()
memory_type = req.memory_types[0].value if req.memory_types else 'unknown'
memory_type = _memory_types_label(req.memory_types)

try:
llm_provider = LLMProvider(
Expand Down