diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 74865b66f..488f59625 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -162,6 +162,17 @@ class RequestContextMiddleware(BaseHTTPMiddleware): 3. Ensures the context is available throughout the request lifecycle """ + def __init__(self, app, source: str | None = None): + """ + Initialize the middleware. + + Args: + app: The ASGI application + source: Source identifier (e.g., 'product' or 'server') to distinguish request origin + """ + super().__init__(app) + self.source = source or "api" + async def dispatch(self, request: Request, call_next: Callable) -> Response: # Extract or generate trace_id trace_id = extract_trace_id_from_headers(request) or generate_trace_id() @@ -178,6 +189,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: env=env, user_type=user_type, user_name=user_name, + source=self.source, ) set_request_context(context) @@ -199,7 +211,7 @@ async def receive(): # Continue without restoring body, downstream handlers will handle it logger.info( - f"Request started, method: {request.method}, path: {request.url.path}, " + f"Request started, source: {self.source}, method: {request.method}, path: {request.url.path}, " f"request params: {params_log}, headers: {request.headers}" ) @@ -209,16 +221,16 @@ async def receive(): end_time = time.time() if response.status_code == 200: logger.info( - f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + f"Request completed: source: {self.source}, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" ) else: logger.error( - f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + f"Request Failed: source: {self.source}, path: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" ) except Exception as e: end_time = time.time() logger.error( - f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" + f"Request Exception Error: source: {self.source}, path: {request.url.path}, error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" ) raise e diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py index 709ad74fb..ec5cccae1 100644 --- a/src/memos/api/product_api.py +++ b/src/memos/api/product_api.py @@ -17,7 +17,7 @@ version="1.0.1", ) -app.add_middleware(RequestContextMiddleware) +app.add_middleware(RequestContextMiddleware, source="product_api") # Include routers app.include_router(product_router) diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 24c67de48..0dfef99d9 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -18,7 +18,7 @@ version="1.0.1", ) -app.add_middleware(RequestContextMiddleware) +app.add_middleware(RequestContextMiddleware, source="server_api") # Include routers app.include_router(server_router) diff --git a/src/memos/context/context.py b/src/memos/context/context.py index d6a0f3bf1..b5d4c24fe 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -36,12 +36,14 @@ def __init__( env: str | None = None, user_type: str | None = None, user_name: str | None = None, + source: str | None = None, ): self.trace_id = trace_id or "trace-id" self.api_path = api_path self.env = env self.user_type = user_type self.user_name = user_name + self.source = source self._data: dict[str, Any] = {} def set(self, key: str, value: Any) -> None: @@ -59,6 +61,7 @@ def __setattr__(self, name: str, value: Any) -> None: "env", "user_type", "user_name", + "source", ): super().__setattr__(name, value) else: @@ -80,6 +83,7 @@ def to_dict(self) -> dict[str, Any]: "env": self.env, "user_type": self.user_type, "user_name": self.user_name, + "source": self.source, "data": self._data.copy(), } @@ -146,6 +150,16 @@ def get_current_user_name() -> str | None: return "memos" +def get_current_source() -> str | None: + """ + Get the current request's source (e.g., 'product_api' or 'server_api'). + """ + context = _request_context.get() + if context: + return context.get("source") + return None + + def get_current_context() -> RequestContext | None: """ Get the current request context. @@ -161,6 +175,7 @@ def get_current_context() -> RequestContext | None: env=context_dict.get("env"), user_type=context_dict.get("user_type"), user_name=context_dict.get("user_name"), + source=context_dict.get("source"), ) ctx._data = context_dict.get("data", {}).copy() return ctx