diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/auto_memory_wrapper/agent.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/auto_memory_wrapper/agent.py index 2f1e8e1353..652c44d6e6 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/auto_memory_wrapper/agent.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/auto_memory_wrapper/agent.py @@ -72,15 +72,25 @@ def _get_user_id_from_context(self) -> str: Extract user_id from runtime context. Priority order: - 1. user_manager.get_id() - For authenticated sessions (set via SessionManager.session()) - 2. X-User-ID HTTP header - For testing/simple auth without middleware - 3. "default_user" - Fallback for development/testing without authentication + 1. Context.user_id - For authenticated sessions (set via SessionManager.session()) + 2. user_manager.get_id() - Legacy/custom context compatibility + 3. X-User-ID HTTP header - For testing/simple auth without middleware + 4. "default_user" - Fallback for development/testing without authentication Returns: str: The user ID for memory operations """ - # Priority 1: Get user_id from user_manager (for authenticated sessions) - user_manager = self._context.user_manager + # Priority 1: Get user_id from the runtime context. + try: + user_id = self._context.user_id + if user_id: + logger.debug(f"Using user_id from context: {user_id}") + return user_id + except Exception as e: + logger.debug(f"Failed to get user_id from context: {e}") + + # Priority 2: Get user_id from user_manager (legacy/custom context compatibility) + user_manager = getattr(self._context, "user_manager", None) if user_manager and hasattr(user_manager, 'get_id'): try: user_id = user_manager.get_id() @@ -90,9 +100,11 @@ def _get_user_id_from_context(self) -> str: except Exception as e: logger.debug(f"Failed to get user_id from user_manager: {e}") - # Priority 2: Extract from X-User-ID HTTP header (temporary workaround for testing) - if self._context.metadata and self._context.metadata.headers: - user_id = self._context.metadata.headers.get("x-user-id") + # Priority 3: Extract from X-User-ID HTTP header (temporary workaround for testing) + metadata = getattr(self._context, "metadata", None) + headers = getattr(metadata, "headers", None) if metadata else None + if headers: + user_id = headers.get("x-user-id") or headers.get("X-User-ID") if user_id: logger.debug(f"Using user_id from X-User-ID header: {user_id}") return user_id diff --git a/packages/nvidia_nat_langchain/tests/agent/test_auto_memory_wrapper.py b/packages/nvidia_nat_langchain/tests/agent/test_auto_memory_wrapper.py index 7498d4b08c..3aad079c60 100644 --- a/packages/nvidia_nat_langchain/tests/agent/test_auto_memory_wrapper.py +++ b/packages/nvidia_nat_langchain/tests/agent/test_auto_memory_wrapper.py @@ -69,7 +69,7 @@ async def _ainvoke(chat_request: ChatRequest): def fixture_mock_context() -> Mock: """Create a mock Context for testing.""" context = Mock(spec=Context) - context.user_manager = None + context.user_id = None context.metadata = None return context @@ -145,6 +145,14 @@ def test_get_user_id_default(self, wrapper_graph): user_id = wrapper_graph._get_user_id_from_context() assert user_id == "default_user" + def test_get_user_id_from_context(self, wrapper_graph, mock_context): + """Test user ID extraction from Context.user_id.""" + mock_context.user_id = "user-from-context" + with patch('nat.plugins.langchain.agent.auto_memory_wrapper.agent.Context.get', return_value=mock_context): + wrapper_graph._context = mock_context + user_id = wrapper_graph._get_user_id_from_context() + assert user_id == "user-from-context" + def test_get_user_id_from_header(self, wrapper_graph, mock_context): """Test user ID extraction from X-User-ID header.""" mock_context.metadata = Mock()