Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ case class LogicalRDD(
}
}

private[sql] def isCheckpointedInput: Boolean = fromCheckpoint
private[sql] def isCheckpointedInput: Boolean = fromCheckpoint && rdd.isCheckpointed

override lazy val constraints: ExpressionSet = originConstraints.getOrElse(ExpressionSet())
// Subqueries can have non-deterministic results even when they only contain deterministic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,14 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
}


private def hasSelectivePredicate(plan: LogicalPlan): Boolean = plan.exists {
case f: Filter => isLikelySelective(f.condition)
case _ => false
}

/**
* Search for a selective filtering operation, a LocalRelation, or a checkpoint-derived
* LogicalRDD.
* Returns whether a plan can be evaluated repeatedly from materialized inputs and produce the
* same rows.
*
* LocalRelation rows are already locally available. A checkpoint-derived LogicalRDD establishes
* an explicit checkpoint boundary and can be used as a broadcast build side for DPP without
Expand All @@ -210,12 +215,28 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
* InMemoryRelation is intentionally excluded because cache() and persist() are lazy: its
* presence does not guarantee the cached data has been materialized, and missing or evicted
* blocks may require evaluating the upstream computation again.
*
* The supported operators are intentionally narrow. DPP is optional, and logical-plan
* determinism does not cover user functions stored outside Catalyst expressions.
*/
private def hasSelectivePredicateOrLocalOrCheckpointedInput(plan: LogicalPlan): Boolean = {
plan.exists {
case f: Filter => isLikelySelective(f.condition)
private def isRepeatableMaterializedPlan(plan: LogicalPlan): Boolean = {
def isRepeatableExpression(expression: Expression): Boolean = {
expression.deterministic && !SubqueryExpression.hasSubquery(expression) &&
!expression.exists {
case _: NonSQLExpression | _: UserDefinedExpression | _: UserDefinedGenerator => true
case _ => false
}
}

plan match {
case _: LocalRelation => true
case r: LogicalRDD => r.isCheckpointedInput
case Project(projectList, child) if projectList.forall(isRepeatableExpression) =>
isRepeatableMaterializedPlan(child)
case Filter(condition, child) if isRepeatableExpression(condition) =>
isRepeatableMaterializedPlan(child)
case u: Union => u.children.forall(isRepeatableMaterializedPlan)
case SubqueryAlias(_, child) => isRepeatableMaterializedPlan(child)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This SubqueryAlias case looks unreachable in the optimizer: EliminateSubqueryAliases runs in the catalyst FinishAnalysis batch (part of super.defaultBatches), which completes before the PartitionPruning batch in SparkOptimizer, so no SubqueryAlias node survives to reach this rule.

Non-blocking, and a question rather than a request: is it intended as defensive coding (in case the batch ordering changes), or can it be dropped to keep the allowlist as tight as the Scaladoc advertises? If kept, a one-word comment noting it's defensive would help future readers.

case _ => false
}
}
Expand All @@ -224,11 +245,10 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
* To be able to prune partitions on a join key, the filtering side needs to
* meet the following requirements:
* (1) it can not be a stream
* (2) it needs to contain a selective predicate, a LocalRelation, or a checkpoint-derived
* LogicalRDD
* (2) it needs to contain a selective predicate or have a repeatable materialized input
*/
private def hasPartitionPruningFilter(plan: LogicalPlan): Boolean = {
!plan.isStreaming && hasSelectivePredicateOrLocalOrCheckpointedInput(plan)
!plan.isStreaming && (hasSelectivePredicate(plan) || isRepeatableMaterializedPlan(plan))
}

private def prune(plan: LogicalPlan): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,7 @@ class DatasetSuite extends SharedSparkSession
val treeString = cp.logicalPlan.treeString(verbose = true)
fail(s"Expecting a LogicalRDD, but got\n$treeString")
}
assert(logicalRDD.isCheckpointedInput === eager)

val dsPhysicalPlan = ds.queryExecution.executedPlan
val cpPhysicalPlan = cp.queryExecution.executedPlan
Expand All @@ -1928,6 +1929,7 @@ class DatasetSuite extends SharedSparkSession

// For a lazy checkpoint() call, the first check also materializes the checkpoint.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
assert(logicalRDD.isCheckpointedInput)

// Reads back from checkpointed data and check again.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import scala.collection.concurrent.TrieMap

import org.scalatest.GivenWhenThen

import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
Expand Down Expand Up @@ -1797,6 +1801,140 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
}
}

test("DPP requires every leaf of a materialized filtering side to be materialized") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
withTable("events") {
Seq((1, "hour1", "a"), (2, "hour1", "b"), (3, "hour2", "a"))
.toDF("id", "hour", "category")
.write
.partitionBy("hour", "category")
.format(tableFormat)
.mode("overwrite")
.saveAsTable("events")

val checkpointedKeys = Seq("hour1||a").toDF("hc_key").localCheckpoint(eager = true)
val originalKeys = Seq("hour2||a").toDF("hc_key")
val nonCheckpointedKeys: DataFrame = LogicalRDD.fromDataset(
rdd = originalKeys.queryExecution.toRdd,
originDataset = originalKeys,
isStreaming = false)
val mixedKeys = checkpointedKeys.union(nonCheckpointedKeys)

val events = spark.table("events").as("events")
def joinWith(keys: DataFrame): DataFrame = events
.join(broadcast(keys.as("sampled")),
concat_ws("||", $"events.hour", $"events.category") === $"sampled.hc_key")
.select($"events.id")

val mixedJoin = joinWith(mixedKeys)
checkPartitionPruningPredicate(mixedJoin, withSubquery = false, withBroadcast = false)
checkAnswer(mixedJoin, Row(1) :: Row(3) :: Nil)

val fullyMaterializedJoin = joinWith(
checkpointedKeys.union(nonCheckpointedKeys.localCheckpoint(eager = true)))
checkPartitionPruningPredicate(
fullyMaterializedJoin, withSubquery = false, withBroadcast = true)
checkAnswer(fullyMaterializedJoin, Row(1) :: Row(3) :: Nil)
}
}
}

test("DPP materialized-input eligibility requires a repeatable plan") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false",
SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "1000",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false",
SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false",
SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
withTable("events") {
val counterId = getClass.getName
spark.range(1, 11)
.select($"id".cast("int").as("p"), $"id".as("v"))
.write
.partitionBy("p")
.format(tableFormat)
.mode("overwrite")
.saveAsTable("events")

def activeDppSubqueries(df: DataFrame): Seq[InSubqueryExec] = {
collectDynamicPruningExpressions(df.queryExecution.executedPlan)
.collect { case in: InSubqueryExec => in }
}

def checkStandaloneDpp(keys: DataFrame): Unit = {
val df = spark.table("events").join(keys, Seq("p")).select("p")
DppMaterializedInputTestState.reset(counterId)
assert(df.collect().toSeq === Seq(Row(1)))
assert(activeDppSubqueries(df).exists {
case InSubqueryExec(_, _: SubqueryExec, _, _, _, _) => true
case _ => false
}, s"Should execute standalone DPP for a repeatable materialized plan:\n" +
df.queryExecution)
}

def checkNoDpp(keys: DataFrame): Unit = {
val df = spark.table("events").join(keys, Seq("p")).select("p")
DppMaterializedInputTestState.reset(counterId)
assert(df.collect().toSeq === Seq(Row(1)))
assert(activeDppSubqueries(df).isEmpty,
s"Shouldn't trigger DPP for a non-repeatable materialized plan:\n" +
df.queryExecution)
}

checkStandaloneDpp(Seq(1).toDF("p"))
checkStandaloneDpp(Seq(1).toDF("p").localCheckpoint(eager = true))

val checkpointed = Seq(1).toDS().localCheckpoint(eager = true)
val mappedKeys = checkpointed.mapPartitions { values =>
val key = DppMaterializedInputTestState.next(counterId)
values.map(_ => key)
}.toDF("p")
checkNoDpp(mappedKeys)

withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") {
val broadcastJoin =
spark.table("events").join(broadcast(mappedKeys), Seq("p")).select("p")
DppMaterializedInputTestState.reset(counterId)
assert(broadcastJoin.collect().toSeq === Seq(Row(1)))
assert(activeDppSubqueries(broadcastJoin).isEmpty,
s"Shouldn't trigger DPP for a non-repeatable broadcast plan:\n" +
broadcastJoin.queryExecution)

val target = spark.table("events").hint("merge")
.join(mappedKeys.hint("merge"), Seq("p"))
.select($"p", lit("target").as("branch"))
val decoy = Seq(-1).toDF("p")
.join(broadcast(mappedKeys), Seq("p"))
.select($"p", lit("decoy").as("branch"))
val withSiblingBroadcast = target.union(decoy)

DppMaterializedInputTestState.reset(counterId)
val rows = withSiblingBroadcast.collect().toSeq
assert(rows.size === 1)
assert(rows.head.getString(1) === "target")
assert(activeDppSubqueries(withSiblingBroadcast).isEmpty,
s"A sibling broadcast shouldn't make a non-repeatable plan eligible for DPP:\n" +
withSiblingBroadcast.queryExecution)
}

withTempView("changing_keys") {
spark.sparkContext.parallelize(Seq(1), 1).mapPartitions { values =>
val key = DppMaterializedInputTestState.next(counterId)
values.map(_ => key)
}.toDF("p").createOrReplaceTempView("changing_keys")

val scalarSubqueryKeys = sql(
"""SELECT CAST((SELECT max(p) FROM changing_keys) AS INT) AS p
|FROM VALUES (1) AS outer(dummy)""".stripMargin)
checkNoDpp(scalarSubqueryKeys)
}
}
}
}

/**
* Check the static scan metrics with and without DPP
*/
Expand Down Expand Up @@ -1955,3 +2093,11 @@ class DynamicPartitionPruningV2FilterSuiteAEOff
class DynamicPartitionPruningV2FilterSuiteAEOn
extends DynamicPartitionPruningV2FilterSuite
with EnableAdaptiveExecutionSuite

private object DppMaterializedInputTestState {
private val counters = TrieMap.empty[String, AtomicInteger]

def reset(id: String): Unit = counters.getOrElseUpdate(id, new AtomicInteger()).set(0)

def next(id: String): Int = counters.getOrElseUpdate(id, new AtomicInteger()).incrementAndGet()
}