From cdce0d5e4e8ee19ae356821722f00168ab404319 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 29 May 2026 17:51:48 +0800 Subject: [PATCH 1/2] fix(wren-core): support to-many calculated field joined on non-PK column Materialize the source model's primary key when building a TO-MANY calculated field, so the calculated subquery can be joined back even when the relationship is keyed on a non-PK column (e.g. `customer.mock_id = whitelist.mock_id`). Also use physical columns (instead of visible columns) in `Lineage` and the `WrenMDL` symbol table so hidden calculated columns participate in lineage and qualified-name resolution. This lets hidden RLAC helper fields be referenced inside RLAC expressions while remaining non-selectable by users. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../core/src/logical_plan/analyze/plan.rs | 25 ++ core/wren-core/core/src/mdl/lineage.rs | 3 +- core/wren-core/core/src/mdl/mod.rs | 224 +++++++++++++++++- 3 files changed, 249 insertions(+), 3 deletions(-) diff --git a/core/wren-core/core/src/logical_plan/analyze/plan.rs b/core/wren-core/core/src/logical_plan/analyze/plan.rs index f6aa53fbd9..a0beb9f359 100644 --- a/core/wren-core/core/src/logical_plan/analyze/plan.rs +++ b/core/wren-core/core/src/logical_plan/analyze/plan.rs @@ -506,6 +506,31 @@ impl ModelPlanNodeBuilder { &mut partial_model_required_fields, )?; + // Default use the primary key of the model as the required field for the partial model if there is no any required field. It's always used when the `TO_MANY` calculation is used, + // because the result of `TO_MANY` calculation is always grouped by the primary key of the source model. + let Some(model) = self + .analyzed_wren_mdl + .wren_mdl() + .get_model(model_ref.table()) + else { + return plan_err!("Model not found for {}", model_ref); + }; + + let primary_key_column = model.primary_key().and_then(|pk| model.get_column(pk)); + + if let Some(primary_key_column) = primary_key_column { + let expr = create_wren_expr_for_model( + &primary_key_column.name, + Arc::clone(&model), + Arc::clone(&self.session_state), + )?; + + partial_model_required_fields + .entry(model_ref.clone()) + .or_default() + .insert(OrdExpr::with_column(expr, Arc::clone(&column))); + } + let mut iter = column_graph.node_indices(); let start = iter.next().unwrap(); diff --git a/core/wren-core/core/src/mdl/lineage.rs b/core/wren-core/core/src/mdl/lineage.rs index 8e5e7712e4..9402c04a2c 100644 --- a/core/wren-core/core/src/mdl/lineage.rs +++ b/core/wren-core/core/src/mdl/lineage.rs @@ -40,7 +40,8 @@ impl Lineage { let mut source_columns_map = HashMap::new(); for model in mdl.manifest.models.iter() { - for column in model.get_visible_columns() { + // Use physical columns for both the normal column and the calculated column + for column in model.get_physical_columns(false) { if column.is_calculated { let expr: &String = match column.expression { Some(ref exp) => exp, diff --git a/core/wren-core/core/src/mdl/mod.rs b/core/wren-core/core/src/mdl/mod.rs index 157f45189e..ae3bace2dc 100644 --- a/core/wren-core/core/src/mdl/mod.rs +++ b/core/wren-core/core/src/mdl/mod.rs @@ -199,7 +199,7 @@ impl WrenMDL { pub fn new(manifest: Manifest) -> Self { let mut qualifed_references = HashMap::new(); manifest.models.iter().for_each(|model| { - model.get_visible_columns().for_each(|column| { + model.columns.iter().for_each(|column| { qualifed_references.insert( from_qualified_name_str( &manifest.catalog, @@ -209,7 +209,7 @@ impl WrenMDL { ), ColumnReference::new( Dataset::Model(Arc::clone(model)), - Arc::clone(&column), + Arc::clone(column), ), ); }); @@ -1978,6 +1978,64 @@ mod test { Ok(()) } + /// Test the calculated column with to-many relationship. The join condtion of the to-many relationship is based on a normal column (not the primary key column). + #[tokio::test] + async fn test_to_many_calculate_join_with_normal_column() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("mock_id", "int").build()) + .column( + ColumnBuilder::new_relationship( + "whitelist", + "whitelist", + "customer_whitelist", + ) + .build(), + ) + .column( + ColumnBuilder::new_calculated("whitelist_name", "array") + .expression("array_agg(whitelist.allow_name)") + .build(), + ) + .primary_key("c_custkey") + .build(), + ) + .model( + ModelBuilder::new("whitelist") + .table_reference("whitelist") + .column(ColumnBuilder::new("allow_name", "string").build()) + .column(ColumnBuilder::new("mock_id", "int").build()) + .build(), + ) + .relationship( + RelationshipBuilder::new("customer_whitelist") + .model("customer") + .model("whitelist") + .join_type(JoinType::OneToMany) + .condition("customer.mock_id = whitelist.mock_id") + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + + let sql = "SELECT c_custkey, whitelist_name FROM customer"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), sql).await?, + @"SELECT customer.c_custkey, customer.whitelist_name FROM (SELECT __relation__1.c_custkey, __relation__1.whitelist_name FROM (SELECT whitelist_name.c_custkey, whitelist_name.whitelist_name FROM (SELECT __relation__1.c_custkey AS c_custkey, array_agg(__relation__1.allow_name) AS whitelist_name FROM (SELECT whitelist.allow_name, customer.c_custkey, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT __source.allow_name AS allow_name, __source.mock_id AS mock_id FROM whitelist AS __source) AS whitelist) AS whitelist) AS whitelist RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey, __source.mock_id AS mock_id FROM customer AS __source) AS customer ON whitelist.mock_id = customer.mock_id) AS __relation__1 GROUP BY __relation__1.c_custkey) AS whitelist_name RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer ON whitelist_name.c_custkey = customer.c_custkey) AS __relation__1) AS customer" + ); + Ok(()) + } + #[tokio::test] async fn test_rlac_with_requried_properties() -> Result<()> { let ctx = create_wren_ctx(None, None); @@ -2552,6 +2610,168 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_rlac_on_to_many_calculated_field() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .column(ColumnBuilder::new("mock_id", "int").build()) + .column( + ColumnBuilder::new_relationship( + "whitelist", + "whitelist", + "customer_whitelist", + ) + .build(), + ) + .column( + ColumnBuilder::new_calculated("whitelist_name", "array") + .expression("array_agg(whitelist.allow_name)") + .build(), + ) + .add_row_level_access_control( + "allow_user_name", + vec![SessionProperty::new_required("session_user")], + "array_contains(whitelist_name, @session_user)", + ) + .primary_key("c_custkey") + .build(), + ) + .model( + ModelBuilder::new("whitelist") + .table_reference("whitelist") + .column(ColumnBuilder::new("allow_name", "string").build()) + .column(ColumnBuilder::new("mock_id", "int").build()) + .primary_key("mock_id") + .build(), + ) + .relationship( + RelationshipBuilder::new("customer_whitelist") + .model("customer") + .model("whitelist") + .join_type(JoinType::OneToMany) + .condition("customer.mock_id = whitelist.mock_id") + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'Gura'".to_string()), + )])); + let sql = "SELECT c_custkey FROM customer"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql).await?, + @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey, customer.whitelist_name FROM (SELECT __relation__1.c_custkey, __relation__1.whitelist_name FROM (SELECT whitelist_name.c_custkey, whitelist_name.whitelist_name FROM (SELECT __relation__1.c_custkey AS c_custkey, array_agg(__relation__1.allow_name) AS whitelist_name FROM (SELECT whitelist.allow_name, customer.c_custkey, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT __source.allow_name AS allow_name, __source.mock_id AS mock_id FROM whitelist AS __source) AS whitelist) AS whitelist) AS whitelist RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey, __source.mock_id AS mock_id FROM customer AS __source) AS customer ON whitelist.mock_id = customer.mock_id) AS __relation__1 GROUP BY __relation__1.c_custkey) AS whitelist_name RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer ON whitelist_name.c_custkey = customer.c_custkey) AS __relation__1) AS customer WHERE array_has(customer.whitelist_name, 'Gura')) AS customer" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_rlac_on_to_many_hidden_calculated_field() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("mock_id", "int").build()) + .column( + ColumnBuilder::new_relationship( + "whitelist", + "whitelist", + "customer_whitelist", + ) + .build(), + ) + .column( + ColumnBuilder::new_calculated("whitelist_name", "array") + .expression("array_agg(whitelist.allow_name)") + .hidden(true) + .build(), + ) + .add_row_level_access_control( + "allow_user_name", + vec![SessionProperty::new_required("session_user")], + "array_contains(whitelist_name, @session_user)", + ) + .primary_key("c_custkey") + .build(), + ) + .model( + ModelBuilder::new("whitelist") + .table_reference("whitelist") + .column(ColumnBuilder::new("allow_name", "string").build()) + .column(ColumnBuilder::new("mock_id", "int").build()) + .primary_key("mock_id") + .build(), + ) + .relationship( + RelationshipBuilder::new("customer_whitelist") + .model("customer") + .model("whitelist") + .join_type(JoinType::OneToMany) + .condition("customer.mock_id = whitelist.mock_id") + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'Gura'".to_string()), + )])); + + let sql = "SELECT * FROM customer"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql).await?, + @"SELECT customer.c_custkey, customer.mock_id FROM (SELECT customer.c_custkey, customer.mock_id, customer.whitelist_name FROM (SELECT __relation__1.c_custkey, __relation__1.mock_id, __relation__1.whitelist_name FROM (SELECT whitelist_name.c_custkey, customer.mock_id, whitelist_name.whitelist_name FROM (SELECT __relation__1.c_custkey AS c_custkey, array_agg(__relation__1.allow_name) AS whitelist_name FROM (SELECT whitelist.allow_name, customer.c_custkey, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT __source.allow_name AS allow_name, __source.mock_id AS mock_id FROM whitelist AS __source) AS whitelist) AS whitelist) AS whitelist RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey, __source.mock_id AS mock_id FROM customer AS __source) AS customer ON whitelist.mock_id = customer.mock_id) AS __relation__1 GROUP BY __relation__1.c_custkey) AS whitelist_name RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey, __source.mock_id AS mock_id FROM customer AS __source) AS customer ON whitelist_name.c_custkey = customer.c_custkey) AS __relation__1) AS customer WHERE array_has(customer.whitelist_name, 'Gura')) AS customer" + ); + + let sql = "SELECT c_custkey FROM customer"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql).await?, + @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey, customer.whitelist_name FROM (SELECT __relation__1.c_custkey, __relation__1.whitelist_name FROM (SELECT whitelist_name.c_custkey, whitelist_name.whitelist_name FROM (SELECT __relation__1.c_custkey AS c_custkey, array_agg(__relation__1.allow_name) AS whitelist_name FROM (SELECT whitelist.allow_name, customer.c_custkey, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT whitelist.allow_name, whitelist.mock_id FROM (SELECT __source.allow_name AS allow_name, __source.mock_id AS mock_id FROM whitelist AS __source) AS whitelist) AS whitelist) AS whitelist RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey, __source.mock_id AS mock_id FROM customer AS __source) AS customer ON whitelist.mock_id = customer.mock_id) AS __relation__1 GROUP BY __relation__1.c_custkey) AS whitelist_name RIGHT OUTER JOIN (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer ON whitelist_name.c_custkey = customer.c_custkey) AS __relation__1) AS customer WHERE array_has(customer.whitelist_name, 'Gura')) AS customer" + ); + + let sql = "SELECT whitelist_name FROM customer"; + match transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + headers.clone(), + sql, + ) + .await + { + Ok(_) => { + panic!("whitelist_name is hidden, it should not be selected directly") + } + Err(e) => { + assert_snapshot!(e.to_string(), @"Schema error: No field named whitelist_name. Valid fields are customer.c_custkey, customer.mock_id.") + } + } + + Ok(()) + } + #[tokio::test] async fn test_rlac_alias_model() -> Result<()> { let ctx = create_wren_ctx(None, None); From cfcbf45584e54b0171a88ae136bb2fd728b1c0a0 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 29 May 2026 17:52:00 +0800 Subject: [PATCH 2/2] feat(wren-core): allow RLAC condition to use subqueries on MDL models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit RLAC conditions can now reference other Wren models via subqueries, e.g. `c_custkey IN (SELECT allowed_id FROM allowed WHERE allowed_user = @session_user)`. - A new `RlacContextProvider` resolves table references to MDL models during expression parsing. The default `SessionState::create_logical_expr` builds a context provider with an empty table map and never consults the catalog, so previously any table reference inside an RLAC condition failed with "table not found". - RLAC parsing moves from `ModelGenerationRule` to `ModelAnalyzeRule`. The pre-parsed filter is carried on `ModelPlanNode::rlac_filter` and exposed via `expressions()`, so the analyzer's subquery traversal rewrites inner `TableScan`s into `ModelPlanNode`s — the referenced models' own RLAC/CLAC and remote table mapping apply transitively. - Cycle detection via a shared `building_models` stack on `ModelAnalyzeRule`: A↔B and self-referential RLAC now report a planning error instead of looping. - `collect_condition` skips bare identifiers inside subqueries (they belong to a different scope) while still collecting `@property` references everywhere. - `extract_models` keeps models referenced solely by another model's RLAC subquery condition, recursively, so manifest extraction doesn't trim them. Co-Authored-By: Claude Opus 4.8 (1M context) --- core/wren-core-py/src/extractor.rs | 92 ++++- .../logical_plan/analyze/access_control.rs | 350 +++++++++++++++--- .../src/logical_plan/analyze/model_anlayze.rs | 174 ++++++++- .../logical_plan/analyze/model_generation.rs | 50 +-- .../core/src/logical_plan/analyze/plan.rs | 86 ++++- core/wren-core/core/src/mdl/mod.rs | 205 ++++++++++ 6 files changed, 842 insertions(+), 115 deletions(-) diff --git a/core/wren-core-py/src/extractor.rs b/core/wren-core-py/src/extractor.rs index d2d22649cb..8f13069d18 100644 --- a/core/wren-core-py/src/extractor.rs +++ b/core/wren-core-py/src/extractor.rs @@ -4,9 +4,13 @@ use datafusion_common::config::Dialect; use pyo3::{pyclass, pymethods}; use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap, HashSet}; +use std::ops::ControlFlow; use std::sync::Arc; +use wren_core::ast::{visit_relations, ObjectName}; +use wren_core::dialect::GenericDialect; use wren_core::mdl::manifest::{Model, Relationship, View}; use wren_core::mdl::WrenMDL; +use wren_core::parser::Parser; use wren_core_base::mdl::Manifest; #[pyclass] @@ -67,6 +71,49 @@ fn resolve_used_table_names(mdl: &WrenMDL, sql: &str) -> Result, Cor }) } +/// Parse a RLAC condition expression and return the model names referenced by +/// subqueries inside it. The condition is parsed with the same `GenericDialect` +/// used by wren-core so session-property placeholders (`@session_id`) are +/// tolerated. Tables qualified with a non-matching catalog/schema are ignored, +/// mirroring `resolve_used_table_names`. +fn resolve_condition_models(mdl: &WrenMDL, condition: &str) -> Vec { + let dialect = GenericDialect {}; + let expr = match Parser::new(&dialect) + .try_with_sql(condition) + .and_then(|mut parser| parser.parse_expr()) + { + Ok(expr) => expr, + Err(_) => return vec![], + }; + let mut tables = Vec::new(); + let _ = visit_relations(&expr, |name: &ObjectName| { + if let Some(table) = matched_table_name(mdl, name) { + tables.push(table); + } + ControlFlow::<()>::Continue(()) + }); + tables +} + +/// Resolve an `ObjectName` against the manifest's catalog/schema, returning the +/// bare table name when it belongs to this manifest. +fn matched_table_name(mdl: &WrenMDL, name: &ObjectName) -> Option { + let parts: Vec<&str> = name + .0 + .iter() + .filter_map(|part| part.as_ident().map(|ident| ident.value.as_str())) + .collect(); + let (catalog, schema, table) = match parts.as_slice() { + [table] => (None, None, *table), + [schema, table] => (None, Some(*schema), *table), + [catalog, schema, table] => (Some(*catalog), Some(*schema), *table), + _ => return None, + }; + let catalog_matches = catalog.is_none_or(|c| c == mdl.catalog()); + let schema_matches = schema.is_none_or(|s| s == mdl.schema()); + (catalog_matches && schema_matches).then(|| table.to_string()) +} + fn extract_manifest( mdl: &WrenMDL, used_datasets: &[String], @@ -98,7 +145,7 @@ fn extract_models(mdl: &WrenMDL, used_datasets: &[String]) -> Vec> { let mut stack: Vec = used_datasets.to_vec(); while let Some(dataset_name) = stack.pop() { if let Some(model) = mdl.get_model(&dataset_name) { - model + let related_via_relationship = model .columns .iter() .filter_map(|col| { @@ -106,7 +153,17 @@ fn extract_models(mdl: &WrenMDL, used_datasets: &[String]) -> Vec> { .as_ref() .and_then(|rel_name| mdl.get_relationship(rel_name)) }) - .flat_map(|rel| rel.models.clone()) + .flat_map(|rel| rel.models.clone()); + // A RLAC condition may reference other models through subqueries + // (e.g. `id IN (SELECT id FROM other_model)`). Those models must be + // kept even when the outer SQL doesn't reference them directly, + // otherwise the RLAC subquery analysis in wren-core fails. + let related_via_rlac = model + .row_level_access_controls() + .iter() + .flat_map(|rlac| resolve_condition_models(mdl, &rlac.condition)); + related_via_relationship + .chain(related_via_rlac) .for_each(|related| { if let Entry::Vacant(vacant) = used_set.entry(related) { let key = vacant.key().clone(); @@ -219,6 +276,30 @@ mod tests { let p_view = ViewBuilder::new("part_view") .statement("SELECT * FROM my_catalog.my_schema.part") .build(); + // A 3-hop RLAC subquery chain: level1 -> level2 -> level3. None of these + // are reachable through relationships, only through the RLAC conditions. + let level3 = ModelBuilder::new("level3") + .table_reference("main.level3") + .column(ColumnBuilder::new("id", "integer").build()) + .build(); + let level2 = ModelBuilder::new("level2") + .table_reference("main.level2") + .column(ColumnBuilder::new("id", "integer").build()) + .add_row_level_access_control( + "level2_rlac", + vec![], + "id IN (SELECT id FROM my_catalog.my_schema.level3)", + ) + .build(); + let level1 = ModelBuilder::new("level1") + .table_reference("main.level1") + .column(ColumnBuilder::new("id", "integer").build()) + .add_row_level_access_control( + "level1_rlac", + vec![], + "id IN (SELECT id FROM level2)", + ) + .build(); let manifest = ManifestBuilder::new() .catalog("my_catalog") .schema("my_schema") @@ -226,6 +307,9 @@ mod tests { .model(orders) .model(lineitem) .model(part) + .model(level1) + .model(level2) + .model(level3) .relationship(c_o_relationship) .relationship(o_l_relationship) .view(c_view) @@ -288,6 +372,10 @@ mod tests { #[case(&["orders"], &["lineitem", "orders"])] #[case(&["lineitem"], &["lineitem"])] #[case(&["part_view", "part"], &["part"])] + // A model referenced only by a RLAC subquery is kept, recursively (3-hop chain). + #[case(&["level1"], &["level1", "level2", "level3"])] + #[case(&["level2"], &["level2", "level3"])] + #[case(&["level3"], &["level3"])] fn test_extract_manifest_for_models( extractor: PyManifestExtractor, #[case] dataset: &[&str], diff --git a/core/wren-core/core/src/logical_plan/analyze/access_control.rs b/core/wren-core/core/src/logical_plan/analyze/access_control.rs index f936bc9cb0..6631fbcd69 100644 --- a/core/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/core/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -5,15 +5,25 @@ use std::{ }; use datafusion::{ - common::{plan_err, Result, Spans}, + arrow::datatypes::{DataType, SchemaRef}, + common::{ + config::ConfigOptions, file_options::file_type::FileType, plan_datafusion_err, + plan_err, Result, Spans, + }, error::DataFusionError, + logical_expr::{ + builder::LogicalTableSource, + planner::{ContextProvider, ExprPlanner, TypePlanner}, + AggregateUDF, ScalarUDF, TableSource, WindowUDF, + }, prelude::Expr, sql::{ parser::DFParserBuilder, + planner::{ParserOptions, PlannerContext, SqlToRel}, sqlparser::{ ast::{ - self, visit_expressions, visit_expressions_mut, Array, ExprWithAlias, - Map, MapEntry, + self, visit_expressions_mut, Array, ExprWithAlias, Map, MapEntry, Query, + Visit, Visitor, }, dialect::GenericDialect, }, @@ -24,62 +34,102 @@ use wren_core_base::mdl::RowLevelAccessControl; use wren_core_base::mdl::{Column, Model, SessionProperty}; use crate::{ - logical_plan::utils::from_qualified_name, - mdl::{context::SessionPropertiesRef, Dataset, SessionStateRef}, + logical_plan::utils::{create_schema, from_qualified_name}, + mdl::{ + context::SessionPropertiesRef, type_planner::WrenTypePlanner, Dataset, + SessionStateRef, + }, AnalyzedWrenMDL, }; /// Collect the required field from the condition of row level access control rules. +/// +/// Two outputs: +/// 1. Bare-identifier column references **at the top level** of the condition. These are +/// treated as columns of the outer model and are pre-marked as required so they are +/// not pruned. Identifiers inside subqueries are skipped — they belong to a different +/// scope and will be validated/resolved when the subquery is analyzed. +/// 2. Session properties (`@name`) referenced **anywhere** in the condition, including +/// inside subqueries. These need to be substituted at RLAC parse time regardless of +/// where they appear. pub fn collect_condition( model: &Model, condition: &str, ) -> Result<(Vec, Vec)> { - let mut conditions = HashSet::new(); - let mut session_properties: HashSet = HashSet::new(); - let mut error: Option> = None; let dialect = GenericDialect {}; let mut parser = DFParserBuilder::new(condition) .with_dialect(&dialect) .build()?; let expr = parser.parse_expr()?; - let _ = visit_expressions(&expr, |expr| { + + let mut visitor = ConditionVisitor { + model, + conditions: HashSet::new(), + session_properties: HashSet::new(), + subquery_depth: 0, + error: None, + }; + let _ = expr.visit(&mut visitor); + + if let Some(err) = visitor.error { + err?; + } + + Ok(( + visitor.conditions.into_iter().collect(), + visitor.session_properties.into_iter().collect::>(), + )) +} + +struct ConditionVisitor<'a> { + model: &'a Model, + conditions: HashSet, + session_properties: HashSet, + subquery_depth: usize, + error: Option>, +} + +impl Visitor for ConditionVisitor<'_> { + type Break = (); + + fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow { + self.subquery_depth += 1; + ControlFlow::Continue(()) + } + + fn post_visit_query(&mut self, _query: &Query) -> ControlFlow { + self.subquery_depth -= 1; + ControlFlow::Continue(()) + } + + fn pre_visit_expr(&mut self, expr: &ast::Expr) -> ControlFlow { // TODO: consider CompoundIdentifier and CompoundFieldAccess if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { - if !value.starts_with("@") { - if model.get_column(value).is_none() { - error = Some(plan_err!( + if let Some(session_property) = value.strip_prefix("@") { + self.session_properties + .insert(session_property.to_ascii_lowercase()); + } else if self.subquery_depth == 0 { + // Validate and collect only at the outer scope. Inside subqueries the + // identifier refers to a different model's column; that resolution is + // deferred to ModelAnalyzeRule's recursive subquery analysis. + if self.model.get_column(value).is_none() { + self.error = Some(plan_err!( "The column {} is not in the model {}", value, - model.name() + self.model.name() )); return ControlFlow::Break(()); } - conditions.insert(Expr::Column(datafusion::common::Column { - relation: Some(TableReference::bare(model.name())), - name: value.to_string(), - spans: Spans::new(), - })); - } else { - let session_property = value - .trim_start_matches("@") - .to_string() - .to_ascii_lowercase(); - if !session_properties.contains(&session_property) { - session_properties.insert(session_property); - } + self.conditions + .insert(Expr::Column(datafusion::common::Column { + relation: Some(TableReference::bare(self.model.name())), + name: value.to_string(), + spans: Spans::new(), + })); } } ControlFlow::Continue(()) - }); - - if let Some(err) = error { - return err; } - - Ok(( - conditions.into_iter().collect(), - session_properties.into_iter().collect::>(), - )) } /// Validate the definition of row level access control rules. @@ -118,8 +168,13 @@ pub fn validate_rlac_rule(rule: &RowLevelAccessControl, model: &Model) -> Result } /// Build the filter expression for the row level access control rule. +/// +/// `analyzed_mdl` is optional: when provided, the parser can resolve table references +/// inside the condition against MDL models, enabling subqueries that select from other +/// models. When `None`, only the current model's columns are visible to the parser. pub fn build_filter_expression( session_state: &SessionStateRef, + analyzed_mdl: Option>, model: Arc, properties: &SessionPropertiesRef, rule: &RowLevelAccessControl, @@ -198,9 +253,198 @@ pub fn build_filter_expression( } // The condition could contains the hidden columns, so we need to build the shcmea with hidden columns let df_schema = Dataset::Model(Arc::clone(&model)).to_qualified_schema(false)?; - session_state - .read() - .create_logical_expr(&expr.to_string(), &df_schema) + + // Build a ContextProvider that can resolve table references (in subqueries) against + // the MDL's models, in addition to delegating function/type/options lookups to the + // session state. This is required because DataFusion's default + // `SessionState::create_logical_expr` constructs a `SessionContextProvider` with an + // empty `tables` map and does not consult the session catalog, so any table reference + // inside the RLAC condition (e.g. in a scalar/IN/EXISTS subquery) fails to resolve. + let provider = RlacContextProvider::new(Arc::clone(session_state), analyzed_mdl); + let parser_options = build_parser_options(session_state); + let sql_to_rel = SqlToRel::new_with_options(&provider, parser_options); + sql_to_rel.sql_to_expr_with_alias(expr, &df_schema, &mut PlannerContext::new()) +} + +fn build_parser_options(session_state: &SessionStateRef) -> ParserOptions { + let state = session_state.read(); + let sql_parser = &state.config_options().sql_parser; + ParserOptions { + parse_float_as_decimal: sql_parser.parse_float_as_decimal, + enable_ident_normalization: sql_parser.enable_ident_normalization, + enable_options_value_normalization: sql_parser.enable_options_value_normalization, + support_varchar_with_length: sql_parser.support_varchar_with_length, + map_string_types_to_utf8view: sql_parser.map_string_types_to_utf8view, + collect_spans: sql_parser.collect_spans, + default_null_ordering: sql_parser.default_null_ordering.as_str().into(), + } +} + +/// A ContextProvider used while parsing RLAC condition expressions. +/// +/// It resolves table references against the [`AnalyzedWrenMDL`]'s models so that +/// subqueries inside an RLAC condition (e.g. `SELECT id FROM other_model WHERE ...`) +/// can reference other models in the manifest. All other lookups (functions, type +/// planner, options) delegate to the underlying [`SessionStateRef`]. +/// +/// The TableSource returned for an MDL model is a [`LogicalTableSource`] holding the +/// model's physical schema. The resulting `TableScan` is later rewritten to a +/// [`ModelPlanNode`](crate::logical_plan::analyze::plan::ModelPlanNode) by +/// `ModelAnalyzeRule` via natural subquery traversal, so the embedded model +/// participates in the same RLAC/CLAC/relationship machinery as the outer query. +pub struct RlacContextProvider { + session_state: SessionStateRef, + analyzed_mdl: Option>, + options: ConfigOptions, + type_planner: Option>, + expr_planners: Vec>, +} + +impl RlacContextProvider { + pub fn new( + session_state: SessionStateRef, + analyzed_mdl: Option>, + ) -> Self { + // Snapshot fields that the trait methods would otherwise need to borrow through + // a per-call RwLock guard. ExprPlanners are required for parsing array/map + // literals (e.g. RLAC property values like `[1,2,3]`); without them the parser + // emits `Could not plan array literal`. + let (options, expr_planners) = { + let guard = session_state.read(); + ( + guard.config_options().as_ref().clone(), + guard.expr_planners().to_vec(), + ) + }; + // WrenTypePlanner is stateless; the session-attached instance is not exposed + // through SessionState's public API on this DataFusion version, so we build a + // fresh one to mirror what `apply_wren_on_ctx` registers. + let type_planner: Option> = + Some(Arc::new(WrenTypePlanner::default())); + Self { + session_state, + analyzed_mdl, + options, + type_planner, + expr_planners, + } + } + + fn model_table_source(&self, name: &TableReference) -> Option> { + let mdl = self.analyzed_mdl.as_ref()?.wren_mdl(); + // Only resolve references that point at this MDL's catalog/schema (or are bare). + match name { + TableReference::Bare { .. } => {} + TableReference::Partial { schema, .. } => { + if schema.as_ref() != mdl.schema() { + return None; + } + } + TableReference::Full { + catalog, schema, .. + } => { + if catalog.as_ref() != mdl.catalog() || schema.as_ref() != mdl.schema() { + return None; + } + } + } + let model = mdl.get_model(name.table())?; + let columns: Vec> = model.get_physical_columns(false); + let schema: SchemaRef = create_schema(columns).ok()?; + Some(Arc::new(LogicalTableSource::new(schema))) + } +} + +impl ContextProvider for RlacContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + if let Some(source) = self.model_table_source(&name) { + return Ok(source); + } + plan_err!("table '{name}' not found") + } + + fn get_file_type(&self, _ext: &str) -> Result> { + Err(plan_datafusion_err!( + "Registered file types are not supported in RLAC conditions" + )) + } + + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + Err(plan_datafusion_err!( + "Table functions are not supported in RLAC conditions" + )) + } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } + + fn get_type_planner(&self) -> Option> { + self.type_planner.clone() + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.session_state + .read() + .scalar_functions() + .get(name) + .cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.session_state + .read() + .aggregate_functions() + .get(name) + .cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.session_state + .read() + .window_functions() + .get(name) + .cloned() + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn udf_names(&self) -> Vec { + self.session_state + .read() + .scalar_functions() + .keys() + .cloned() + .collect() + } + + fn udaf_names(&self) -> Vec { + self.session_state + .read() + .aggregate_functions() + .keys() + .cloned() + .collect() + } + + fn udwf_names(&self) -> Vec { + self.session_state + .read() + .window_functions() + .keys() + .cloned() + .collect() + } } fn parse_expr(expr: &str) -> Result { @@ -644,7 +888,8 @@ mod test { name: "test".to_string(), }; - let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + let expr = + build_filter_expression(&state, None, Arc::clone(&model), &headers, &rule)?; assert_snapshot!(expr_to_sql(&expr)?, @"m1.id = 1 AND m1.\"name\" = 'test'"); let rule = RowLevelAccessControl { @@ -656,7 +901,7 @@ mod test { name: "test".to_string(), }; - match build_filter_expression(&state, Arc::clone(&model), &headers, &rule) { + match build_filter_expression(&state, None, Arc::clone(&model), &headers, &rule) { Err(error) => { assert_snapshot!(error.to_string(), @"Error during planning: The session property not_found is required for `test` rule but not found in the session properties"); } @@ -676,7 +921,7 @@ mod test { "session_id".to_string(), Some("1".to_string()), )])); - match build_filter_expression(&state, Arc::clone(&model), &headers, &rule) { + match build_filter_expression(&state, None, Arc::clone(&model), &headers, &rule) { Err(error) => { assert_snapshot!(error.to_string(), @"Error during planning: The session property session_name is required for `test` rule but not found in the session properties"); } @@ -697,7 +942,8 @@ mod test { Some("1".to_string()), )])); - let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + let expr = + build_filter_expression(&state, None, Arc::clone(&model), &headers, &rule)?; assert_snapshot!(expr_to_sql(&expr)?, @"m1.id = 1 AND m1.\"name\" = 'test'"); Ok(()) @@ -731,7 +977,8 @@ mod test { name: "test".to_string(), }; - let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + let expr = + build_filter_expression(&state, None, Arc::clone(&model), &headers, &rule)?; assert_snapshot!(expr_to_sql(&expr)?, @"m1.id = 1 AND m1.\"name\" = 'test'"); Ok(()) } @@ -768,8 +1015,13 @@ mod test { &[("session_id".to_string(), Some(value.to_string()))], )); - let expr = - build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + let expr = build_filter_expression( + &state, + None, + Arc::clone(&model), + &headers, + &rule, + )?; expr_to_sql(&expr)?; } @@ -792,7 +1044,13 @@ mod test { &[("session_id".to_string(), Some(value.to_string()))], )); - match build_filter_expression(&state, Arc::clone(&model), &headers, &rule) { + match build_filter_expression( + &state, + None, + Arc::clone(&model), + &headers, + &rule, + ) { Err(_) => {} _ => panic!( "should be error: {}", diff --git a/core/wren-core/core/src/logical_plan/analyze/model_anlayze.rs b/core/wren-core/core/src/logical_plan/analyze/model_anlayze.rs index 50bd94c03c..790df3995a 100644 --- a/core/wren-core/core/src/logical_plan/analyze/model_anlayze.rs +++ b/core/wren-core/core/src/logical_plan/analyze/model_anlayze.rs @@ -8,7 +8,7 @@ use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion::common::{internal_err, plan_err, Column, DFSchemaRef, Result, Spans}; use datafusion::config::ConfigOptions; use datafusion::error::DataFusionError; -use datafusion::logical_expr::expr::Alias; +use datafusion::logical_expr::expr::{Alias, Exists, InSubquery}; use datafusion::logical_expr::{ col, ident, Aggregate, Distinct, DistinctOn, Expr, Extension, Filter, Join, LogicalPlan, LogicalPlanBuilder, Projection, Subquery, SubqueryAlias, TableScan, @@ -16,6 +16,7 @@ use datafusion::logical_expr::{ }; use datafusion::optimizer::AnalyzerRule; use datafusion::sql::TableReference; +use parking_lot::Mutex; use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; @@ -35,6 +36,24 @@ pub struct ModelAnalyzeRule { analyzed_wren_mdl: Arc, session_state: SessionStateRef, properties: SessionPropertiesRef, + /// Stack of model names currently being resolved through RLAC. Shared across + /// recursive `analyze_*` calls (including those triggered by subqueries inside an + /// RLAC condition) so we can detect cycles like A's RLAC referencing B whose RLAC + /// references A. + building_models: Arc>>, +} + +/// RAII guard that removes a model name from the `building_models` stack on drop, +/// regardless of how the surrounding function exits. +struct ModelStackGuard { + stack: Arc>>, + name: String, +} + +impl Drop for ModelStackGuard { + fn drop(&mut self) { + self.stack.lock().remove(&self.name); + } } impl Debug for ModelAnalyzeRule { @@ -45,6 +64,10 @@ impl Debug for ModelAnalyzeRule { impl AnalyzerRule for ModelAnalyzeRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + // Each top-level invocation starts with a clean cycle-detection stack so the + // rule instance can be reused across queries. + self.building_models.lock().clear(); + let mut scope_manager = ScopeManager::new(); let root_scope_id = scope_manager.create_root_scope(); @@ -80,6 +103,7 @@ impl ModelAnalyzeRule { analyzed_wren_mdl, session_state, properties, + building_models: Arc::new(Mutex::new(HashSet::new())), } } @@ -448,6 +472,129 @@ impl ModelAnalyzeRule { } } + /// Construct a [`ModelPlanNode`] with cycle-aware RLAC handling. + /// + /// Steps: + /// 1. Push the model's name onto the shared `building_models` stack; error out if + /// it is already present (which means an RLAC condition transitively re-entered + /// this model). An RAII guard removes the name when this function exits. + /// 2. Build the [`ModelPlanNode`] (the builder parses each matching RLAC condition + /// via [`crate::logical_plan::analyze::access_control::RlacContextProvider`], so + /// table references inside subqueries are resolved against MDL models). + /// 3. If the resulting `rlac_filter` contains subqueries, recursively run the same + /// analyzer pipeline on each subquery's inner plan with a fresh scope manager — + /// that rewrites their `TableScan`s into `ModelPlanNode`s so the referenced + /// models' own RLAC/CLAC and relationship handling apply transitively. + fn build_model_plan_node( + &self, + model: Arc, + required_fields: Vec, + original_table_scan: Option, + ) -> Result { + let model_name = model.name().to_string(); + { + let mut stack = self.building_models.lock(); + if stack.contains(&model_name) { + return plan_err!( + "Detected a cycle in row level access control conditions for model `{}`", + model_name + ); + } + stack.insert(model_name.clone()); + } + let _guard = ModelStackGuard { + stack: Arc::clone(&self.building_models), + name: model_name, + }; + + let mut plan_node = ModelPlanNode::new( + model, + required_fields, + original_table_scan, + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + Arc::clone(&self.properties), + )?; + + if let Some(filter) = plan_node.rlac_filter.take() { + plan_node.rlac_filter = Some(self.analyze_rlac_subqueries(filter)?); + } + Ok(plan_node) + } + + /// Walk `expr` and analyze the inner plan of every embedded subquery + /// (`ScalarSubquery`, `InSubquery`, `Exists`). Each inner plan is processed with a + /// fresh `ScopeManager`/scope id — RLAC subqueries are introduced after the outer + /// scope analysis runs, so they don't have entries in the outer `scope_manager`. + fn analyze_rlac_subqueries(&self, expr: Expr) -> Result { + expr.transform_down(|expr| -> Result> { + match expr { + Expr::ScalarSubquery(sq) => { + let plan = + self.analyze_subquery_plan(Arc::unwrap_or_clone(sq.subquery))?; + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns: sq.outer_ref_columns, + spans: sq.spans, + }))) + } + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => { + let plan = self + .analyze_subquery_plan(Arc::unwrap_or_clone(subquery.subquery))?; + Ok(Transformed::yes(Expr::InSubquery(InSubquery { + expr, + subquery: Subquery { + subquery: Arc::new(plan), + outer_ref_columns: subquery.outer_ref_columns, + spans: subquery.spans, + }, + negated, + }))) + } + Expr::Exists(Exists { subquery, negated }) => { + let plan = self + .analyze_subquery_plan(Arc::unwrap_or_clone(subquery.subquery))?; + Ok(Transformed::yes(Expr::Exists(Exists { + subquery: Subquery { + subquery: Arc::new(plan), + outer_ref_columns: subquery.outer_ref_columns, + spans: subquery.spans, + }, + negated, + }))) + } + other => Ok(Transformed::no(other)), + } + }) + .data() + } + + /// Re-run the analyzer pipeline (scope analysis → model rewriting → schema cleanup) + /// on an inner subquery plan with a fresh `ScopeManager`. Used to make `TableScan`s + /// introduced by RLAC condition parsing go through the same transformation as + /// regular query plans. + fn analyze_subquery_plan(&self, plan: LogicalPlan) -> Result { + let mut scope_manager = ScopeManager::new(); + let root_scope_id = scope_manager.create_root_scope(); + self.analyze_scope(plan, &mut scope_manager, root_scope_id)? + .map_data(|p| { + self.analyze_model(p, &mut scope_manager, root_scope_id) + .data() + })? + .map_data(|p| { + p.transform_up_with_subqueries(&|p| -> Result> { + self.remove_wren_catalog_schema_prefix_and_refresh_schema(p) + }) + .data() + })? + .map_data(|p| p.recompute_schema()) + .data() + } + fn analyze_table_scan( &self, analyzed_wren_mdl: Arc, @@ -483,15 +630,13 @@ impl ModelAnalyzeRule { }; vec![] }; + let model_plan_node = self.build_model_plan_node( + Arc::clone(&model), + field, + Some(LogicalPlan::TableScan(table_scan.clone())), + )?; let model_plan = LogicalPlan::Extension(Extension { - node: Arc::new(ModelPlanNode::new( - Arc::clone(&model), - field, - Some(LogicalPlan::TableScan(table_scan.clone())), - Arc::clone(&self.analyzed_wren_mdl), - Arc::clone(&self.session_state), - Arc::clone(&self.properties), - )?), + node: Arc::new(model_plan_node), }); let subquery = LogicalPlanBuilder::from(model_plan) .alias(quoted(model.name()))? @@ -543,15 +688,10 @@ impl ModelAnalyzeRule { }; vec![] }; + let model_plan_node = + self.build_model_plan_node(Arc::clone(&model), field, None)?; let model_plan = LogicalPlan::Extension(Extension { - node: Arc::new(ModelPlanNode::new( - Arc::clone(&model), - field, - None, - Arc::clone(&self.analyzed_wren_mdl), - Arc::clone(&self.session_state), - Arc::clone(&self.properties), - )?), + node: Arc::new(model_plan_node), }); let subquery = LogicalPlanBuilder::from(model_plan).alias(alias)?.build()?; diff --git a/core/wren-core/core/src/logical_plan/analyze/model_generation.rs b/core/wren-core/core/src/logical_plan/analyze/model_generation.rs index 4a33b424aa..728d7e56c0 100644 --- a/core/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/core/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -23,9 +23,6 @@ use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion::optimizer::analyzer::AnalyzerRule; use datafusion::physical_plan::internal_err; use datafusion::sql::TableReference; -use wren_core_base::mdl::RowLevelAccessControl; - -use super::access_control::{build_filter_expression, validate_rule}; pub const SOURCE_ALIAS: &str = "__source"; @@ -85,31 +82,10 @@ impl ModelGenerationRule { return plan_err!("Failed to generate source plan"); }; - let filters: Vec> = model_plan - .model - .row_level_access_controls() - .iter() - .map(|rule| { - self.generate_row_level_access_control_filter( - Arc::clone(&model_plan.model), - rule, - ) - }) - .collect::>()?; - let rls_filter = filters - .into_iter() - .reduce(|acc, filter| { - if let Some(acc) = acc { - if let Some(filter) = filter { - Some(acc.and(filter)) - } else { - Some(acc) - } - } else { - filter - } - }) - .flatten(); + // RLAC condition parsing has already happened during ModelAnalyzeRule — + // the combined filter (with subqueries already rewritten into + // ModelPlanNodes) lives on `ModelPlanNode::rlac_filter`. + let rls_filter = model_plan.rlac_filter.clone(); if !model_plan.required_exprs.is_empty() { builder = builder.project(projections)? @@ -290,24 +266,6 @@ impl ModelGenerationRule { _ => Ok(Transformed::yes(plan.recompute_schema()?)), } } - - fn generate_row_level_access_control_filter( - &self, - model: Arc, - rule: &RowLevelAccessControl, - ) -> Result> { - if validate_rule(&rule.name, &rule.required_properties, &self.properties)? { - let filter = build_filter_expression( - &self.session_state, - model, - &self.properties, - rule, - )?; - Ok(Some(filter)) - } else { - Ok(None) - } - } } impl Debug for ModelGenerationRule { diff --git a/core/wren-core/core/src/logical_plan/analyze/plan.rs b/core/wren-core/core/src/logical_plan/analyze/plan.rs index a0beb9f359..8eeca9d18d 100644 --- a/core/wren-core/core/src/logical_plan/analyze/plan.rs +++ b/core/wren-core/core/src/logical_plan/analyze/plan.rs @@ -35,7 +35,7 @@ use crate::mdl::utils::{ use crate::mdl::Dataset; use crate::mdl::{AnalyzedWrenMDL, ColumnReference, SessionStateRef}; -use super::access_control::{collect_condition, validate_rule}; +use super::access_control::{build_filter_expression, collect_condition, validate_rule}; #[derive(Debug)] pub(crate) enum WrenPlan { @@ -59,6 +59,13 @@ impl WrenPlan { /// [ModelPlanNode] is a logical plan node that represents a model. It contains the model name, /// required fields, and the relation chain that connects the model with other models. /// It only generates the top plan for the model, and the relation chain will generate the source plan. +/// +/// `rlac_filter` carries the pre-parsed Row Level Access Control filter for this model. +/// Parsing happens during [`ModelAnalyzeRule`] (rather than during model generation) so that +/// any subqueries embedded in the condition — e.g. `SELECT id FROM other_model WHERE ...` — +/// are visited by the same analyzer pass. That visitation rewrites their internal +/// `TableScan` references to `ModelPlanNode`s, allowing the referenced model's own RLAC +/// to apply transitively. #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub(crate) struct ModelPlanNode { pub(crate) model: Arc, @@ -66,6 +73,7 @@ pub(crate) struct ModelPlanNode { pub(crate) relation_chain: Box, schema_ref: DFSchemaRef, pub(crate) original_table_scan: Option, + pub(crate) rlac_filter: Option, } impl ModelPlanNode { @@ -411,6 +419,8 @@ impl ModelPlanNodeBuilder { ); } + let rlac_filter = self.build_rlac_filter(&model)?; + Ok(ModelPlanNode { model, required_exprs: self @@ -422,9 +432,41 @@ impl ModelPlanNodeBuilder { relation_chain: Box::new(relation_chain), schema_ref, original_table_scan, + rlac_filter, }) } + /// Build the combined RLAC filter expression for this model. + /// + /// Each rule whose required session properties are present contributes one `Expr`; all + /// matching rules are AND-combined. Returns `None` when no rule matches (e.g. all rules + /// are optional and the corresponding session properties are unset). + /// + /// Note: subqueries inside the parsed expression may contain `TableScan`s that point at + /// other Wren models — they remain unanalyzed here. [`ModelAnalyzeRule`] is responsible + /// for recursively analyzing those subqueries (and for detecting cycles between RLAC + /// rules that reference each other). + fn build_rlac_filter(&self, model: &Arc) -> Result> { + let mut combined: Option = None; + for rule in model.row_level_access_controls().iter() { + if !validate_rule(&rule.name, &rule.required_properties, &self.properties)? { + continue; + } + let expr = build_filter_expression( + &self.session_state, + Some(Arc::clone(&self.analyzed_wren_mdl)), + Arc::clone(model), + &self.properties, + rule, + )?; + combined = Some(match combined { + Some(acc) => acc.and(expr), + None => expr, + }); + } + Ok(combined) + } + fn add_required_columns_from_session_properties( &self, model: &Model, @@ -878,11 +920,21 @@ impl UserDefinedLogicalNodeCore for ModelPlanNode { } fn expressions(&self) -> Vec { - self.schema_ref + // First emit one column reference per output field so the node's schema width + // matches `expressions()` length (DataFusion uses this for validation). + // Then append the RLAC filter — exposing it here lets `map_subqueries` walk into + // any subqueries inside the filter so they get the same analyzer treatment as + // subqueries in the outer plan. + let mut exprs: Vec = self + .schema_ref .fields() .iter() .map(|field| col(field.name())) - .collect() + .collect(); + if let Some(filter) = &self.rlac_filter { + exprs.push(filter.clone()); + } + exprs } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -896,15 +948,41 @@ impl UserDefinedLogicalNodeCore for ModelPlanNode { fn with_exprs_and_inputs( &self, - _: Vec, + exprs: Vec, _: Vec, ) -> datafusion::common::Result { + // `expressions()` returns one `col(field)` per schema field followed by the + // optional RLAC filter. After DataFusion's traversal callbacks (e.g. + // `map_subqueries`) rewrite individual expressions, we need to put the updated + // RLAC filter back. The column refs are stable — only the trailing filter (if + // present) carries any meaningful change. + // + // Guard against a count mismatch: if DataFusion ever changes its traversal API + // and passes fewer expressions than we emitted, silently dropping `rlac_filter` + // would bypass row-level access control. Surface it as an internal error so the + // regression is caught immediately. + let field_count = self.schema_ref.fields().len(); + let rlac_filter = if self.rlac_filter.is_some() { + if exprs.len() <= field_count { + return internal_err!( + "ModelPlanNode::with_exprs_and_inputs received {} expressions but \ + expected at least {} (field count + 1 for rlac_filter); dropping \ + the filter would silently bypass row-level access control", + exprs.len(), + field_count + 1 + ); + } + exprs.get(field_count).cloned() + } else { + None + }; Ok(ModelPlanNode { model: self.model.clone(), required_exprs: self.required_exprs.clone(), relation_chain: self.relation_chain.clone(), schema_ref: self.schema_ref.clone(), original_table_scan: self.original_table_scan.clone(), + rlac_filter, }) } } diff --git a/core/wren-core/core/src/mdl/mod.rs b/core/wren-core/core/src/mdl/mod.rs index ae3bace2dc..e7697fd46d 100644 --- a/core/wren-core/core/src/mdl/mod.rs +++ b/core/wren-core/core/src/mdl/mod.rs @@ -3433,6 +3433,211 @@ mod test { Ok(()) } + /// RLAC condition references another model in a subquery. The subquery's TableScan + /// should be rewritten into a ModelPlanNode so the inner model's own RLAC/CLAC and + /// remote table reference apply transitively. + #[tokio::test] + async fn test_rlac_with_cross_model_subquery() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer_remote") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "allowed_customers", + vec![SessionProperty::new_required("session_user")], + "c_custkey IN (SELECT allowed_id FROM allowed WHERE allowed_user = @session_user)", + ) + .build(), + ) + .model( + ModelBuilder::new("allowed") + .table_reference("allowed_remote") + .column(ColumnBuilder::new("allowed_id", "int").build()) + .column(ColumnBuilder::new("allowed_user", "string").build()) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let sql = "SELECT c_custkey FROM customer"; + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'alice'".to_string()), + )])); + // The subquery's `allowed` reference must unparse to the remote table + // `allowed_remote`, proving ModelAnalyzeRule recursed into the RLAC subquery + // and rewrote its TableScan into a ModelPlanNode. + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer_remote AS __source) AS customer) AS customer WHERE customer.c_custkey IN (SELECT allowed.allowed_id FROM (SELECT allowed.allowed_id, allowed.allowed_user FROM (SELECT __source.allowed_id AS allowed_id, __source.allowed_user AS allowed_user FROM allowed_remote AS __source) AS allowed) AS allowed WHERE allowed.allowed_user = 'alice')) AS customer" + ); + Ok(()) + } + + /// When an RLAC subquery selects from a model that also has its own RLAC, that + /// inner model's RLAC must apply too — proving the rewritten ModelPlanNode inside + /// the subquery is processed by the same pipeline as the outer model. + #[tokio::test] + async fn test_rlac_subquery_applies_inner_rlac() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer_remote") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .add_row_level_access_control( + "by_allowed", + vec![SessionProperty::new_required("session_user")], + "c_custkey IN (SELECT allowed_id FROM allowed)", + ) + .build(), + ) + .model( + ModelBuilder::new("allowed") + .table_reference("allowed_remote") + .column(ColumnBuilder::new("allowed_id", "int").build()) + .column(ColumnBuilder::new("allowed_user", "string").build()) + .add_row_level_access_control( + "by_user", + vec![SessionProperty::new_required("session_user")], + "allowed_user = @session_user", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let sql = "SELECT c_custkey FROM customer"; + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'alice'".to_string()), + )])); + // The inner `allowed` model has its own RLAC filtering by `allowed_user = + // @session_user`; that filter must appear in the unparsed SQL because the + // subquery's TableScan was rewritten into a ModelPlanNode that goes through + // ModelGenerationRule's RLAC-application path. + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer_remote AS __source) AS customer) AS customer WHERE customer.c_custkey IN (SELECT allowed.allowed_id FROM (SELECT allowed.allowed_id, allowed.allowed_user FROM (SELECT allowed.allowed_id, allowed.allowed_user FROM (SELECT __source.allowed_id AS allowed_id, __source.allowed_user AS allowed_user FROM allowed_remote AS __source) AS allowed) AS allowed WHERE allowed.allowed_user = 'alice') AS allowed)) AS customer" + ); + Ok(()) + } + + /// Two models reference each other in their RLAC conditions. The analyzer must + /// detect the cycle and surface a planning error rather than recurse infinitely. + #[tokio::test] + async fn test_rlac_subquery_cycle_detected() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("a") + .table_reference("a_remote") + .column(ColumnBuilder::new("id", "int").build()) + .add_row_level_access_control( + "via_b", + vec![SessionProperty::new_required("session_x")], + "id IN (SELECT id FROM b)", + ) + .build(), + ) + .model( + ModelBuilder::new("b") + .table_reference("b_remote") + .column(ColumnBuilder::new("id", "int").build()) + .add_row_level_access_control( + "via_a", + vec![SessionProperty::new_required("session_x")], + "id IN (SELECT id FROM a)", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let sql = "SELECT id FROM a"; + let headers = Arc::new(build_headers(&[( + "session_x".to_string(), + Some("1".to_string()), + )])); + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await; + match result { + Err(e) => { + let msg = e.to_string(); + assert!( + msg.contains("cycle in row level access control"), + "expected cycle detection error, got: {msg}" + ); + } + Ok(plan) => panic!("expected cycle error, got plan: {plan}"), + } + Ok(()) + } + + /// A self-referential RLAC condition (subquery selecting from the same model) + /// must be rejected as a cycle rather than looping infinitely. + #[tokio::test] + async fn test_rlac_self_reference_is_cycle() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer_remote") + .column(ColumnBuilder::new("id", "int").build()) + .add_row_level_access_control( + "self_ref", + vec![SessionProperty::new_required("session_x")], + "id IN (SELECT id FROM customer)", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let sql = "SELECT id FROM customer"; + let headers = Arc::new(build_headers(&[( + "session_x".to_string(), + Some("1".to_string()), + )])); + match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await + { + Err(e) => { + let msg = e.to_string(); + assert!( + msg.contains("cycle in row level access control"), + "expected cycle detection error, got: {msg}" + ); + } + Ok(plan) => panic!("expected cycle error, got plan: {plan}"), + } + Ok(()) + } + #[tokio::test] async fn test_disable_eliminate_limit() -> Result<()> { let ctx = create_wren_ctx(None, None);