diff --git a/core/wren-core/core/src/logical_plan/analyze/access_control.rs b/core/wren-core/core/src/logical_plan/analyze/access_control.rs index f936bc9cb0..5d795c71ec 100644 --- a/core/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/core/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -708,6 +708,39 @@ mod test { unparser.expr_to_sql(expr).map(|sql| sql.to_string()) } + #[test] + pub fn test_build_filter_expression_with_bypass_function() -> Result<()> { + use crate::mdl::function::ByPassScalarUDF; + use datafusion::logical_expr::ScalarUDF; + + // A function unknown to wren-core, registered as an inferred bypass UDF + // (as the manifest scan does), can be used inside an RLAC condition. + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new_inferred( + "mask", + ))); + let state = ctx.state_ref(); + let model = ModelBuilder::new("m1") + .column(ColumnBuilder::new("id", "int").build()) + .column(ColumnBuilder::new("name", "varchar").build()) + .build(); + + let headers = Arc::new(build_headers(&[( + "session_name".to_string(), + Some("'test'".to_string()), + )])); + + let rule = RowLevelAccessControl { + condition: "mask(name) = @session_name".to_string(), + required_properties: vec![SessionProperty::new_required("session_name")], + name: "test".to_string(), + }; + + let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + assert_snapshot!(expr_to_sql(&expr)?, @"mask(m1.\"name\") = 'test'"); + Ok(()) + } + #[test] pub fn test_match_case_insensitive() -> Result<()> { let ctx = SessionContext::new(); diff --git a/core/wren-core/core/src/mdl/context.rs b/core/wren-core/core/src/mdl/context.rs index 0dbc67100d..9db9a811c8 100644 --- a/core/wren-core/core/src/mdl/context.rs +++ b/core/wren-core/core/src/mdl/context.rs @@ -75,12 +75,24 @@ pub async fn apply_wren_on_ctx( } let type_planner = Arc::new(WrenTypePlanner::default()); - let reset_default_catalog_schema = Arc::new(RwLock::new( - SessionStateBuilder::new_from_existing(ctx.state()) - .with_config(config.clone()) - .with_type_planner(type_planner) - .build(), - )); + let mut base_state = SessionStateBuilder::new_from_existing(ctx.state()) + .with_config(config.clone()) + .with_type_planner(type_planner) + .build(); + // Auto-register any function referenced in the manifest's expressions that is + // still unknown (after built-ins, dialect, and explicit remote functions) as + // an inferred bypass UDF. Registered into this locally-derived state — not the + // shared input `ctx` — so the registration is captured by the state snapshot + // the `ModelAnalyzeRule` resolves functions against without leaking into a + // context that may be reused across manifests. `apply_wren_on_ctx` is the + // choke point shared by every planning path (transform, Python session load, + // dry-run, permission analysis), so MDL authors can reference data-source- + // native functions without declaring each one. + crate::mdl::register_inferred_bypass_for_manifest( + &mut base_state, + &analyzed_mdl.wren_mdl.manifest, + )?; + let reset_default_catalog_schema = Arc::new(RwLock::new(base_state)); let new_state = SessionStateBuilder::new_from_existing( reset_default_catalog_schema.clone().read().deref().clone(), diff --git a/core/wren-core/core/src/mdl/function/remote_function.rs b/core/wren-core/core/src/mdl/function/remote_function.rs index 9f5bfa68c6..42c3903a11 100644 --- a/core/wren-core/core/src/mdl/function/remote_function.rs +++ b/core/wren-core/core/src/mdl/function/remote_function.rs @@ -4,6 +4,7 @@ use datafusion::common::{internal_err, not_impl_err}; use datafusion::logical_expr::function::{ AccumulatorArgs, PartitionEvaluatorArgs, WindowUDFFieldArgs, }; +use datafusion::logical_expr::type_coercion::binary::type_union_resolution; use datafusion::logical_expr::{ Accumulator, AggregateUDFImpl, ColumnarValue, DocSection, Documentation, DocumentationBuilder, PartitionEvaluator, ScalarFunctionArgs, ScalarUDFImpl, @@ -112,6 +113,12 @@ pub enum ReturnType { SameAsInputFirstArrayElement, /// The return type is the array of the first argument type ArrayOfInputFirstArgument, + /// The return type is inferred from the actual argument types at planning time. + /// Used for functions that are auto-registered as bypass UDFs (e.g. unknown + /// functions found in MDL expressions) where no return type is declared. + /// Inference rule: no args -> Utf8; one arg -> that arg's type; multiple args -> + /// the common super-type via type coercion, falling back to the first arg, then Utf8. + Inferred, } impl Display for ReturnType { @@ -125,6 +132,7 @@ impl Display for ReturnType { ReturnType::ArrayOfInputFirstArgument => { write!(f, "array_of_input_first_argument") } + ReturnType::Inferred => write!(f, "inferred"), } } } @@ -138,6 +146,7 @@ impl FromStr for ReturnType { Ok(ReturnType::SameAsInputFirstArrayElement) } "array_of_input_first_argument" => Ok(ReturnType::ArrayOfInputFirstArgument), + "inferred" => Ok(ReturnType::Inferred), _ => map_data_type(s) .map(ReturnType::Specific) .map_err(|e| e.to_string()), @@ -145,6 +154,15 @@ impl FromStr for ReturnType { } } +/// Signature accepting either no arguments or any variadic arguments, used by +/// bypass UDFs that are never executed and only need to pass planning. +fn bypass_any_signature() -> Signature { + Signature::one_of( + vec![TypeSignature::Nullary, TypeSignature::VariadicAny], + Volatility::Volatile, + ) +} + impl ReturnType { pub fn to_data_type(&self, arg_types: &[DataType]) -> Result { Ok(match self { @@ -168,6 +186,13 @@ impl ReturnType { } DataType::List(Arc::new(Field::new("item", arg_types[0].clone(), true))) } + ReturnType::Inferred => match arg_types { + [] => DataType::Utf8, + [single] => single.clone(), + many => type_union_resolution(many) + .or_else(|| many.first().cloned()) + .unwrap_or(DataType::Utf8), + }, }) } } @@ -249,6 +274,29 @@ impl ByPassScalarUDF { pub fn original_name(&self) -> Option<&str> { self.original_name.as_deref() } + + /// A bypass scalar UDF whose return type is inferred from the argument types + /// at planning time. Accepts any arguments and is never executed; it only + /// allows an otherwise-unknown function to pass planning and be unparsed back. + /// + /// The original (case-sensitive) name is kept for SQL generation while a + /// lowercase alias is added for DataFusion's name resolution during parsing. + pub fn new_inferred(name: &str) -> Self { + let normalized = name.to_lowercase(); + let aliases = if name != normalized { + vec![normalized] + } else { + vec![] + }; + Self { + name: name.to_string(), + original_name: Some(name.to_string()), + aliases, + return_type: ReturnType::Inferred, + signature: bypass_any_signature(), + doc: None, + } + } } impl From for ByPassScalarUDF { @@ -423,6 +471,9 @@ impl AggregateUDFImpl for ByPassAggregateUDF { #[derive(Debug, PartialEq, Eq, Hash)] pub struct ByPassWindowFunction { name: String, + /// Aliases for the function, including the lowercase name used for parsing + /// when it differs from the original (case-sensitive) name. + aliases: Vec, return_type: ReturnType, signature: Signature, doc: Option, @@ -437,6 +488,7 @@ impl ByPassWindowFunction { ) -> Self { Self { name: name.to_string(), + aliases: vec![], return_type, signature, doc, @@ -446,6 +498,7 @@ impl ByPassWindowFunction { pub fn new_with_return_type(name: &str, return_type: DataType) -> Self { Self { name: name.to_string(), + aliases: vec![], return_type: ReturnType::Specific(return_type), signature: Signature::one_of( vec![TypeSignature::VariadicAny, TypeSignature::Nullary], @@ -454,6 +507,27 @@ impl ByPassWindowFunction { doc: None, } } + + /// A bypass window UDF whose return type is inferred from the argument types + /// at planning time. See [`ByPassScalarUDF::new_inferred`]. + /// + /// The original (case-sensitive) name is kept for SQL generation while a + /// lowercase alias is added for DataFusion's name resolution during parsing. + pub fn new_inferred(name: &str) -> Self { + let normalized = name.to_lowercase(); + let aliases = if name != normalized { + vec![normalized] + } else { + vec![] + }; + Self { + name: name.to_string(), + aliases, + return_type: ReturnType::Inferred, + signature: bypass_any_signature(), + doc: None, + } + } } impl From for ByPassWindowFunction { @@ -465,6 +539,7 @@ impl From for ByPassWindowFunction { signature: func.get_signature(), doc: Some(build_document(&func)), name: func.name, + aliases: vec![], } } } @@ -474,6 +549,10 @@ impl WindowUDFImpl for ByPassWindowFunction { self } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn name(&self) -> &str { &self.name } @@ -510,6 +589,7 @@ mod test { use std::slice::from_ref; use std::sync::Arc; + use crate::mdl::function::remote_function::ReturnType; use crate::mdl::function::{ByPassScalarUDF, FunctionType, RemoteFunction}; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::types::logical_string; @@ -518,6 +598,53 @@ mod test { use datafusion::logical_expr::TypeSignatureClass; use datafusion::logical_expr::{Coercion, TypeSignature}; + #[test] + fn test_inferred_return_type() { + let inferred = ReturnType::Inferred; + // No arguments -> Utf8 fallback. + assert_eq!(inferred.to_data_type(&[]).unwrap(), DataType::Utf8); + // One argument -> that argument's type. + assert_eq!( + inferred.to_data_type(&[DataType::Int32]).unwrap(), + DataType::Int32 + ); + // Multiple compatible arguments -> common super-type. + assert_eq!( + inferred + .to_data_type(&[DataType::Int32, DataType::Int64]) + .unwrap(), + DataType::Int64 + ); + // A bypass UDF built with new_inferred carries the inferred return type. + let udf = ByPassScalarUDF::new_inferred("my_fn"); + assert_eq!(udf.name(), "my_fn"); + assert_eq!( + udf.return_type(&[DataType::Float64]).unwrap(), + DataType::Float64 + ); + } + + #[test] + fn test_inferred_bypass_preserves_casing() { + use crate::mdl::function::ByPassWindowFunction; + use datafusion::logical_expr::WindowUDFImpl; + + // Mixed-case names keep their original spelling for SQL generation and add + // a lowercase alias for DataFusion's name resolution during parsing. + let scalar = ByPassScalarUDF::new_inferred("toYear"); + assert_eq!(scalar.name(), "toYear"); + assert_eq!(scalar.aliases(), ["toyear".to_string()]); + + let window = ByPassWindowFunction::new_inferred("MyWinFn"); + assert_eq!(window.name(), "MyWinFn"); + assert_eq!(window.aliases(), ["mywinfn".to_string()]); + + // Already-lowercase names need no alias. + let lower = ByPassWindowFunction::new_inferred("rank_fn"); + assert_eq!(lower.name(), "rank_fn"); + assert!(lower.aliases().is_empty()); + } + #[tokio::test] async fn test_remote_function_to_bypass_func() -> Result<()> { // full information diff --git a/core/wren-core/core/src/mdl/mod.rs b/core/wren-core/core/src/mdl/mod.rs index 157f45189e..d9cadc642d 100644 --- a/core/wren-core/core/src/mdl/mod.rs +++ b/core/wren-core/core/src/mdl/mod.rs @@ -492,6 +492,9 @@ pub async fn transform_sql_with_ctx( register_remote_function(ctx, remote_function)?; Ok::<_, DataFusionError>(()) })?; + // Note: unknown functions referenced in the manifest's expressions are + // auto-registered as inferred bypass UDFs inside `apply_wren_on_ctx`, which is + // the choke point shared by every planning path (transform, load, dry-run). let ctx = apply_wren_on_ctx( ctx, Arc::clone(&analyzed_mdl), @@ -630,6 +633,129 @@ fn register_remote_function( Ok(()) } +/// Scan all free-form SQL expressions in the manifest — calculated column +/// expressions, row-level access control conditions, cube +/// measure/dimension/time-dimension expressions, and relationship conditions — +/// for function calls that are not registered in the context, and register each +/// unknown function as a bypass UDF whose return type is inferred from the +/// argument types at planning time. +/// +/// This lets MDL authors reference data-source-native functions that wren-core +/// does not know about (and pass them through to the unparser) without declaring +/// every one in the remote-functions list. Functions already registered — built +/// ins, dialect functions, and explicitly declared remote functions — are left +/// untouched, so explicit declarations always win. +/// +/// Aggregate functions cannot be reliably distinguished from scalar functions by +/// syntax alone, so unknown functions are registered as scalar UDFs unless they +/// carry an `OVER (...)` clause (registered as window UDFs). Unknown aggregates +/// still need an explicit remote-function declaration. +/// +/// Registers into the given [`SessionState`] (rather than a shared +/// [`SessionContext`]) so the caller can seed a derived state without leaking +/// inferred functions back into a context that may be reused across manifests. +pub(crate) fn register_inferred_bypass_for_manifest( + state: &mut SessionState, + manifest: &Manifest, +) -> Result<()> { + use datafusion::execution::FunctionRegistry; + use datafusion::sql::parser::DFParserBuilder; + use datafusion::sql::sqlparser::ast::{self, visit_expressions}; + use datafusion::sql::sqlparser::dialect::GenericDialect; + use std::collections::HashSet; + use std::ops::ControlFlow; + + // Collect every free-form expression string in the manifest. + let mut expressions: Vec<&str> = vec![]; + for model in &manifest.models { + for column in &model.columns { + if let Some(expr) = column.expression.as_deref() { + if !expr.is_empty() { + expressions.push(expr); + } + } + } + for rule in &model.row_level_access_controls { + expressions.push(rule.condition.as_str()); + } + } + for cube in &manifest.cubes { + for measure in &cube.measures { + expressions.push(measure.expression.as_str()); + } + for dimension in &cube.dimensions { + expressions.push(dimension.expression.as_str()); + } + for time_dimension in &cube.time_dimensions { + expressions.push(time_dimension.expression.as_str()); + } + } + for relationship in &manifest.relationships { + expressions.push(relationship.condition.as_str()); + } + + // Names already registered as scalar, aggregate, or window functions + // (built-ins, dialect functions, explicit remote functions). Registry keys are + // lowercase, matching DataFusion's name resolution. Snapshot them up front so + // the read does not borrow `state` while we later register into it. + let mut preexisting: HashSet = HashSet::new(); + preexisting.extend(state.scalar_functions().keys().cloned()); + preexisting.extend(state.aggregate_functions().keys().cloned()); + preexisting.extend(state.window_functions().keys().cloned()); + + // Scalar and window auto-registrations are tracked separately so discovering + // `foo(...)` as a scalar does not mask a later `foo(...) OVER (...)` window + // occurrence (and vice versa). Original casing is kept for SQL generation. + let dialect = GenericDialect {}; + let mut scalar_to_add: Vec = vec![]; + let mut window_to_add: Vec = vec![]; + let mut seen_scalar: HashSet = HashSet::new(); + let mut seen_window: HashSet = HashSet::new(); + for expr_sql in expressions { + // Skip expressions that don't parse; the analyzer reports those with a + // clearer error than we could here. + let Ok(mut parser) = DFParserBuilder::new(expr_sql) + .with_dialect(&dialect) + .build() + else { + continue; + }; + let Ok(expr) = parser.parse_expr() else { + continue; + }; + let _ = visit_expressions(&expr, |node| { + if let ast::Expr::Function(func) = node { + if let Some(ident) = func.name.0.last().and_then(|p| p.as_ident()) { + let original = ident.value.as_str(); + let normalized = original.to_lowercase(); + if !preexisting.contains(&normalized) { + if func.over.is_some() { + if seen_window.insert(normalized.clone()) { + window_to_add.push(original.to_string()); + } + } else if seen_scalar.insert(normalized.clone()) { + scalar_to_add.push(original.to_string()); + } + } + } + } + ControlFlow::<()>::Continue(()) + }); + } + + for name in scalar_to_add { + state.register_udf(Arc::new(ScalarUDF::new_from_impl( + ByPassScalarUDF::new_inferred(&name), + )))?; + } + for name in window_to_add { + state.register_udwf(Arc::new(WindowUDF::new_from_impl( + ByPassWindowFunction::new_inferred(&name), + )))?; + } + Ok(()) +} + /// Analyze the decision point. It's same as the /v1/analysis/sql API in wren engine pub fn decision_point_analyze(_wren_mdl: Arc, _sql: &str) {} @@ -922,6 +1048,88 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_inferred_bypass_unknown_function_in_expression() -> Result<()> { + // A column expression references `custom_mask`, a function wren-core does + // not know about and that is NOT declared as a remote function. It should + // be auto-registered as an inferred bypass UDF and passed through to the + // unparsed SQL with its original casing preserved. + let manifest = ManifestBuilder::new() + .catalog("CTest") + .schema("STest") + .model( + ModelBuilder::new("Customer") + .table_reference("datafusion.public.customer") + .column(ColumnBuilder::new("Custkey", "int").build()) + .column(ColumnBuilder::new("Name", "string").build()) + .column( + ColumnBuilder::new("Masked", "string") + .expression(r#"custom_mask("Name")"#) + .build(), + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let ctx = create_wren_ctx(None, analyzed_mdl.wren_mdl().data_source().as_ref()); + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + Arc::new(HashMap::new()), + r#"select "Masked" from "CTest"."STest"."Customer""#, + ) + .await?; + assert!( + actual.contains("custom_mask"), + "expected unknown function to be passed through, got: {actual}" + ); + Ok(()) + } + + #[test] + fn test_inferred_bypass_registers_scalar_and_window_independently() -> Result<()> { + // The same name appears as a plain call and as a window call. Both the + // scalar and the window bypass UDF must be registered — neither should + // mask the other. + let manifest = ManifestBuilder::new() + .catalog("CTest") + .schema("STest") + .model( + ModelBuilder::new("Customer") + .table_reference("datafusion.public.customer") + .column(ColumnBuilder::new("Custkey", "int").build()) + .column(ColumnBuilder::new("Name", "string").build()) + .column( + ColumnBuilder::new("Scalar", "string") + .expression(r#"foo("Name")"#) + .build(), + ) + .column( + ColumnBuilder::new("Windowed", "string") + .expression(r#"foo("Custkey") OVER (PARTITION BY "Name")"#) + .build(), + ) + .build(), + ) + .build(); + let mut state = create_wren_ctx(None, None).state(); + super::register_inferred_bypass_for_manifest(&mut state, &manifest)?; + assert!( + state.scalar_functions().contains_key("foo"), + "scalar bypass should be registered" + ); + assert!( + state.window_functions().contains_key("foo"), + "window bypass should be registered" + ); + Ok(()) + } + #[tokio::test] async fn test_unicode_remote_column_name() -> Result<()> { let ctx = create_wren_ctx(None, None);