Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 90 additions & 2 deletions core/wren-core-py/src/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ use datafusion_common::config::Dialect;
use pyo3::{pyclass, pymethods};
use std::collections::hash_map::Entry;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::ops::ControlFlow;
use std::sync::Arc;
use wren_core::ast::{visit_relations, ObjectName};
use wren_core::dialect::GenericDialect;
use wren_core::mdl::manifest::{Model, Relationship, View};
use wren_core::mdl::WrenMDL;
use wren_core::parser::Parser;
use wren_core_base::mdl::Manifest;

#[pyclass]
Expand Down Expand Up @@ -67,6 +71,49 @@ fn resolve_used_table_names(mdl: &WrenMDL, sql: &str) -> Result<Vec<String>, Cor
})
}

/// Parse a RLAC condition expression and return the model names referenced by
/// subqueries inside it. The condition is parsed with the same `GenericDialect`
/// used by wren-core so session-property placeholders (`@session_id`) are
/// tolerated. Tables qualified with a non-matching catalog/schema are ignored,
/// mirroring `resolve_used_table_names`.
fn resolve_condition_models(mdl: &WrenMDL, condition: &str) -> Vec<String> {
let dialect = GenericDialect {};
let expr = match Parser::new(&dialect)
.try_with_sql(condition)
.and_then(|mut parser| parser.parse_expr())
{
Ok(expr) => expr,
Err(_) => return vec![],
};
let mut tables = Vec::new();
let _ = visit_relations(&expr, |name: &ObjectName| {
if let Some(table) = matched_table_name(mdl, name) {
tables.push(table);
}
ControlFlow::<()>::Continue(())
});
tables
}

/// Resolve an `ObjectName` against the manifest's catalog/schema, returning the
/// bare table name when it belongs to this manifest.
fn matched_table_name(mdl: &WrenMDL, name: &ObjectName) -> Option<String> {
let parts: Vec<&str> = name
.0
.iter()
.filter_map(|part| part.as_ident().map(|ident| ident.value.as_str()))
.collect();
let (catalog, schema, table) = match parts.as_slice() {
[table] => (None, None, *table),
[schema, table] => (None, Some(*schema), *table),
[catalog, schema, table] => (Some(*catalog), Some(*schema), *table),
_ => return None,
};
let catalog_matches = catalog.is_none_or(|c| c == mdl.catalog());
let schema_matches = schema.is_none_or(|s| s == mdl.schema());
(catalog_matches && schema_matches).then(|| table.to_string())
}

fn extract_manifest(
mdl: &WrenMDL,
used_datasets: &[String],
Expand Down Expand Up @@ -98,15 +145,25 @@ fn extract_models(mdl: &WrenMDL, used_datasets: &[String]) -> Vec<Arc<Model>> {
let mut stack: Vec<String> = used_datasets.to_vec();
while let Some(dataset_name) = stack.pop() {
if let Some(model) = mdl.get_model(&dataset_name) {
model
let related_via_relationship = model
.columns
.iter()
.filter_map(|col| {
col.relationship
.as_ref()
.and_then(|rel_name| mdl.get_relationship(rel_name))
})
.flat_map(|rel| rel.models.clone())
.flat_map(|rel| rel.models.clone());
// A RLAC condition may reference other models through subqueries
// (e.g. `id IN (SELECT id FROM other_model)`). Those models must be
// kept even when the outer SQL doesn't reference them directly,
// otherwise the RLAC subquery analysis in wren-core fails.
let related_via_rlac = model
.row_level_access_controls()
.iter()
.flat_map(|rlac| resolve_condition_models(mdl, &rlac.condition));
related_via_relationship
.chain(related_via_rlac)
.for_each(|related| {
if let Entry::Vacant(vacant) = used_set.entry(related) {
let key = vacant.key().clone();
Expand Down Expand Up @@ -219,13 +276,40 @@ mod tests {
let p_view = ViewBuilder::new("part_view")
.statement("SELECT * FROM my_catalog.my_schema.part")
.build();
// A 3-hop RLAC subquery chain: level1 -> level2 -> level3. None of these
// are reachable through relationships, only through the RLAC conditions.
let level3 = ModelBuilder::new("level3")
.table_reference("main.level3")
.column(ColumnBuilder::new("id", "integer").build())
.build();
let level2 = ModelBuilder::new("level2")
.table_reference("main.level2")
.column(ColumnBuilder::new("id", "integer").build())
.add_row_level_access_control(
"level2_rlac",
vec![],
"id IN (SELECT id FROM my_catalog.my_schema.level3)",
)
.build();
let level1 = ModelBuilder::new("level1")
.table_reference("main.level1")
.column(ColumnBuilder::new("id", "integer").build())
.add_row_level_access_control(
"level1_rlac",
vec![],
"id IN (SELECT id FROM level2)",
)
.build();
let manifest = ManifestBuilder::new()
.catalog("my_catalog")
.schema("my_schema")
.model(customer)
.model(orders)
.model(lineitem)
.model(part)
.model(level1)
.model(level2)
.model(level3)
.relationship(c_o_relationship)
.relationship(o_l_relationship)
.view(c_view)
Expand Down Expand Up @@ -288,6 +372,10 @@ mod tests {
#[case(&["orders"], &["lineitem", "orders"])]
#[case(&["lineitem"], &["lineitem"])]
#[case(&["part_view", "part"], &["part"])]
// A model referenced only by a RLAC subquery is kept, recursively (3-hop chain).
#[case(&["level1"], &["level1", "level2", "level3"])]
#[case(&["level2"], &["level2", "level3"])]
#[case(&["level3"], &["level3"])]
fn test_extract_manifest_for_models(
extractor: PyManifestExtractor,
#[case] dataset: &[&str],
Expand Down
Loading
Loading