Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/unit/test_reliability/test_null_embedding_protection.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ async def test_lazy_collection_creation_on_new_dimension(self):
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
proc.replication_factor = 1
proc.shard_number = 1

msg = MagicMock()
msg.metadata.collection = "graphs"
Expand Down
87 changes: 87 additions & 0 deletions trustgraph-base/trustgraph/base/qdrant_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

import os
import argparse
from typing import Optional, Any, Tuple


def get_qdrant_defaults() -> dict:
return {
'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
'api_key': os.getenv('QDRANT_API_KEY'),
'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
}


def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
defaults = get_qdrant_defaults()

url_help = f"Qdrant URL (default: {defaults['url']})"
if 'QDRANT_URL' in os.environ:
url_help += " [from QDRANT_URL]"

api_key_help = "Qdrant API key"
if defaults['api_key']:
api_key_help += " (default: <set>)"
if 'QDRANT_API_KEY' in os.environ:
api_key_help += " [from QDRANT_API_KEY]"

replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
if 'QDRANT_REPLICATION_FACTOR' in os.environ:
replication_help += " [from QDRANT_REPLICATION_FACTOR]"

shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
if 'QDRANT_SHARD_NUMBER' in os.environ:
shard_help += " [from QDRANT_SHARD_NUMBER]"

parser.add_argument(
'--store-uri',
default=defaults['url'],
help=url_help,
)

parser.add_argument(
'--api-key',
default=defaults['api_key'],
help=api_key_help,
)

parser.add_argument(
'--qdrant-replication-factor',
type=int,
default=defaults['replication_factor'],
help=replication_help,
)

parser.add_argument(
'--qdrant-shard-number',
type=int,
default=defaults['shard_number'],
help=shard_help,
)


def resolve_qdrant_config(
args: Optional[Any] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
replication_factor: Optional[int] = None,
shard_number: Optional[int] = None,
) -> Tuple[str, Optional[str], int, int]:
if args is not None:
url = url or getattr(args, 'store_uri', None)
api_key = api_key or getattr(args, 'api_key', None)
replication_factor = replication_factor or getattr(
args, 'qdrant_replication_factor', None
)
shard_number = shard_number or getattr(
args, 'qdrant_shard_number', None
)

defaults = get_qdrant_defaults()
url = url or defaults['url']
api_key = api_key or defaults['api_key']
replication_factor = replication_factor or defaults['replication_factor']
shard_number = shard_number or defaults['shard_number']

return url, api_key, replication_factor, shard_number
18 changes: 11 additions & 7 deletions trustgraph-flow/trustgraph/direct/cassandra_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import ssl

from ..tables.cassandra_async import async_execute

Expand Down Expand Up @@ -41,13 +41,15 @@ class KnowledgeGraph:

def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
):

if hosts is None:
hosts = ["localhost"]

self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username

# 7-table schema for quads with full query pattern support
Expand All @@ -68,7 +70,7 @@ def __init__(
self.collection_metadata_table = "collection_metadata"

if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
Expand All @@ -92,7 +94,7 @@ def init(self):
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
'replication_factor' : {self.replication_factor}
}};
""")

Expand Down Expand Up @@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph:

def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
):

if hosts is None:
hosts = ["localhost"]

self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username

# 2-table entity-centric schema
Expand All @@ -556,7 +560,7 @@ def __init__(
self.collection_metadata_table = "collection_metadata"

if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
Expand All @@ -580,7 +584,7 @@ def init(self):
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
'replication_factor' : {self.replication_factor}
}};
""")

Expand Down
28 changes: 9 additions & 19 deletions trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,32 @@
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config

# Module logger
logger = logging.getLogger(__name__)

default_ident = "doc-embeddings-query"

default_store_uri = 'http://localhost:6333'

class Processor(DocumentEmbeddingsQueryService):

def __init__(self, **params):

store_uri = params.get("store_uri", default_store_uri)
store_uri = params.get("store_uri")
api_key = params.get("api_key")

#optional api key
api_key = params.get("api_key", None)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)

super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)

self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)

async def query_document_embeddings(self, workspace, msg):

Expand Down Expand Up @@ -85,18 +86,7 @@ async def query_document_embeddings(self, workspace, msg):
def add_args(parser):

DocumentEmbeddingsQueryService.add_args(parser)

parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)

parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
add_qdrant_args(parser)

def run():

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,32 @@
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config

# Module logger
logger = logging.getLogger(__name__)

default_ident = "graph-embeddings-query"

default_store_uri = 'http://localhost:6333'

class Processor(GraphEmbeddingsQueryService):

def __init__(self, **params):

store_uri = params.get("store_uri", default_store_uri)
store_uri = params.get("store_uri")
api_key = params.get("api_key")

#optional api key
api_key = params.get("api_key", None)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)

super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)

self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)

def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
Expand Down Expand Up @@ -104,18 +105,7 @@ async def query_graph_embeddings(self, workspace, msg):
def add_args(parser):

GraphEmbeddingsQueryService.add_args(parser)

parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)

parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
add_qdrant_args(parser)

def run():

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _create_schema(self):
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
""")

# Create triples table optimized for SPARQL queries
Expand Down
28 changes: 10 additions & 18 deletions trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config

# Module logger
logger = logging.getLogger(__name__)

default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
default_concurrency = 10


Expand All @@ -35,13 +35,17 @@ def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)

store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
store_uri = params.get("store_uri")
api_key = params.get("api_key")

url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)

super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
Expand All @@ -62,7 +66,7 @@ def __init__(self, **params):
)
)

self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)

def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
Expand Down Expand Up @@ -192,21 +196,9 @@ async def on_message(self, msg, consumer, flow):

@staticmethod
def add_args(parser):
"""Add command-line arguments"""

FlowProcessor.add_args(parser)

parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)

parser.add_argument(
'-k', '--api-key',
default=None,
help='API key for Qdrant (default: None)'
)
add_qdrant_args(parser)

parser.add_argument(
'-c', '--concurrency',
Expand Down
Loading
Loading