Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,25 @@ def _get_user_id_from_context(self) -> str:
Extract user_id from runtime context.

Priority order:
1. user_manager.get_id() - For authenticated sessions (set via SessionManager.session())
2. X-User-ID HTTP header - For testing/simple auth without middleware
3. "default_user" - Fallback for development/testing without authentication
1. Context.user_id - For authenticated sessions (set via SessionManager.session())
2. user_manager.get_id() - Legacy/custom context compatibility
3. X-User-ID HTTP header - For testing/simple auth without middleware
4. "default_user" - Fallback for development/testing without authentication

Returns:
str: The user ID for memory operations
"""
# Priority 1: Get user_id from user_manager (for authenticated sessions)
user_manager = self._context.user_manager
# Priority 1: Get user_id from the runtime context.
try:
user_id = self._context.user_id
if user_id:
logger.debug(f"Using user_id from context: {user_id}")
return user_id
except Exception as e:
logger.debug(f"Failed to get user_id from context: {e}")

# Priority 2: Get user_id from user_manager (legacy/custom context compatibility)
user_manager = getattr(self._context, "user_manager", None)
if user_manager and hasattr(user_manager, 'get_id'):
try:
user_id = user_manager.get_id()
Expand All @@ -90,9 +100,11 @@ def _get_user_id_from_context(self) -> str:
except Exception as e:
logger.debug(f"Failed to get user_id from user_manager: {e}")

# Priority 2: Extract from X-User-ID HTTP header (temporary workaround for testing)
if self._context.metadata and self._context.metadata.headers:
user_id = self._context.metadata.headers.get("x-user-id")
# Priority 3: Extract from X-User-ID HTTP header (temporary workaround for testing)
metadata = getattr(self._context, "metadata", None)
headers = getattr(metadata, "headers", None) if metadata else None
if headers:
user_id = headers.get("x-user-id") or headers.get("X-User-ID")
if user_id:
logger.debug(f"Using user_id from X-User-ID header: {user_id}")
return user_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def _ainvoke(chat_request: ChatRequest):
def fixture_mock_context() -> Mock:
"""Create a mock Context for testing."""
context = Mock(spec=Context)
context.user_manager = None
context.user_id = None
context.metadata = None
return context

Expand Down Expand Up @@ -145,6 +145,14 @@ def test_get_user_id_default(self, wrapper_graph):
user_id = wrapper_graph._get_user_id_from_context()
assert user_id == "default_user"

def test_get_user_id_from_context(self, wrapper_graph, mock_context):
"""Test user ID extraction from Context.user_id."""
mock_context.user_id = "user-from-context"
with patch('nat.plugins.langchain.agent.auto_memory_wrapper.agent.Context.get', return_value=mock_context):
wrapper_graph._context = mock_context
user_id = wrapper_graph._get_user_id_from_context()
assert user_id == "user-from-context"

def test_get_user_id_from_header(self, wrapper_graph, mock_context):
"""Test user ID extraction from X-User-ID header."""
mock_context.metadata = Mock()
Expand Down
Loading