Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 34 additions & 101 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,32 +1703,19 @@ def search_by_fulltext(
return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
"""
Full-text search functionality using PostgreSQL's full-text search capabilities.

Args:
query_text: query text
top_k: maximum number of results to return
scope: memory type filter (memory_type)
status: status filter, defaults to "activated"
threshold: similarity threshold filter
search_filter: additional property filter conditions
user_name: username filter
knowledgebase_ids: knowledgebase ids filter
filter: filter conditions with 'and' or 'or' logic for search results.
tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1
tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation)
return_fields: additional node fields to include in results
**kwargs: other parameters (e.g. cube_name)

Returns:
list[dict]: result list containing id and score.
If return_fields is specified, each dict also includes the requested fields.
"""
start_time = time.perf_counter()
logger.info(
f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}"
" search_by_fulltext query_words=%s top_k=%s scope=%s status=%s threshold=%s search_filter=%s user_name=%s knowledgebase_ids=%s filter=%s",
query_words,
top_k,
scope,
status,
threshold,
search_filter,
user_name,
knowledgebase_ids,
filter,
)
start_time = time.time()
where_clauses = []

if scope:
Expand All @@ -1744,22 +1731,18 @@ def search_by_fulltext(
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
)

# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
default_user_name=self.config.user_name,
)
logger.info(f"[search_by_fulltext] user_name_conditions: {user_name_conditions}")

# Add OR condition if we have any user_name conditions
if user_name_conditions:
if len(user_name_conditions) == 1:
where_clauses.append(user_name_conditions[0])
else:
where_clauses.append(f"({' OR '.join(user_name_conditions)})")

# Add search_filter conditions
if search_filter:
for key, value in search_filter.items():
if isinstance(value, str):
Expand All @@ -1772,17 +1755,12 @@ def search_by_fulltext(
)

filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}")

where_clauses.extend(filter_conditions)
tsquery_string = " | ".join(query_words)

where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")

where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

logger.info(f"[search_by_fulltext] where_clause: {where_clause}")

select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id,
ts_rank(m.{tsvector_field}, q.fq) AS rank"""
if return_fields:
Expand All @@ -1808,7 +1786,8 @@ def search_by_fulltext(
LIMIT {top_k};
"""
params = [tsquery_string]
logger.info(f"[search_by_fulltext] query: {query}, params: {params}")
logger.info("search_by_fulltext query=%s params=%s", query, params)

with self._get_connection() as conn, conn.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
Expand All @@ -1829,8 +1808,8 @@ def search_by_fulltext(
properties = row[2] # properties column
item.update(self._extract_fields_from_properties(properties, return_fields))
output.append(item)
elapsed_time = time.time() - start_time
logger.info(f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s")
elapsed = (time.perf_counter() - start_time) * 1000
logger.info("search_by_fulltext internal took %.1f ms", elapsed)
return output[:top_k]

@timed
Expand All @@ -1849,9 +1828,18 @@ def search_by_embedding(
**kwargs,
) -> list[dict]:
logger.info(
f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids},return_fields:{return_fields}"
"search_by_embedding user_name:%s,filter: %s, knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s",
user_name,
filter,
knowledgebase_ids,
scope,
status,
search_filter,
filter,
knowledgebase_ids,
return_fields,
)
start_time = time.time()
start_time = time.perf_counter()
where_clauses = []
if scope:
where_clauses.append(
Expand Down Expand Up @@ -1890,7 +1878,6 @@ def search_by_embedding(
)

filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}")
where_clauses.extend(filter_conditions)

where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
Expand Down Expand Up @@ -1926,7 +1913,7 @@ def search_by_embedding(
else:
pass

logger.info(f"[search_by_embedding] query: {query}, params: {params}")
logger.info(" search_by_embedding query: %s", query)

with self._get_connection() as conn, conn.cursor() as cursor:
if params:
Expand All @@ -1952,9 +1939,9 @@ def search_by_embedding(
properties = row[1] # properties column
item.update(self._extract_fields_from_properties(properties, return_fields))
output.append(item)
elapsed_time = time.time() - start_time
elapsed_time = time.perf_counter() - start_time
logger.info(
f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s"
"search_by_embedding query embedding completed time took %.1f ms", elapsed_time
)
return output[:top_k]

Expand Down Expand Up @@ -3169,27 +3156,15 @@ def add_nodes_batch(
nodes: list[dict[str, Any]],
user_name: str | None = None,
) -> None:
"""
Batch add multiple memory nodes to the graph.
logger.info(f" add_nodes_batch Processing only first node (total nodes: {len(nodes)})")

Args:
nodes: List of node dictionaries, each containing:
- id: str - Node ID
- memory: str - Memory content
- metadata: dict[str, Any] - Node metadata
user_name: Optional user name (will use config default if not provided)
"""
batch_start_time = time.time()
batch_start_time = time.perf_counter()
if not nodes:
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
return

logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})")

# user_name comes from parameter; fallback to config if missing
effective_user_name = user_name if user_name else self.config.user_name

# Prepare all nodes
prepared_nodes = []
for node_data in nodes:
try:
Expand All @@ -3199,16 +3174,13 @@ def add_nodes_batch(

logger.debug(f"[add_nodes_batch] Processing node id: {id}")

# Set user_name in metadata
metadata["user_name"] = effective_user_name

metadata = _prepare_node_metadata(metadata)

# Merge node and set metadata
created_at = metadata.pop("created_at", datetime.utcnow().isoformat())
updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat())

# Prepare properties
properties = {
"id": id,
"memory": memory,
Expand All @@ -3219,32 +3191,26 @@ def add_nodes_batch(
**metadata,
}

# Generate embedding if not provided
if "embedding" not in properties or not properties["embedding"]:
properties["embedding"] = generate_vector(
self._get_config_value("embedding_dimension", 1024)
)

# Serialization - JSON-serialize sources and usage fields
for field_name in ["sources", "usage"]:
if properties.get(field_name):
if isinstance(properties[field_name], list):
for idx in range(len(properties[field_name])):
# Serialize only when element is not a string
if not isinstance(properties[field_name][idx], str):
properties[field_name][idx] = json.dumps(
properties[field_name][idx]
)
elif isinstance(properties[field_name], str):
# If already a string, leave as-is
pass

# Extract embedding for separate column
embedding_vector = properties.pop("embedding", [])
if not isinstance(embedding_vector, list):
embedding_vector = []

# Select column name based on embedding dimension
embedding_column = "embedding" # default column
if len(embedding_vector) == 3072:
embedding_column = "embedding_3072"
Expand All @@ -3267,14 +3233,12 @@ def add_nodes_batch(
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
exc_info=True,
)
# Continue with other nodes
continue

if not prepared_nodes:
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
return

# Group nodes by embedding column to optimize batch inserts
nodes_by_embedding_column = {}
for node in prepared_nodes:
col = node["embedding_column"]
Expand All @@ -3284,9 +3248,7 @@ def add_nodes_batch(

try:
with self._get_connection() as conn, conn.cursor() as cursor:
# Process each group separately
for embedding_column, nodes_group in nodes_by_embedding_column.items():
# Batch delete existing records using IN clause
ids_to_delete = [node["id"] for node in nodes_group]
if ids_to_delete:
delete_query = f"""
Expand All @@ -3297,7 +3259,6 @@ def add_nodes_batch(
"""
cursor.execute(delete_query, (ids_to_delete,))

# Batch get graph_ids for all nodes
get_graph_ids_query = f"""
SELECT
id_val,
Expand All @@ -3307,21 +3268,16 @@ def add_nodes_batch(
cursor.execute(get_graph_ids_query, (ids_to_delete,))
graph_id_map = {row[0]: row[1] for row in cursor.fetchall()}

# Add graph_id to properties
for node in nodes_group:
graph_id = graph_id_map.get(node["id"])
if graph_id:
node["properties"]["graph_id"] = str(graph_id)

# Use PREPARE/EXECUTE for efficient batch insert
# Generate unique prepare statement name to avoid conflicts
prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}"

try:
if embedding_column and any(
node["embedding_vector"] for node in nodes_group
):
# PREPARE statement for insert with embedding
prepare_query = f"""
PREPARE {prepare_name} AS
INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})
Expand All @@ -3331,16 +3287,9 @@ def add_nodes_batch(
$3::vector
)
"""
logger.info(
f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}"
)
logger.info(
f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}"
)

cursor.execute(prepare_query)

# Execute prepared statement for each node
for node in nodes_group:
properties_json = json.dumps(node["properties"])
embedding_json = (
Expand All @@ -3354,7 +3303,6 @@ def add_nodes_batch(
(node["id"], properties_json, embedding_json),
)
else:
# PREPARE statement for insert without embedding
prepare_query = f"""
PREPARE {prepare_name} AS
INSERT INTO {self.db_name}_graph."Memory"(id, properties)
Expand All @@ -3363,40 +3311,25 @@ def add_nodes_batch(
$2::text::agtype
)
"""
logger.info(
f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}"
)
logger.info(
f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}"
)
cursor.execute(prepare_query)

# Execute prepared statement for each node
for node in nodes_group:
properties_json = json.dumps(node["properties"])

cursor.execute(
f"EXECUTE {prepare_name}(%s, %s)",
(node["id"], properties_json),
)
finally:
# DEALLOCATE prepared statement (always execute, even on error)
try:
cursor.execute(f"DEALLOCATE {prepare_name}")
logger.info(
f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}"
)
except Exception as dealloc_error:
logger.warning(
f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}"
)

logger.info(
f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}"
)
elapsed_time = time.time() - batch_start_time
elapsed_time = time.perf_counter() - batch_start_time
logger.info(
f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s"
"add_nodes_batch batch insert completed successfully in took %.1f ms",
elapsed_time,
)

except Exception as e:
Expand Down