Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion tests/unit/test_base/test_cassandra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,4 +409,57 @@ def test_mixed_none_and_values(self):

assert hosts == ['mixed-host']
assert username is None # Stays None
assert password == 'mixed-pass'
assert password == 'mixed-pass'


class TestReplicationFactorParamPath:

def test_explicit_kwarg(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3

def test_kwarg_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3

def test_env_fallback_when_kwarg_none(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=None,
)
assert rf == 5

def test_default_when_no_kwarg_no_env(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config()
assert rf == 1

def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3

def test_params_dict_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3

def test_params_dict_missing_falls_to_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 5
136 changes: 136 additions & 0 deletions tests/unit/test_base/test_qdrant_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@

import os
import pytest
from unittest.mock import patch

from trustgraph.base.qdrant_config import (
get_qdrant_defaults,
resolve_qdrant_config,
)


class TestGetQdrantDefaults:

def test_defaults_with_no_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://localhost:6333'
assert defaults['api_key'] is None
assert defaults['replication_factor'] == 1
assert defaults['shard_number'] == 1

def test_defaults_from_env(self):
env = {
'QDRANT_URL': 'http://qdrant:6333',
'QDRANT_API_KEY': 'secret',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://qdrant:6333'
assert defaults['api_key'] == 'secret'
assert defaults['replication_factor'] == 3
assert defaults['shard_number'] == 5


class TestResolveQdrantConfig:

def test_defaults(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config()
assert url == 'http://localhost:6333'
assert api_key is None
assert rf == 1
assert sn == 1

def test_explicit_kwargs(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config(
url='http://custom:6333',
api_key='key',
replication_factor=3,
shard_number=5,
)
assert url == 'http://custom:6333'
assert api_key == 'key'
assert rf == 3
assert sn == 5

def test_kwargs_override_env(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config(
url='http://explicit:6333',
replication_factor=3,
shard_number=5,
)
assert url == 'http://explicit:6333'
assert rf == 3
assert sn == 5

def test_env_fallback_when_kwargs_none(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config()
assert url == 'http://env:6333'
assert rf == 3
assert sn == 5

def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {
'store_uri': 'http://params:6333',
'api_key': 'pkey',
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
url, api_key, rf, sn = resolve_qdrant_config(
url=params.get('store_uri'),
api_key=params.get('api_key'),
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert url == 'http://params:6333'
assert api_key == 'pkey'
assert rf == 3
assert sn == 5

def test_params_dict_overrides_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
params = {
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5

def test_params_dict_missing_falls_to_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
params = {}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5
26 changes: 5 additions & 21 deletions trustgraph-base/trustgraph/base/cassandra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,35 +103,19 @@ def resolve_cassandra_config(
host: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
default_keyspace: Optional[str] = None
default_keyspace: Optional[str] = None,
replication_factor: Optional[int] = None,
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
"""
Resolve Cassandra configuration from various sources.

Can accept either argparse args object or explicit parameters.
Converts host string to list format for Cassandra driver.

Args:
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
host: Optional explicit host parameter (overrides args)
username: Optional explicit username parameter (overrides args)
password: Optional explicit password parameter (overrides args)
default_keyspace: Optional default keyspace if not specified elsewhere

Returns:
tuple: (hosts_list, username, password, keyspace, replication_factor)
"""
# If args provided, extract values
keyspace = None
replication_factor = 1
if args is not None:
host = host or getattr(args, 'cassandra_host', None)
username = username or getattr(args, 'cassandra_username', None)
password = password or getattr(args, 'cassandra_password', None)
keyspace = getattr(args, 'cassandra_keyspace', None)
replication_factor = getattr(args, 'cassandra_replication_factor', 1)
replication_factor = replication_factor or getattr(
args, 'cassandra_replication_factor', None
)

# Apply defaults if still None
defaults = get_cassandra_defaults()
host = host or defaults['host']
username = username or defaults['username']
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/config/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(self, **params):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="config"
default_keyspace="config",
replication_factor=params.get("cassandra_replication_factor"),
)

# Store resolved configuration
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/cores/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(self, **params):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="knowledge"
default_keyspace="knowledge",
replication_factor=params.get("cassandra_replication_factor"),
)

self.cassandra_host = hosts
Expand Down
1 change: 1 addition & 0 deletions trustgraph-flow/trustgraph/iam/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, **params):
username=cassandra_username,
password=cassandra_password,
default_keyspace="iam",
replication_factor=params.get("cassandra_replication_factor"),
)

self.cassandra_host = hosts
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/librarian/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def __init__(self, **params):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="librarian"
default_keyspace="librarian",
replication_factor=params.get("cassandra_replication_factor"),
)

# Store resolved configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self, **params):
api_key = params.get("api_key")

url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
url=store_uri,
api_key=api_key,
)

super(Processor, self).__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, **params):

url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
)

super(Processor, self).__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, **params):

url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
)

super(Processor, self).__init__(
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/storage/knowledge/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self, **params):
host=params.get("cassandra_host"),
username=params.get("cassandra_username"),
password=params.get("cassandra_password"),
default_keyspace='knowledge'
default_keyspace='knowledge',
replication_factor=params.get("cassandra_replication_factor"),
)

super(Processor, self).__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(self, **params):

url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
)

super(Processor, self).__init__(
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, **params):
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
password=cassandra_password,
replication_factor=params.get("cassandra_replication_factor"),
)

# Store resolved configuration with proper names
Expand Down
Loading