Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions core/wren-core/core/src/logical_plan/analyze/access_control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,39 @@ mod test {
unparser.expr_to_sql(expr).map(|sql| sql.to_string())
}

#[test]
pub fn test_build_filter_expression_with_bypass_function() -> Result<()> {
use crate::mdl::function::ByPassScalarUDF;
use datafusion::logical_expr::ScalarUDF;

// A function unknown to wren-core, registered as an inferred bypass UDF
// (as the manifest scan does), can be used inside an RLAC condition.
let ctx = SessionContext::new();
ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new_inferred(
"mask",
)));
let state = ctx.state_ref();
let model = ModelBuilder::new("m1")
.column(ColumnBuilder::new("id", "int").build())
.column(ColumnBuilder::new("name", "varchar").build())
.build();

let headers = Arc::new(build_headers(&[(
"session_name".to_string(),
Some("'test'".to_string()),
)]));

let rule = RowLevelAccessControl {
condition: "mask(name) = @session_name".to_string(),
required_properties: vec![SessionProperty::new_required("session_name")],
name: "test".to_string(),
};

let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?;
assert_snapshot!(expr_to_sql(&expr)?, @"mask(m1.\"name\") = 'test'");
Ok(())
}

#[test]
pub fn test_match_case_insensitive() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
24 changes: 18 additions & 6 deletions core/wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,24 @@ pub async fn apply_wren_on_ctx(
}

let type_planner = Arc::new(WrenTypePlanner::default());
let reset_default_catalog_schema = Arc::new(RwLock::new(
SessionStateBuilder::new_from_existing(ctx.state())
.with_config(config.clone())
.with_type_planner(type_planner)
.build(),
));
let mut base_state = SessionStateBuilder::new_from_existing(ctx.state())
.with_config(config.clone())
.with_type_planner(type_planner)
.build();
// Auto-register any function referenced in the manifest's expressions that is
// still unknown (after built-ins, dialect, and explicit remote functions) as
// an inferred bypass UDF. Registered into this locally-derived state — not the
// shared input `ctx` — so the registration is captured by the state snapshot
// the `ModelAnalyzeRule` resolves functions against without leaking into a
// context that may be reused across manifests. `apply_wren_on_ctx` is the
// choke point shared by every planning path (transform, Python session load,
// dry-run, permission analysis), so MDL authors can reference data-source-
// native functions without declaring each one.
crate::mdl::register_inferred_bypass_for_manifest(
&mut base_state,
&analyzed_mdl.wren_mdl.manifest,
)?;
let reset_default_catalog_schema = Arc::new(RwLock::new(base_state));

let new_state = SessionStateBuilder::new_from_existing(
reset_default_catalog_schema.clone().read().deref().clone(),
Expand Down
127 changes: 127 additions & 0 deletions core/wren-core/core/src/mdl/function/remote_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use datafusion::common::{internal_err, not_impl_err};
use datafusion::logical_expr::function::{
AccumulatorArgs, PartitionEvaluatorArgs, WindowUDFFieldArgs,
};
use datafusion::logical_expr::type_coercion::binary::type_union_resolution;
use datafusion::logical_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, DocSection, Documentation,
DocumentationBuilder, PartitionEvaluator, ScalarFunctionArgs, ScalarUDFImpl,
Expand Down Expand Up @@ -112,6 +113,12 @@ pub enum ReturnType {
SameAsInputFirstArrayElement,
/// The return type is the array of the first argument type
ArrayOfInputFirstArgument,
/// The return type is inferred from the actual argument types at planning time.
/// Used for functions that are auto-registered as bypass UDFs (e.g. unknown
/// functions found in MDL expressions) where no return type is declared.
/// Inference rule: no args -> Utf8; one arg -> that arg's type; multiple args ->
/// the common super-type via type coercion, falling back to the first arg, then Utf8.
Inferred,
}

impl Display for ReturnType {
Expand All @@ -125,6 +132,7 @@ impl Display for ReturnType {
ReturnType::ArrayOfInputFirstArgument => {
write!(f, "array_of_input_first_argument")
}
ReturnType::Inferred => write!(f, "inferred"),
}
}
}
Expand All @@ -138,13 +146,23 @@ impl FromStr for ReturnType {
Ok(ReturnType::SameAsInputFirstArrayElement)
}
"array_of_input_first_argument" => Ok(ReturnType::ArrayOfInputFirstArgument),
"inferred" => Ok(ReturnType::Inferred),
_ => map_data_type(s)
.map(ReturnType::Specific)
.map_err(|e| e.to_string()),
}
}
}

/// Signature accepting either no arguments or any variadic arguments, used by
/// bypass UDFs that are never executed and only need to pass planning.
fn bypass_any_signature() -> Signature {
Signature::one_of(
vec![TypeSignature::Nullary, TypeSignature::VariadicAny],
Volatility::Volatile,
)
}

impl ReturnType {
pub fn to_data_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(match self {
Expand All @@ -168,6 +186,13 @@ impl ReturnType {
}
DataType::List(Arc::new(Field::new("item", arg_types[0].clone(), true)))
}
ReturnType::Inferred => match arg_types {
[] => DataType::Utf8,
[single] => single.clone(),
many => type_union_resolution(many)
.or_else(|| many.first().cloned())
.unwrap_or(DataType::Utf8),
},
})
}
}
Expand Down Expand Up @@ -249,6 +274,29 @@ impl ByPassScalarUDF {
pub fn original_name(&self) -> Option<&str> {
self.original_name.as_deref()
}

/// A bypass scalar UDF whose return type is inferred from the argument types
/// at planning time. Accepts any arguments and is never executed; it only
/// allows an otherwise-unknown function to pass planning and be unparsed back.
///
/// The original (case-sensitive) name is kept for SQL generation while a
/// lowercase alias is added for DataFusion's name resolution during parsing.
pub fn new_inferred(name: &str) -> Self {
let normalized = name.to_lowercase();
let aliases = if name != normalized {
vec![normalized]
} else {
vec![]
};
Self {
name: name.to_string(),
original_name: Some(name.to_string()),
aliases,
return_type: ReturnType::Inferred,
signature: bypass_any_signature(),
doc: None,
}
}
}

impl From<RemoteFunction> for ByPassScalarUDF {
Expand Down Expand Up @@ -423,6 +471,9 @@ impl AggregateUDFImpl for ByPassAggregateUDF {
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ByPassWindowFunction {
name: String,
/// Aliases for the function, including the lowercase name used for parsing
/// when it differs from the original (case-sensitive) name.
aliases: Vec<String>,
return_type: ReturnType,
signature: Signature,
doc: Option<Documentation>,
Expand All @@ -437,6 +488,7 @@ impl ByPassWindowFunction {
) -> Self {
Self {
name: name.to_string(),
aliases: vec![],
return_type,
signature,
doc,
Expand All @@ -446,6 +498,7 @@ impl ByPassWindowFunction {
pub fn new_with_return_type(name: &str, return_type: DataType) -> Self {
Self {
name: name.to_string(),
aliases: vec![],
return_type: ReturnType::Specific(return_type),
signature: Signature::one_of(
vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
Expand All @@ -454,6 +507,27 @@ impl ByPassWindowFunction {
doc: None,
}
}

/// A bypass window UDF whose return type is inferred from the argument types
/// at planning time. See [`ByPassScalarUDF::new_inferred`].
///
/// The original (case-sensitive) name is kept for SQL generation while a
/// lowercase alias is added for DataFusion's name resolution during parsing.
pub fn new_inferred(name: &str) -> Self {
let normalized = name.to_lowercase();
let aliases = if name != normalized {
vec![normalized]
} else {
vec![]
};
Self {
name: name.to_string(),
aliases,
return_type: ReturnType::Inferred,
signature: bypass_any_signature(),
doc: None,
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

impl From<RemoteFunction> for ByPassWindowFunction {
Expand All @@ -465,6 +539,7 @@ impl From<RemoteFunction> for ByPassWindowFunction {
signature: func.get_signature(),
doc: Some(build_document(&func)),
name: func.name,
aliases: vec![],
}
}
}
Expand All @@ -474,6 +549,10 @@ impl WindowUDFImpl for ByPassWindowFunction {
self
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn name(&self) -> &str {
&self.name
}
Expand Down Expand Up @@ -510,6 +589,7 @@ mod test {
use std::slice::from_ref;
use std::sync::Arc;

use crate::mdl::function::remote_function::ReturnType;
use crate::mdl::function::{ByPassScalarUDF, FunctionType, RemoteFunction};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::types::logical_string;
Expand All @@ -518,6 +598,53 @@ mod test {
use datafusion::logical_expr::TypeSignatureClass;
use datafusion::logical_expr::{Coercion, TypeSignature};

#[test]
fn test_inferred_return_type() {
let inferred = ReturnType::Inferred;
// No arguments -> Utf8 fallback.
assert_eq!(inferred.to_data_type(&[]).unwrap(), DataType::Utf8);
// One argument -> that argument's type.
assert_eq!(
inferred.to_data_type(&[DataType::Int32]).unwrap(),
DataType::Int32
);
// Multiple compatible arguments -> common super-type.
assert_eq!(
inferred
.to_data_type(&[DataType::Int32, DataType::Int64])
.unwrap(),
DataType::Int64
);
// A bypass UDF built with new_inferred carries the inferred return type.
let udf = ByPassScalarUDF::new_inferred("my_fn");
assert_eq!(udf.name(), "my_fn");
assert_eq!(
udf.return_type(&[DataType::Float64]).unwrap(),
DataType::Float64
);
}

#[test]
fn test_inferred_bypass_preserves_casing() {
use crate::mdl::function::ByPassWindowFunction;
use datafusion::logical_expr::WindowUDFImpl;

// Mixed-case names keep their original spelling for SQL generation and add
// a lowercase alias for DataFusion's name resolution during parsing.
let scalar = ByPassScalarUDF::new_inferred("toYear");
assert_eq!(scalar.name(), "toYear");
assert_eq!(scalar.aliases(), ["toyear".to_string()]);

let window = ByPassWindowFunction::new_inferred("MyWinFn");
assert_eq!(window.name(), "MyWinFn");
assert_eq!(window.aliases(), ["mywinfn".to_string()]);

// Already-lowercase names need no alias.
let lower = ByPassWindowFunction::new_inferred("rank_fn");
assert_eq!(lower.name(), "rank_fn");
assert!(lower.aliases().is_empty());
}

#[tokio::test]
async fn test_remote_function_to_bypass_func() -> Result<()> {
// full information
Expand Down
Loading
Loading