Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ object VeloxRuleApi {
private def injectSpark(injector: SparkInjector): Unit = {
// Inject the regular Spark rules directly.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(BloomFilterMightContainJointRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply)
injector.injectOptimizerRule(RewriteCastFromArray.apply)
Expand Down Expand Up @@ -81,11 +82,6 @@ object VeloxRuleApi {
injector.injectPreTransform(c => FallbackMultiCodegens.apply(c.session))
injector.injectPreTransform(c => MergeTwoPhasesHashBaseAggregate(c.session))
injector.injectPreTransform(_ => RewriteSubqueryBroadcast())
injector.injectPreTransform(
c =>
BloomFilterMightContainJointRewriteRule.apply(
c.session,
c.caller.isBloomFilterStatFunction()))
injector.injectPreTransform(_ => EliminateRedundantGetTimestamp)

// Legacy: The legacy transform rule.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,63 +21,70 @@ import org.apache.gluten.expression.VeloxBloomFilterMightContain
import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BloomFilterMightContain, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

case class BloomFilterMightContainJointRewriteRule(
spark: SparkSession,
isBloomFilterStatFunction: Boolean)
extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
if (isBloomFilterStatFunction || !GlutenConfig.get.enableNativeBloomFilter) {
/**
* Optimizer rule that rewrites `BloomFilterAggregate` -> `VeloxBloomFilterAggregate` and
* `BloomFilterMightContain` -> `VeloxBloomFilterMightContain` at the logical plan level.
*
* Running as an optimizer rule ensures the substitution is captured in the `originalPlan` snapshot
* that [[org.apache.gluten.extension.columnar.heuristic.ExpandFallbackPolicy]] uses when promoting
* an individual stage fallback to a whole-stage AQE fallback. This guarantees that both sides of
* the bloom-filter pair always produce and consume the same byte format, regardless of whether
* stages fall back to JVM execution after AQE re-planning.
*
* The aggregate (producer) and the might-contain (consumer) are always rewritten as a pair, or not
* at all, so they never end up on different serialized byte formats:
* - `might_contain(ScalarSubquery(...), col)` with a plain column value ([[Attribute]]): rewrite
* both to their Velox forms (version=1). This is the user-facing filter path that GLUTEN-12013
* protects across whole-stage AQE fallback.
* - `might_contain(ScalarSubquery(...), <non-column>)` (e.g. a literal, as in SPARK-54336): leave
* both vanilla (version=0). Rewriting only the outer side to Velox while the inner aggregate
* stayed vanilla `bloom_filter_agg` is exactly what caused the `kBloomFilterV1 == version` (1
* vs. 0) crash.
*
* Standalone `BloomFilterAggregate` (e.g., `DataFrame.stat.bloomFilter()`) is never matched, so its
* bytes stay in Spark-native format. DPP/runtime-filter `might_contain` expressions are injected by
* Spark's `InjectRuntimeFilter`, which runs after this rule's batch, so they are never seen here.
*/
case class BloomFilterMightContainJointRewriteRule(spark: SparkSession)
extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!GlutenConfig.get.enableNativeBloomFilter) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the early-exit if (!GlutenConfig.get.enableNativeBloomFilter) return plan is untested. Can we add a test case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test that sets spark.gluten.sql.native.bloomFilter=false and checks the query still returns the correct count and the optimized plan does not contain velox_might_contain.

return plan
}
val out = plan.transformWithSubqueries {
case p =>
applyForNode(p)
}
out
}

private def replaceBloomFilterAggregate[T](
expr: Expression,
bloomFilterAggReplacer: (
Expression,
Expression,
Expression,
Int,
Int) => TypedImperativeAggregate[T]): Expression = expr match {
case BloomFilterAggregate(
child,
estimatedNumItemsExpression,
numBitsExpression,
mutableAggBufferOffset,
inputAggBufferOffset) =>
bloomFilterAggReplacer(
child,
estimatedNumItemsExpression,
numBitsExpression,
mutableAggBufferOffset,
inputAggBufferOffset)
case other => other
}

private def replaceMightContain[T](
expr: Expression,
mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match {
case BloomFilterMightContain(bloomFilterExpression, valueExpression) =>
mightContainReplacer(bloomFilterExpression, valueExpression)
case other => other
}

private def applyForNode(p: SparkPlan) = {
p.transformExpressions {
case e =>
replaceMightContain(
replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply),
VeloxBloomFilterMightContain.apply)
plan.transformAllExpressions {
case BloomFilterMightContain(subq: ScalarSubquery, v: Attribute) =>
// User-facing bloom filter: value is a plain column reference.
// Rewrite both the outer might-contain and the inner aggregate to Velox format so that
// both sides always produce/consume the same byte layout (even when one stage falls back).
val rewrittenPlan = subq.plan.transformAllExpressions {
case ae @ AggregateExpression(b: BloomFilterAggregate, _, _, _, _) =>
ae.copy(aggregateFunction = VeloxBloomFilterAggregate(
b.child,
b.estimatedNumItemsExpression,
b.numBitsExpression,
b.mutableAggBufferOffset,
b.inputAggBufferOffset))
}
VeloxBloomFilterMightContain(subq.withNewPlan(rewrittenPlan), v)
case bfmc @ BloomFilterMightContain(_: ScalarSubquery, _) =>
// Inline subquery whose value is NOT a plain column -- e.g. a literal, as in SPARK-54336
// (`might_contain((SELECT bloom_filter_agg(col) FROM t), 0L)`). Leave BOTH the inner
// aggregate and the outer might-contain as vanilla Spark expressions so they stay on the
// same Spark-native (version=0) byte format, and so an empty aggregate input still yields
// vanilla's NULL bloom filter. Rewriting only the outer side to Velox (which expects
// version=1) while the inner aggregate stays vanilla `bloom_filter_agg` (no Substrait
// mapping -> JVM -> version=0) is what caused the `kBloomFilterV1 == version` (1 vs. 0)
// crash.
bfmc
case BloomFilterMightContain(bf, v) =>
// Pre-computed literal bloom filter bytes -- rewrite to consume Velox-format bytes.
VeloxBloomFilterMightContain(bf, v)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ trait CallerInfo {
def isAqe(): Boolean
def isCache(): Boolean
def isStreaming(): Boolean
def isBloomFilterStatFunction(): Boolean
}

object CallerInfo {
Expand All @@ -42,8 +41,7 @@ object CallerInfo {
private class Impl(
override val isAqe: Boolean,
override val isCache: Boolean,
override val isStreaming: Boolean,
override val isBloomFilterStatFunction: Boolean
override val isStreaming: Boolean
) extends CallerInfo

/*
Expand All @@ -57,8 +55,7 @@ object CallerInfo {
new Impl(
isAqe = inAqeCall(stack),
isCache = inCacheCall(stack),
isStreaming = inStreamingCall(stack),
isBloomFilterStatFunction = inBloomFilterStatFunctionCall(stack))
isStreaming = inStreamingCall(stack))
}

private def inAqeCall(stack: Seq[StackTraceElement]): Boolean = {
Expand All @@ -78,21 +75,13 @@ object CallerInfo {
stack.exists(_.getClassName.equals(streamName))
}

private def inBloomFilterStatFunctionCall(stack: Seq[StackTraceElement]): Boolean = {
val res = stack.exists(
_.getClassName.equals("org.apache.spark.sql.DataFrameStatFunctions")
&& stack.exists(_.getMethodName.equals("bloomFilter")))
res
}

/** For testing only. */
def withLocalValue[T](
isAqe: Boolean,
isCache: Boolean,
isStreaming: Boolean = false,
isBloomFilterStatFunction: Boolean = false)(body: => T): T = {
isStreaming: Boolean = false)(body: => T): T = {
val prevValue = localStorage.get()
val newValue = new Impl(isAqe, isCache, isStreaming, isBloomFilterStatFunction)
val newValue = new Impl(isAqe, isCache, isStreaming)
localStorage.set(Some(newValue))
try {
body
Expand Down
Loading
Loading