From 3a75aca17daea30933946401a664bc82c95465ca Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Sun, 14 Jun 2026 13:49:31 -0700 Subject: [PATCH] [SPARK-54593][SQL] Fix DPP eligibility for materialized filtering sides Co-authored-by: Tri Tam Hoang Co-authored-by: Dustin Smith --- .../spark/sql/execution/ExistingRDD.scala | 2 +- .../dynamicpruning/PartitionPruning.scala | 36 ++++- .../org/apache/spark/sql/DatasetSuite.scala | 2 + .../sql/DynamicPartitionPruningSuite.scala | 146 ++++++++++++++++++ 4 files changed, 177 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f2e87568632a5..a449f4f171440 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -157,7 +157,7 @@ case class LogicalRDD( } } - private[sql] def isCheckpointedInput: Boolean = fromCheckpoint + private[sql] def isCheckpointedInput: Boolean = fromCheckpoint && rdd.isCheckpointed override lazy val constraints: ExpressionSet = originConstraints.getOrElse(ExpressionSet()) // Subqueries can have non-deterministic results even when they only contain deterministic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index ca7c8442d5f90..93e388c45af00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -199,9 +199,14 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join } + private def hasSelectivePredicate(plan: LogicalPlan): Boolean = plan.exists { + case f: Filter => isLikelySelective(f.condition) + case _ => false + } + /** - * Search for a selective filtering operation, a LocalRelation, or a checkpoint-derived - * LogicalRDD. + * Returns whether a plan can be evaluated repeatedly from materialized inputs and produce the + * same rows. * * LocalRelation rows are already locally available. A checkpoint-derived LogicalRDD establishes * an explicit checkpoint boundary and can be used as a broadcast build side for DPP without @@ -210,12 +215,28 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join * InMemoryRelation is intentionally excluded because cache() and persist() are lazy: its * presence does not guarantee the cached data has been materialized, and missing or evicted * blocks may require evaluating the upstream computation again. + * + * The supported operators are intentionally narrow. DPP is optional, and logical-plan + * determinism does not cover user functions stored outside Catalyst expressions. */ - private def hasSelectivePredicateOrLocalOrCheckpointedInput(plan: LogicalPlan): Boolean = { - plan.exists { - case f: Filter => isLikelySelective(f.condition) + private def isRepeatableMaterializedPlan(plan: LogicalPlan): Boolean = { + def isRepeatableExpression(expression: Expression): Boolean = { + expression.deterministic && !SubqueryExpression.hasSubquery(expression) && + !expression.exists { + case _: NonSQLExpression | _: UserDefinedExpression | _: UserDefinedGenerator => true + case _ => false + } + } + + plan match { case _: LocalRelation => true case r: LogicalRDD => r.isCheckpointedInput + case Project(projectList, child) if projectList.forall(isRepeatableExpression) => + isRepeatableMaterializedPlan(child) + case Filter(condition, child) if isRepeatableExpression(condition) => + isRepeatableMaterializedPlan(child) + case u: Union => u.children.forall(isRepeatableMaterializedPlan) + case SubqueryAlias(_, child) => isRepeatableMaterializedPlan(child) case _ => false } } @@ -224,11 +245,10 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join * To be able to prune partitions on a join key, the filtering side needs to * meet the following requirements: * (1) it can not be a stream - * (2) it needs to contain a selective predicate, a LocalRelation, or a checkpoint-derived - * LogicalRDD + * (2) it needs to contain a selective predicate or have a repeatable materialized input */ private def hasPartitionPruningFilter(plan: LogicalPlan): Boolean = { - !plan.isStreaming && hasSelectivePredicateOrLocalOrCheckpointedInput(plan) + !plan.isStreaming && (hasSelectivePredicate(plan) || isRepeatableMaterializedPlan(plan)) } private def prune(plan: LogicalPlan): LogicalPlan = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dc930af874908..3b28cae31a134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1908,6 +1908,7 @@ class DatasetSuite extends SharedSparkSession val treeString = cp.logicalPlan.treeString(verbose = true) fail(s"Expecting a LogicalRDD, but got\n$treeString") } + assert(logicalRDD.isCheckpointedInput === eager) val dsPhysicalPlan = ds.queryExecution.executedPlan val cpPhysicalPlan = cp.queryExecution.executedPlan @@ -1928,6 +1929,7 @@ class DatasetSuite extends SharedSparkSession // For a lazy checkpoint() call, the first check also materializes the checkpoint. checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + assert(logicalRDD.isCheckpointedInput) // Reads back from checkpointed data and check again. checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index d303a03ba64b7..4db67ec77479c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.concurrent.TrieMap + import org.scalatest.GivenWhenThen import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} @@ -1797,6 +1801,140 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat } } + test("DPP requires every leaf of a materialized filtering side to be materialized") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + withTable("events") { + Seq((1, "hour1", "a"), (2, "hour1", "b"), (3, "hour2", "a")) + .toDF("id", "hour", "category") + .write + .partitionBy("hour", "category") + .format(tableFormat) + .mode("overwrite") + .saveAsTable("events") + + val checkpointedKeys = Seq("hour1||a").toDF("hc_key").localCheckpoint(eager = true) + val originalKeys = Seq("hour2||a").toDF("hc_key") + val nonCheckpointedKeys: DataFrame = LogicalRDD.fromDataset( + rdd = originalKeys.queryExecution.toRdd, + originDataset = originalKeys, + isStreaming = false) + val mixedKeys = checkpointedKeys.union(nonCheckpointedKeys) + + val events = spark.table("events").as("events") + def joinWith(keys: DataFrame): DataFrame = events + .join(broadcast(keys.as("sampled")), + concat_ws("||", $"events.hour", $"events.category") === $"sampled.hc_key") + .select($"events.id") + + val mixedJoin = joinWith(mixedKeys) + checkPartitionPruningPredicate(mixedJoin, withSubquery = false, withBroadcast = false) + checkAnswer(mixedJoin, Row(1) :: Row(3) :: Nil) + + val fullyMaterializedJoin = joinWith( + checkpointedKeys.union(nonCheckpointedKeys.localCheckpoint(eager = true))) + checkPartitionPruningPredicate( + fullyMaterializedJoin, withSubquery = false, withBroadcast = true) + checkAnswer(fullyMaterializedJoin, Row(1) :: Row(3) :: Nil) + } + } + } + + test("DPP materialized-input eligibility requires a repeatable plan") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "1000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + withTable("events") { + val counterId = getClass.getName + spark.range(1, 11) + .select($"id".cast("int").as("p"), $"id".as("v")) + .write + .partitionBy("p") + .format(tableFormat) + .mode("overwrite") + .saveAsTable("events") + + def activeDppSubqueries(df: DataFrame): Seq[InSubqueryExec] = { + collectDynamicPruningExpressions(df.queryExecution.executedPlan) + .collect { case in: InSubqueryExec => in } + } + + def checkStandaloneDpp(keys: DataFrame): Unit = { + val df = spark.table("events").join(keys, Seq("p")).select("p") + DppMaterializedInputTestState.reset(counterId) + assert(df.collect().toSeq === Seq(Row(1))) + assert(activeDppSubqueries(df).exists { + case InSubqueryExec(_, _: SubqueryExec, _, _, _, _) => true + case _ => false + }, s"Should execute standalone DPP for a repeatable materialized plan:\n" + + df.queryExecution) + } + + def checkNoDpp(keys: DataFrame): Unit = { + val df = spark.table("events").join(keys, Seq("p")).select("p") + DppMaterializedInputTestState.reset(counterId) + assert(df.collect().toSeq === Seq(Row(1))) + assert(activeDppSubqueries(df).isEmpty, + s"Shouldn't trigger DPP for a non-repeatable materialized plan:\n" + + df.queryExecution) + } + + checkStandaloneDpp(Seq(1).toDF("p")) + checkStandaloneDpp(Seq(1).toDF("p").localCheckpoint(eager = true)) + + val checkpointed = Seq(1).toDS().localCheckpoint(eager = true) + val mappedKeys = checkpointed.mapPartitions { values => + val key = DppMaterializedInputTestState.next(counterId) + values.map(_ => key) + }.toDF("p") + checkNoDpp(mappedKeys) + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") { + val broadcastJoin = + spark.table("events").join(broadcast(mappedKeys), Seq("p")).select("p") + DppMaterializedInputTestState.reset(counterId) + assert(broadcastJoin.collect().toSeq === Seq(Row(1))) + assert(activeDppSubqueries(broadcastJoin).isEmpty, + s"Shouldn't trigger DPP for a non-repeatable broadcast plan:\n" + + broadcastJoin.queryExecution) + + val target = spark.table("events").hint("merge") + .join(mappedKeys.hint("merge"), Seq("p")) + .select($"p", lit("target").as("branch")) + val decoy = Seq(-1).toDF("p") + .join(broadcast(mappedKeys), Seq("p")) + .select($"p", lit("decoy").as("branch")) + val withSiblingBroadcast = target.union(decoy) + + DppMaterializedInputTestState.reset(counterId) + val rows = withSiblingBroadcast.collect().toSeq + assert(rows.size === 1) + assert(rows.head.getString(1) === "target") + assert(activeDppSubqueries(withSiblingBroadcast).isEmpty, + s"A sibling broadcast shouldn't make a non-repeatable plan eligible for DPP:\n" + + withSiblingBroadcast.queryExecution) + } + + withTempView("changing_keys") { + spark.sparkContext.parallelize(Seq(1), 1).mapPartitions { values => + val key = DppMaterializedInputTestState.next(counterId) + values.map(_ => key) + }.toDF("p").createOrReplaceTempView("changing_keys") + + val scalarSubqueryKeys = sql( + """SELECT CAST((SELECT max(p) FROM changing_keys) AS INT) AS p + |FROM VALUES (1) AS outer(dummy)""".stripMargin) + checkNoDpp(scalarSubqueryKeys) + } + } + } + } + /** * Check the static scan metrics with and without DPP */ @@ -1955,3 +2093,11 @@ class DynamicPartitionPruningV2FilterSuiteAEOff class DynamicPartitionPruningV2FilterSuiteAEOn extends DynamicPartitionPruningV2FilterSuite with EnableAdaptiveExecutionSuite + +private object DppMaterializedInputTestState { + private val counters = TrieMap.empty[String, AtomicInteger] + + def reset(id: String): Unit = counters.getOrElseUpdate(id, new AtomicInteger()).set(0) + + def next(id: String): Int = counters.getOrElseUpdate(id, new AtomicInteger()).incrementAndGet() +}