diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 87500f0ca5149..68a83a55c8daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1561,7 +1561,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { * Check if the given expression is cheap that we can inline it. */ def isCheap(e: Expression): Boolean = e match { - case _: Attribute | _: OuterReference => true + // `BoundReference` is the codegen-bound form of an `Attribute`; a slot read, equally cheap. + case _: Attribute | _: OuterReference | _: BoundReference => true case _ if e.foldable => true // PythonUDF is handled by the rule ExtractPythonUDFs case _: PythonUDF => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 88c74ab7adc41..3eadfe7b865e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -316,12 +317,22 @@ case class FilterExec(condition: Expression, child: SparkPlan) // (e.g. decoding a decimal column for rows a cheaper earlier predicate would reject), so we // fall back to `generatePredicateCode`. // + // A *cheap* common subexpression does not count. `c BETWEEN lo AND hi` lowers to + // `c >= lo AND c <= hi`, so any `BETWEEN` (or a column referenced in several conjuncts) makes + // that column a common subexpression, but caching a cheap load saves nothing: the non-CSE path + // already loads each column lazily into a variable on demand. Taking the CSE path for it would + // only add the eager prologue that decodes every referenced column up front. Require a + // non-cheap common subexpression (per `CollapseProject.isCheap`) so filters like TPC-DS q28 + // (`ss_quantity BETWEEN ... AND (ss_list_price BETWEEN ... OR ...)`, whose only repeats are the + // bare columns) keep the lazy, short-circuiting path. + // // `subexpressionElimination.filterExec.enabled` additionally gates this path so it can be // turned off independently of subexpression elimination elsewhere. val (prologueCode, predicateCode) = if (conf.subexpressionEliminationEnabled && conf.subexpressionEliminationFilterExecEnabled && otherPreds.nonEmpty && - otherPredsEquivalentExpressions.getCommonSubexpressions.nonEmpty) { + otherPredsEquivalentExpressions.getCommonSubexpressions + .exists(!CollapseProject.isCheap(_))) { // Pre-evaluate input variables before CSE analysis: CSE clears // ctx.currentVars[i].code as a side effect; without this pre-evaluation, Janino // fails when otherPreds reference the same input columns that CSE already diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 886df9184aca4..e013e59597f69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -1225,6 +1225,48 @@ class WholeStageCodegenSuite extends SharedSparkSession "CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting non-CSE path)") } + test("SPARK-56032: FilterExec skips CSE codegen when the only common subexpression is a leaf") { + // `c BETWEEN lo AND hi` lowers to `c >= lo AND c <= hi`, so a column used in a BETWEEN (or in + // several conjuncts) becomes a "common subexpression" -- but it is a bare leaf column whose + // load CSE cannot meaningfully cache, since the non-CSE path already loads columns lazily. The + // gate must not take the CSE path for it: doing so emits the eager prologue that decodes every + // referenced column (the decimals `p1`/`p2` here) up front, defeating the cheap `q` filter's + // short-circuiting. This is the TPC-DS q28 shape. Verify the leaf-only case falls back to the + // same code as CSE-disabled; `p1`/`p2` stand in for q28's decimal columns whose eager decode + // is the cost, though the fallback is type-independent. + val schema = StructType(Seq( + StructField("q", IntegerType, nullable = true), + StructField("p1", IntegerType, nullable = true), + StructField("p2", IntegerType, nullable = true))) + val data = spark.sparkContext.parallelize(Seq( + Row(4, 10, 7), Row(1, 10, 7), Row(null, 10, 7), + Row(5, 100, 7), Row(6, 100, 100), Row(3, 9, 1))) + val expected = Seq(Row(4, 10, 7), Row(5, 100, 7), Row(3, 9, 1)) + + def filterCode(cseEnabled: Boolean): String = { + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = spark.createDataFrame(data, schema) + // The only repeated expressions are the bare columns q, p1, p2 (each referenced by the two + // halves of its BETWEEN). No non-leaf expression is shared. + val filtered = df.where( + "q IS NOT NULL AND q BETWEEN 2 AND 6 AND (p1 BETWEEN 8 AND 18 OR p2 BETWEEN 5 AND 9)") + val plan = filtered.queryExecution.executedPlan + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]), + "Filter should be in whole-stage codegen") + checkAnswer(filtered, expected) + codegenString(plan) + } + } + + def normalize(code: String): String = code.replaceAll("#\\d+", "#") + assert(normalize(filterCode(cseEnabled = true)) == normalize(filterCode(cseEnabled = false)), + "With only leaf common subexpressions, CSE-enabled FilterExec codegen should be identical " + + "to CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting non-CSE path)") + } + test("SPARK-56032: subexpressionElimination.filterExec.enabled gates FilterExec CSE " + "independently of subexpression elimination") { // The conf disables CSE specifically for FilterExec while leaving subexpression elimination