diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 488f59625..443aa1f3d 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -2,8 +2,6 @@ Request context middleware for automatic trace_id injection. """ -import json -import os import time from collections.abc import Callable @@ -19,9 +17,6 @@ logger = memos.log.get_logger(__name__) -# Maximum body size to read for logging (in bytes) - bodies larger than this will be skipped -MAX_BODY_LOG_SIZE = os.getenv("MAX_BODY_LOG_SIZE", 10 * 1024) - def extract_trace_id_from_headers(request: Request) -> str | None: """Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id.""" @@ -31,127 +26,6 @@ def extract_trace_id_from_headers(request: Request) -> str | None: return None -def _is_json_request(request: Request) -> tuple[bool, str]: - """ - Check if request is a JSON request. - - Args: - request: The request object - - Returns: - Tuple of (is_json, content_type) - """ - if request.method not in ("POST", "PUT", "PATCH", "DELETE"): - return False, "" - - content_type = request.headers.get("content-type", "") - if not content_type: - return False, "" - - is_json = "application/json" in content_type.lower() - return is_json, content_type - - -def _should_read_body(content_length: str | None) -> tuple[bool, int | None]: - """ - Check if body should be read based on content-length header. - - Args: - content_length: Content-Length header value - - Returns: - Tuple of (should_read, body_size). body_size is None if header is invalid. - """ - if not content_length: - return True, None - - try: - body_size = int(content_length) - return body_size <= MAX_BODY_LOG_SIZE, body_size - except ValueError: - return True, None - - -def _create_body_info(content_type: str, body_size: int) -> dict: - """Create body_info dict for large bodies that are skipped.""" - return { - "content_type": content_type, - "content_length": body_size, - "note": f"body too large ({body_size} bytes), skipping read", - } - - -def _parse_json_body(body_bytes: bytes) -> dict | str: - """ - Parse JSON body bytes. - - Args: - body_bytes: Raw body bytes - - Returns: - Parsed JSON dict, or error message string if parsing fails - """ - try: - return json.loads(body_bytes) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - return f"" - - -async def get_request_params(request: Request) -> tuple[dict, bytes | None]: - """ - Extract request parameters (query params and body) for logging. - - Only reads body for application/json requests that are within size limits. - - This function is wrapped with exception handling to ensure logging failures - don't affect the actual request processing. - - Args: - request: The incoming request object - - Returns: - Tuple of (params_dict, body_bytes). body_bytes is None if body was not read. - Returns empty dict and None on any error. - """ - try: - params_log = {} - - # Check if this is a JSON request - is_json, content_type = _is_json_request(request) - if not is_json: - return params_log, None - - # Pre-check body size using content-length header - content_length = request.headers.get("content-length") - should_read, body_size = _should_read_body(content_length) - - if not should_read and body_size is not None: - params_log["body_info"] = _create_body_info(content_type, body_size) - return params_log, None - - # Read body - body_bytes = await request.body() - - if not body_bytes: - return params_log, None - - # Post-check: verify actual size (content-length might be missing or wrong) - actual_size = len(body_bytes) - if actual_size > MAX_BODY_LOG_SIZE: - params_log["body_info"] = _create_body_info(content_type, actual_size) - return params_log, None - - # Parse JSON body - params_log["body"] = _parse_json_body(body_bytes) - return params_log, body_bytes - - except Exception as e: - # Catch-all for any unexpected errors - logger.error(f"Unexpected error in get_request_params: {e}", exc_info=True) - # Return empty dict to ensure request can continue - return {}, None - - class RequestContextMiddleware(BaseHTTPMiddleware): """ Middleware to automatically inject request context for every HTTP request. @@ -193,26 +67,9 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: ) set_request_context(context) - # Get request parameters for logging - # Wrap in try-catch to ensure logging failures don't break the request - params_log, body_bytes = await get_request_params(request) - - # Re-create the request receive function if body was read - # This ensures downstream handlers can still read the body - if body_bytes is not None: - try: - - async def receive(): - return {"type": "http.request", "body": body_bytes, "more_body": False} - - request._receive = receive - except Exception as e: - logger.error(f"Failed to recreate request receive function: {e}") - # Continue without restoring body, downstream handlers will handle it - logger.info( f"Request started, source: {self.source}, method: {request.method}, path: {request.url.path}, " - f"request params: {params_log}, headers: {request.headers}" + f"headers: {request.headers}" ) # Process the request