diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 2922ab3eb..74865b66f 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -2,6 +2,8 @@ Request context middleware for automatic trace_id injection. """ +import json +import os import time from collections.abc import Callable @@ -17,6 +19,9 @@ logger = memos.log.get_logger(__name__) +# Maximum body size to read for logging (in bytes) - bodies larger than this will be skipped +MAX_BODY_LOG_SIZE = os.getenv("MAX_BODY_LOG_SIZE", 10 * 1024) + def extract_trace_id_from_headers(request: Request) -> str | None: """Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id.""" @@ -26,6 +31,127 @@ def extract_trace_id_from_headers(request: Request) -> str | None: return None +def _is_json_request(request: Request) -> tuple[bool, str]: + """ + Check if request is a JSON request. + + Args: + request: The request object + + Returns: + Tuple of (is_json, content_type) + """ + if request.method not in ("POST", "PUT", "PATCH", "DELETE"): + return False, "" + + content_type = request.headers.get("content-type", "") + if not content_type: + return False, "" + + is_json = "application/json" in content_type.lower() + return is_json, content_type + + +def _should_read_body(content_length: str | None) -> tuple[bool, int | None]: + """ + Check if body should be read based on content-length header. + + Args: + content_length: Content-Length header value + + Returns: + Tuple of (should_read, body_size). body_size is None if header is invalid. + """ + if not content_length: + return True, None + + try: + body_size = int(content_length) + return body_size <= MAX_BODY_LOG_SIZE, body_size + except ValueError: + return True, None + + +def _create_body_info(content_type: str, body_size: int) -> dict: + """Create body_info dict for large bodies that are skipped.""" + return { + "content_type": content_type, + "content_length": body_size, + "note": f"body too large ({body_size} bytes), skipping read", + } + + +def _parse_json_body(body_bytes: bytes) -> dict | str: + """ + Parse JSON body bytes. + + Args: + body_bytes: Raw body bytes + + Returns: + Parsed JSON dict, or error message string if parsing fails + """ + try: + return json.loads(body_bytes) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + return f"" + + +async def get_request_params(request: Request) -> tuple[dict, bytes | None]: + """ + Extract request parameters (query params and body) for logging. + + Only reads body for application/json requests that are within size limits. + + This function is wrapped with exception handling to ensure logging failures + don't affect the actual request processing. + + Args: + request: The incoming request object + + Returns: + Tuple of (params_dict, body_bytes). body_bytes is None if body was not read. + Returns empty dict and None on any error. + """ + try: + params_log = {} + + # Check if this is a JSON request + is_json, content_type = _is_json_request(request) + if not is_json: + return params_log, None + + # Pre-check body size using content-length header + content_length = request.headers.get("content-length") + should_read, body_size = _should_read_body(content_length) + + if not should_read and body_size is not None: + params_log["body_info"] = _create_body_info(content_type, body_size) + return params_log, None + + # Read body + body_bytes = await request.body() + + if not body_bytes: + return params_log, None + + # Post-check: verify actual size (content-length might be missing or wrong) + actual_size = len(body_bytes) + if actual_size > MAX_BODY_LOG_SIZE: + params_log["body_info"] = _create_body_info(content_type, actual_size) + return params_log, None + + # Parse JSON body + params_log["body"] = _parse_json_body(body_bytes) + return params_log, body_bytes + + except Exception as e: + # Catch-all for any unexpected errors + logger.error(f"Unexpected error in get_request_params: {e}", exc_info=True) + # Return empty dict to ensure request can continue + return {}, None + + class RequestContextMiddleware(BaseHTTPMiddleware): """ Middleware to automatically inject request context for every HTTP request. @@ -55,14 +181,27 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: ) set_request_context(context) - # Log request start with parameters - params_log = {} + # Get request parameters for logging + # Wrap in try-catch to ensure logging failures don't break the request + params_log, body_bytes = await get_request_params(request) + + # Re-create the request receive function if body was read + # This ensures downstream handlers can still read the body + if body_bytes is not None: + try: - # Get query parameters - if request.query_params: - params_log["query_params"] = dict(request.query_params) + async def receive(): + return {"type": "http.request", "body": body_bytes, "more_body": False} - logger.info(f"Request started, params: {params_log}, headers: {request.headers}") + request._receive = receive + except Exception as e: + logger.error(f"Failed to recreate request receive function: {e}") + # Continue without restoring body, downstream handlers will handle it + + logger.info( + f"Request started, method: {request.method}, path: {request.url.path}, " + f"request params: {params_log}, headers: {request.headers}" + ) # Process the request try: diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 2b481d5c6..684e02a0c 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -412,6 +412,8 @@ def _search_pref(): search_req.include_preference, ) + logger.info(f"Search memories result: {memories_result}") + return SearchResponse( message="Search completed successfully", data=memories_result, @@ -618,6 +620,9 @@ def _process_pref_mem() -> list[dict[str, str]]: text_response_data = text_future.result() pref_response_data = pref_future.result() + logger.info(f"add_memories Text response data: {text_response_data}") + logger.info(f"add_memories Pref response data: {pref_response_data}") + return MemoryResponse( message="Memory added successfully", data=text_response_data + pref_response_data,