diff --git a/crates/polyglot-sql/src/dialects/mod.rs b/crates/polyglot-sql/src/dialects/mod.rs index c28320a..7070c70 100644 --- a/crates/polyglot-sql/src/dialects/mod.rs +++ b/crates/polyglot-sql/src/dialects/mod.rs @@ -543,6 +543,15 @@ where { use crate::expressions::BinaryOp; + // Helper macro to recurse into a single-argument AggFunc-based expression. + macro_rules! recurse_agg { + ($variant:ident, $f:expr) => {{ + let mut f = $f; + f.this = transform_recursive(f.this, transform_fn)?; + Expression::$variant(f) + }}; + } + // Helper macro to transform binary ops with Box macro_rules! transform_binary { ($variant:ident, $op:expr) => {{ @@ -954,14 +963,9 @@ where f.this = transform_recursive(f.this, transform_fn)?; Expression::Date(f) } - Expression::Stddev(mut f) => { - f.this = transform_recursive(f.this, transform_fn)?; - Expression::Stddev(f) - } - Expression::Variance(mut f) => { - f.this = transform_recursive(f.this, transform_fn)?; - Expression::Variance(f) - } + Expression::Stddev(f) => recurse_agg!(Stddev, f), + Expression::StddevSamp(f) => recurse_agg!(StddevSamp, f), + Expression::Variance(f) => recurse_agg!(Variance, f), // ===== BinaryFunc-based expressions ===== Expression::ModFunc(mut f) => { @@ -1528,19 +1532,41 @@ where Expression::Filter(f) } - // BitwiseOrAgg/BitwiseAndAgg/BitwiseXorAgg: recurse into the aggregate argument - Expression::BitwiseOrAgg(mut f) => { - f.this = transform_recursive(f.this, transform_fn)?; - Expression::BitwiseOrAgg(f) - } - Expression::BitwiseAndAgg(mut f) => { - f.this = transform_recursive(f.this, transform_fn)?; - Expression::BitwiseAndAgg(f) - } - Expression::BitwiseXorAgg(mut f) => { - f.this = transform_recursive(f.this, transform_fn)?; - Expression::BitwiseXorAgg(f) + // Aggregate functions (AggFunc-based): recurse into the aggregate argument. + // Note: Stddev, StddevSamp, Variance, and ArrayAgg are already handled above. + Expression::Sum(f) => recurse_agg!(Sum, f), + Expression::Avg(f) => recurse_agg!(Avg, f), + Expression::Min(f) => recurse_agg!(Min, f), + Expression::Max(f) => recurse_agg!(Max, f), + Expression::CountIf(f) => recurse_agg!(CountIf, f), + Expression::StddevPop(f) => recurse_agg!(StddevPop, f), + Expression::VarPop(f) => recurse_agg!(VarPop, f), + Expression::VarSamp(f) => recurse_agg!(VarSamp, f), + Expression::Median(f) => recurse_agg!(Median, f), + Expression::Mode(f) => recurse_agg!(Mode, f), + Expression::First(f) => recurse_agg!(First, f), + Expression::Last(f) => recurse_agg!(Last, f), + Expression::AnyValue(f) => recurse_agg!(AnyValue, f), + Expression::ApproxDistinct(f) => recurse_agg!(ApproxDistinct, f), + Expression::ApproxCountDistinct(f) => recurse_agg!(ApproxCountDistinct, f), + Expression::LogicalAnd(f) => recurse_agg!(LogicalAnd, f), + Expression::LogicalOr(f) => recurse_agg!(LogicalOr, f), + Expression::Skewness(f) => recurse_agg!(Skewness, f), + Expression::ArrayConcatAgg(f) => recurse_agg!(ArrayConcatAgg, f), + Expression::ArrayUniqueAgg(f) => recurse_agg!(ArrayUniqueAgg, f), + Expression::BoolXorAgg(f) => recurse_agg!(BoolXorAgg, f), + Expression::BitwiseOrAgg(f) => recurse_agg!(BitwiseOrAgg, f), + Expression::BitwiseAndAgg(f) => recurse_agg!(BitwiseAndAgg, f), + Expression::BitwiseXorAgg(f) => recurse_agg!(BitwiseXorAgg, f), + + // Count has its own struct with an Option `this` field + Expression::Count(mut c) => { + if let Some(this) = c.this.take() { + c.this = Some(transform_recursive(this, transform_fn)?); + } + Expression::Count(c) } + Expression::PipeOperator(mut pipe) => { pipe.this = transform_recursive(pipe.this, transform_fn)?; pipe.expression = transform_recursive(pipe.expression, transform_fn)?; diff --git a/crates/polyglot-sql/src/optimizer/qualify_columns.rs b/crates/polyglot-sql/src/optimizer/qualify_columns.rs index cf1f3fb..8e547bb 100644 --- a/crates/polyglot-sql/src/optimizer/qualify_columns.rs +++ b/crates/polyglot-sql/src/optimizer/qualify_columns.rs @@ -3144,6 +3144,50 @@ mod tests { sql.contains("t1.a"), "column should be qualified with table name: {sql}" ); + + // test that columns in agg functions also get qualified + let expr = parse("SELECT MAX(a) FROM raw.t1"); + let result = + qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify"); + let sql = gen(&result); + assert!( + sql.contains("t1.a"), + "column in function should be qualified with table name: {sql}" + ); + + // test that columns in scalar functions also get qualified + let expr = parse("SELECT ABS(a) FROM raw.t1"); + let result = + qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify"); + let sql = gen(&result); + assert!( + sql.contains("t1.a"), + "column in function should be qualified with table name: {sql}" + ); + } + + #[test] + fn test_qualify_columns_count_star() { + // COUNT(*) uses Count { this: None } — verify qualify_columns handles it without panic + let expr = parse("SELECT COUNT(*) FROM t1"); + + let mut schema = MappingSchema::new(); + schema + .add_table( + "t1", + &[("id".to_string(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + + let result = + qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify"); + let sql = gen(&result); + + assert!( + sql.contains("COUNT(*)"), + "COUNT(*) should be preserved: {sql}" + ); } #[test]