Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions trustgraph-cli/trustgraph/cli/load_structured_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def load_structured_data(
logger.info("Step 1: Analyzing data to discover best matching schema...")

# Step 1: Auto-discover schema (reuse discover_schema logic)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not discovered_schema:
logger.error("Failed to discover suitable schema automatically")
print("❌ Could not automatically determine the best schema for your data.")
Expand All @@ -90,7 +90,7 @@ def load_structured_data(

# Step 2: Auto-generate descriptor
logger.info("Step 2: Generating descriptor configuration...")
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
if not auto_descriptor:
logger.error("Failed to generate descriptor automatically")
print("❌ Could not automatically generate descriptor configuration.")
Expand Down Expand Up @@ -172,7 +172,7 @@ def load_structured_data(
logger.info(f"Sample chars: {sample_chars} characters")

# Use the helper function to discover schema (get raw response for display)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)

if response:
# Debug: print response type and content
Expand Down Expand Up @@ -203,7 +203,7 @@ def load_structured_data(
# If no schema specified, discover it first
if not schema_name:
logger.info("No schema specified, auto-discovering...")
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not schema_name:
print("Error: Could not determine schema automatically.")
print("Please specify a schema using --schema-name or run --discover-schema first.")
Expand All @@ -213,7 +213,7 @@ def load_structured_data(
logger.info(f"Target schema: {schema_name}")

# Generate descriptor using helper function
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)

if descriptor:
# Output the generated descriptor
Expand Down Expand Up @@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp


# Helper functions for auto mode
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
"""Auto-discover the best matching schema for the input data

Args:
Expand All @@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, workspace=workspace)
api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()

# Get available schemas
Expand Down Expand Up @@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
return None


def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
"""Auto-generate descriptor configuration for the discovered schema"""
try:
# Read sample data
Expand All @@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, workspace=workspace)
api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()

# Get schema definition
Expand Down
53 changes: 34 additions & 19 deletions trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan

from ... graphql import GraphQLSchemaBuilder, SortDirection

Expand Down Expand Up @@ -180,7 +180,7 @@ async def _apply_schema_config(self, workspace, config):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
indexed=field_def.get("indexed", False),
)
fields.append(field)

Expand Down Expand Up @@ -232,6 +232,8 @@ def find_matching_index(
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
if value == "" or value is None:
continue
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
Expand Down Expand Up @@ -282,11 +284,13 @@ async def query_cassandra(
query += f" LIMIT {limit}"

try:
rows = await async_execute(self.session, query, params)
for row in rows:
# Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
pages = await async_execute_paged(
self.session, query, params
)
for page in pages:
for row in page:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
Expand All @@ -308,8 +312,6 @@ async def query_cassandra(
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]

# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
Expand All @@ -320,17 +322,18 @@ async def query_cassandra(
params = [collection, schema_name, primary_index]

try:
rows = await async_execute(self.session, query, params)

for row in rows:
def row_filter(row):
row_dict = dict(row.data) if row.data else {}
return self._matches_filters(row_dict, filters, row_schema)

# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
results.append(row_dict)

if limit and len(results) >= limit:
break
matched_rows = await async_scan(
self.session, query, params,
row_filter=row_filter,
limit=limit,
)
for row in matched_rows:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)

except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
Expand Down Expand Up @@ -363,7 +366,7 @@ def _matches_filters(
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
field_name = parts[0]
operator = parts[1]
else:
Expand Down Expand Up @@ -400,6 +403,18 @@ def _matches_filters(
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
elif operator == 'not':
if str(row_value) == str(filter_value):
return False
elif operator == 'startsWith':
if not str(row_value).startswith(str(filter_value)):
return False
elif operator == 'endsWith':
if not str(row_value).endswith(str(filter_value)):
return False
elif operator == 'not_in':
if str(row_value) in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False

Expand Down
2 changes: 1 addition & 1 deletion trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def _apply_schema_config(self, workspace, config, version):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
indexed=field_def.get("indexed", False),
)
fields.append(field)

Expand Down
51 changes: 49 additions & 2 deletions trustgraph-flow/trustgraph/tables/cassandra_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
fut.set_exception(exc)


async def async_execute_paged(session, query, parameters=None, fetch_size=100):
async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
"""Execute a CQL query with page-by-page iteration.

Uses synchronous session.execute() inside run_in_executor so that
the driver's ResultSet paging works correctly without materialising
the entire result set in memory.

Yields one page of rows at a time (as a list).
Returns all pages as a list of lists.
"""
loop = asyncio.get_running_loop()

Expand All @@ -111,3 +111,50 @@ def _fetch_all_pages():
return await loop.run_in_executor(
None, _fetch_all_pages
)


async def async_scan(
session, query, parameters=None, row_filter=None,
limit=None, fetch_size=5000,
):
"""Scan a CQL query page-by-page, applying a filter and limit.

Only matching rows accumulate in memory. Each page is discarded
after processing, so peak memory is bounded by fetch_size plus
the number of matching rows (capped by limit).

Args:
session: cassandra.cluster.Session
query: CQL statement string
parameters: bind params
row_filter: callable(row) -> bool, or None to accept all
limit: max results to return, or None for unlimited
fetch_size: rows per Cassandra page fetch

Returns:
List of matching rows.
"""
loop = asyncio.get_running_loop()

if isinstance(query, str):
stmt = SimpleStatement(query, fetch_size=fetch_size)
else:
stmt = query
stmt.fetch_size = fetch_size

def _scan():
results = []
result_set = session.execute(stmt, parameters)
while True:
for row in result_set.current_rows:
if row_filter is None or row_filter(row):
results.append(row)
if limit and len(results) >= limit:
return results
if result_set.has_more_pages:
result_set.fetch_next_page()
else:
break
return results

return await loop.run_in_executor(None, _scan)
Loading