Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions crates/polyglot-sql/src/dialects/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ where
{
use crate::expressions::BinaryOp;

// Helper macro to recurse into a single-argument AggFunc-based expression.
macro_rules! recurse_agg {
($variant:ident, $f:expr) => {{
let mut f = $f;
f.this = transform_recursive(f.this, transform_fn)?;
Expression::$variant(f)
}};
}

// Helper macro to transform binary ops with Box<BinaryOp>
macro_rules! transform_binary {
($variant:ident, $op:expr) => {{
Expand Down Expand Up @@ -954,14 +963,9 @@ where
f.this = transform_recursive(f.this, transform_fn)?;
Expression::Date(f)
}
Expression::Stddev(mut f) => {
f.this = transform_recursive(f.this, transform_fn)?;
Expression::Stddev(f)
}
Expression::Variance(mut f) => {
f.this = transform_recursive(f.this, transform_fn)?;
Expression::Variance(f)
}
Expression::Stddev(f) => recurse_agg!(Stddev, f),
Expression::StddevSamp(f) => recurse_agg!(StddevSamp, f),
Expression::Variance(f) => recurse_agg!(Variance, f),

// ===== BinaryFunc-based expressions =====
Expression::ModFunc(mut f) => {
Expand Down Expand Up @@ -1528,19 +1532,41 @@ where
Expression::Filter(f)
}

// BitwiseOrAgg/BitwiseAndAgg/BitwiseXorAgg: recurse into the aggregate argument
Expression::BitwiseOrAgg(mut f) => {
f.this = transform_recursive(f.this, transform_fn)?;
Expression::BitwiseOrAgg(f)
}
Expression::BitwiseAndAgg(mut f) => {
f.this = transform_recursive(f.this, transform_fn)?;
Expression::BitwiseAndAgg(f)
}
Expression::BitwiseXorAgg(mut f) => {
f.this = transform_recursive(f.this, transform_fn)?;
Expression::BitwiseXorAgg(f)
// Aggregate functions (AggFunc-based): recurse into the aggregate argument.
// Note: Stddev, StddevSamp, Variance, and ArrayAgg are already handled above.
Expression::Sum(f) => recurse_agg!(Sum, f),
Expression::Avg(f) => recurse_agg!(Avg, f),
Expression::Min(f) => recurse_agg!(Min, f),
Expression::Max(f) => recurse_agg!(Max, f),
Expression::CountIf(f) => recurse_agg!(CountIf, f),
Expression::StddevPop(f) => recurse_agg!(StddevPop, f),
Expression::VarPop(f) => recurse_agg!(VarPop, f),
Expression::VarSamp(f) => recurse_agg!(VarSamp, f),
Expression::Median(f) => recurse_agg!(Median, f),
Expression::Mode(f) => recurse_agg!(Mode, f),
Expression::First(f) => recurse_agg!(First, f),
Expression::Last(f) => recurse_agg!(Last, f),
Expression::AnyValue(f) => recurse_agg!(AnyValue, f),
Expression::ApproxDistinct(f) => recurse_agg!(ApproxDistinct, f),
Expression::ApproxCountDistinct(f) => recurse_agg!(ApproxCountDistinct, f),
Expression::LogicalAnd(f) => recurse_agg!(LogicalAnd, f),
Expression::LogicalOr(f) => recurse_agg!(LogicalOr, f),
Expression::Skewness(f) => recurse_agg!(Skewness, f),
Expression::ArrayConcatAgg(f) => recurse_agg!(ArrayConcatAgg, f),
Expression::ArrayUniqueAgg(f) => recurse_agg!(ArrayUniqueAgg, f),
Expression::BoolXorAgg(f) => recurse_agg!(BoolXorAgg, f),
Expression::BitwiseOrAgg(f) => recurse_agg!(BitwiseOrAgg, f),
Expression::BitwiseAndAgg(f) => recurse_agg!(BitwiseAndAgg, f),
Expression::BitwiseXorAgg(f) => recurse_agg!(BitwiseXorAgg, f),

// Count has its own struct with an Option<Expression> `this` field
Expression::Count(mut c) => {
if let Some(this) = c.this.take() {
c.this = Some(transform_recursive(this, transform_fn)?);
}
Expression::Count(c)
}

Expression::PipeOperator(mut pipe) => {
pipe.this = transform_recursive(pipe.this, transform_fn)?;
pipe.expression = transform_recursive(pipe.expression, transform_fn)?;
Expand Down
44 changes: 44 additions & 0 deletions crates/polyglot-sql/src/optimizer/qualify_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3144,6 +3144,50 @@ mod tests {
sql.contains("t1.a"),
"column should be qualified with table name: {sql}"
);

// test that columns in agg functions also get qualified
let expr = parse("SELECT MAX(a) FROM raw.t1");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("t1.a"),
"column in function should be qualified with table name: {sql}"
);

// test that columns in scalar functions also get qualified
let expr = parse("SELECT ABS(a) FROM raw.t1");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("t1.a"),
"column in function should be qualified with table name: {sql}"
);
}

#[test]
fn test_qualify_columns_count_star() {
// COUNT(*) uses Count { this: None } — verify qualify_columns handles it without panic
let expr = parse("SELECT COUNT(*) FROM t1");

let mut schema = MappingSchema::new();
schema
.add_table(
"t1",
&[("id".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");

let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);

assert!(
sql.contains("COUNT(*)"),
"COUNT(*) should be preserved: {sql}"
);
}

#[test]
Expand Down