From 8d8fb5766d39651f40157b0eae181f1a5466522a Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Tue, 2 Jun 2026 14:18:06 +0100 Subject: [PATCH] fix: route workspace through bulk WebSocket clients and merge query params Bulk clients (sync and async) were not forwarding the workspace parameter, causing all bulk operations to hit the default workspace regardless of the Api instance's workspace setting. Also fixes the gateway socket endpoint to pass query parameters (including workspace) to the dispatcher, and prevents the auth handshake from overwriting an explicitly set workspace. Updates knowledge table store tests for paged query interface. --- .../test_tables/test_knowledge_table_store.py | 26 +++++----- trustgraph-base/trustgraph/api/api.py | 4 +- .../trustgraph/api/async_bulk_client.py | 51 ++++++++---------- trustgraph-base/trustgraph/api/bulk_client.py | 52 +++++++++---------- .../trustgraph/api/socket_client.py | 3 +- .../trustgraph/gateway/endpoint/socket.py | 4 +- 6 files changed, 67 insertions(+), 73 deletions(-) diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py index 59d15b45..9a0b55c4 100644 --- a/tests/unit/test_tables/test_knowledge_table_store.py +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -35,9 +35,9 @@ def _make_store(): class TestGetGraphEmbeddings: @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) async def test_row_converts_to_entity_embeddings_with_singular_vector( - self, mock_async_execute + self, mock_async_execute_paged ): """ Cassandra rows return entities as a list of [entity_tuple, vector] @@ -57,7 +57,7 @@ async def test_row_converts_to_entity_embeddings_with_singular_vector( store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -66,7 +66,7 @@ async def receiver(msg): await store.get_graph_embeddings("alice", "doc-1", receiver) - mock_async_execute.assert_called_once_with( + mock_async_execute_paged.assert_called_once_with( store.cassandra, store.get_graph_embeddings_stmt, ("alice", "doc-1"), @@ -96,8 +96,8 @@ async def receiver(msg): assert ge.entities[2].entity.value == "a literal entity" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged): """row[3] being None / empty must produce a GraphEmbeddings with no entities, not raise.""" fake_row = (None, None, None, None) @@ -105,7 +105,7 @@ async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute): store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -118,8 +118,8 @@ async def receiver(msg): assert received[0].entities == [] @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_multiple_rows_each_emit_one_message(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged): fake_rows = [ (None, None, None, [ (("http://example.org/a", True), [1.0]), @@ -132,7 +132,7 @@ async def test_multiple_rows_each_emit_one_message(self, mock_async_execute): store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = fake_rows + mock_async_execute_paged.return_value = [fake_rows] received = [] @@ -153,8 +153,8 @@ class TestGetTriples: the same Metadata construction. Cover it for parity.""" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_row_converts_to_triples(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_row_converts_to_triples(self, mock_async_execute_paged): # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri) fake_row = ( None, None, None, @@ -170,7 +170,7 @@ async def test_row_converts_to_triples(self, mock_async_execute): store = _make_store() store.cassandra = Mock() store.get_triples_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index 9074bac1..0190d3f5 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -337,7 +337,7 @@ def triple_generator(): from . bulk_client import BulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._bulk_client = BulkClient(base_url, self.timeout, self.token) + self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._bulk_client def metrics(self): @@ -462,7 +462,7 @@ async def triple_gen(): from . async_bulk_client import AsyncBulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token) + self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._async_bulk_client def async_metrics(self): diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py index 9a6a49c3..f93ab667 100644 --- a/trustgraph-base/trustgraph/api/async_bulk_client.py +++ b/trustgraph-base/trustgraph/api/async_bulk_client.py @@ -9,10 +9,11 @@ class AsyncBulkClient: """Asynchronous bulk operations client""" - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -25,11 +26,21 @@ def _convert_to_ws_url(self, url: str) -> str: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None: """Bulk import triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for triple in triples: @@ -42,9 +53,7 @@ async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwar async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]: """Bulk export triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -57,9 +66,7 @@ async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -67,9 +74,7 @@ async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dic async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -77,9 +82,7 @@ async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterat async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -87,9 +90,7 @@ async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[ async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -97,9 +98,7 @@ async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIte async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for context in contexts: @@ -107,9 +106,7 @@ async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[s async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -117,9 +114,7 @@ async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterato async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import rows via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for row in rows: diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 0e49fc4e..ae185240 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -34,7 +34,7 @@ class BulkClient: Note: For true async support, use AsyncBulkClient instead. """ - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: """ Initialize synchronous bulk client. @@ -42,10 +42,12 @@ def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS) timeout: WebSocket timeout in seconds token: Optional bearer token for authentication + workspace: Workspace for data isolation """ self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -58,6 +60,18 @@ def _convert_to_ws_url(self, url: str) -> str: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any: """Run async coroutine synchronously""" try: @@ -116,9 +130,7 @@ async def _import_triples_async( metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of triple import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -194,9 +206,7 @@ def export_triples(self, flow: str, **kwargs: Any) -> Iterator[Triple]: async def _export_triples_async(self, flow: str) -> Iterator[Triple]: """Async implementation of triple export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -238,9 +248,7 @@ def embedding_generator(): async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of graph embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -296,9 +304,7 @@ def export_graph_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of graph embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -336,9 +342,7 @@ def doc_embedding_generator(): async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of document embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -394,9 +398,7 @@ def export_document_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[ async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of document embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -446,9 +448,7 @@ async def _import_entity_contexts_async( metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of entity contexts import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -522,9 +522,7 @@ def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of entity contexts export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -562,9 +560,7 @@ def row_generator(): async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None: """Async implementation of rows import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for row in rows: diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 6eeb95ff..b88d0c78 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -167,7 +167,8 @@ async def _ensure_connected(self): ) if resp.get("type") == "auth-ok": - self.workspace = resp.get("workspace", self.workspace) + if self.workspace == "default": + self.workspace = resp.get("workspace", self.workspace) elif resp.get("type") == "auth-failed": await self._socket.close() raise ProtocolException( diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index f53ad73b..af6183db 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -117,8 +117,10 @@ async def handle(self, request): running = Running() + params = dict(request.query) + params.update(request.match_info) dispatcher = await self.dispatcher( - ws, running, request.match_info + ws, running, params ) worker_task = tg.create_task(