|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | +package org.apache.spark.sql.hive |
| 18 | + |
| 19 | +import org.apache.gluten.execution.{FileSourceScanExecTransformer, FilterExecTransformerBase} |
| 20 | + |
| 21 | +import org.apache.spark.sql._ |
| 22 | +import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} |
| 23 | +import org.apache.spark.sql.execution._ |
| 24 | +import org.apache.spark.sql.execution.adaptive._ |
| 25 | +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} |
| 26 | +import org.apache.spark.sql.hive.execution.HiveTableScanExec |
| 27 | +import org.apache.spark.sql.internal.SQLConf |
| 28 | + |
| 29 | +abstract class GlutenDynamicPartitionPruningHiveScanSuiteBase |
| 30 | + extends DynamicPartitionPruningHiveScanSuiteBase |
| 31 | + with GlutenSQLTestsTrait { |
| 32 | + |
| 33 | + import testImplicits._ |
| 34 | + |
| 35 | + override def beforeAll(): Unit = { |
| 36 | + prepareWorkDir() |
| 37 | + super.beforeAll() |
| 38 | + spark.sparkContext.setLogLevel("WARN") |
| 39 | + } |
| 40 | + |
| 41 | + override protected def collectDynamicPruningExpressions(plan: SparkPlan): Seq[Expression] = { |
| 42 | + flatMap(plan) { |
| 43 | + case s: FileSourceScanExecTransformer => |
| 44 | + s.partitionFilters.collect { case d: DynamicPruningExpression => d.child } |
| 45 | + case s: FileSourceScanExec => |
| 46 | + s.partitionFilters.collect { case d: DynamicPruningExpression => d.child } |
| 47 | + case h: HiveTableScanExec => |
| 48 | + h.partitionPruningPred.collect { case d: DynamicPruningExpression => d.child } |
| 49 | + case h: HiveTableScanExecTransformer => |
| 50 | + h.partitionPruningPred.collect { case d: DynamicPruningExpression => d.child } |
| 51 | + case _ => Nil |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + override def checkPartitionPruningPredicate( |
| 56 | + df: DataFrame, |
| 57 | + withSubquery: Boolean, |
| 58 | + withBroadcast: Boolean): Unit = { |
| 59 | + df.collect() |
| 60 | + |
| 61 | + val plan = df.queryExecution.executedPlan |
| 62 | + val dpExprs = collectDynamicPruningExpressions(plan) |
| 63 | + val hasSubquery = dpExprs.exists { |
| 64 | + case InSubqueryExec(_, _: SubqueryExec, _, _, _, _) => true |
| 65 | + case _ => false |
| 66 | + } |
| 67 | + val subqueryBroadcast = dpExprs.collect { |
| 68 | + case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => b |
| 69 | + case InSubqueryExec(_, b: ColumnarSubqueryBroadcastExec, _, _, _, _) => b |
| 70 | + } |
| 71 | + |
| 72 | + val hasFilter = if (withSubquery) "Should" else "Shouldn't" |
| 73 | + assert( |
| 74 | + hasSubquery == withSubquery, |
| 75 | + s"$hasFilter trigger DPP with a subquery duplicate:\n${df.queryExecution}") |
| 76 | + val hasBroadcast = if (withBroadcast) "Should" else "Shouldn't" |
| 77 | + assert( |
| 78 | + subqueryBroadcast.nonEmpty == withBroadcast, |
| 79 | + s"$hasBroadcast trigger DPP with a reused broadcast exchange:\n${df.queryExecution}") |
| 80 | + |
| 81 | + subqueryBroadcast.foreach { |
| 82 | + s => |
| 83 | + s.child match { |
| 84 | + case _: ReusedExchangeExec => // reuse check ok. |
| 85 | + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => // reuse check ok. |
| 86 | + case b: BroadcastExchangeLike => |
| 87 | + val hasReuse = plan.find { |
| 88 | + case ReusedExchangeExec(_, e) => e eq b |
| 89 | + case _ => false |
| 90 | + }.isDefined |
| 91 | + assert(hasReuse, s"$s\nshould have been reused in\n$plan") |
| 92 | + case a: AdaptiveSparkPlanExec => |
| 93 | + val broadcastQueryStage = collectFirst(a) { case b: BroadcastQueryStageExec => b } |
| 94 | + val broadcastPlan = broadcastQueryStage.get.broadcast |
| 95 | + val hasReuse = find(plan) { |
| 96 | + case ReusedExchangeExec(_, e) => e eq broadcastPlan |
| 97 | + case b: BroadcastExchangeLike => b eq broadcastPlan |
| 98 | + case _ => false |
| 99 | + }.isDefined |
| 100 | + assert(hasReuse, s"$s\nshould have been reused in\n$plan") |
| 101 | + case _ => |
| 102 | + fail(s"Invalid child node found in\n$s") |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + val isMainQueryAdaptive = plan.isInstanceOf[AdaptiveSparkPlanExec] |
| 107 | + subqueriesAll(plan).filterNot(subqueryBroadcast.contains).foreach { |
| 108 | + s => |
| 109 | + val subquery = s match { |
| 110 | + case r: ReusedSubqueryExec => r.child |
| 111 | + case o => o |
| 112 | + } |
| 113 | + assert( |
| 114 | + subquery.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined == isMainQueryAdaptive) |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + override def checkDistinctSubqueries(df: DataFrame, n: Int): Unit = { |
| 119 | + df.collect() |
| 120 | + |
| 121 | + val buf = collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect { |
| 122 | + case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => |
| 123 | + b.indices |
| 124 | + case InSubqueryExec(_, b: ColumnarSubqueryBroadcastExec, _, _, _, _) => |
| 125 | + b.indices |
| 126 | + } |
| 127 | + assert(buf.distinct.size == n) |
| 128 | + } |
| 129 | + |
| 130 | + override def checkUnpushedFilters(df: DataFrame): Boolean = { |
| 131 | + find(df.queryExecution.executedPlan) { |
| 132 | + case FilterExec(condition, _) => |
| 133 | + splitConjunctivePredicates(condition).exists { |
| 134 | + case _: DynamicPruningExpression => true |
| 135 | + case _ => false |
| 136 | + } |
| 137 | + case transformer: FilterExecTransformerBase => |
| 138 | + splitConjunctivePredicates(transformer.cond).exists { |
| 139 | + case _: DynamicPruningExpression => true |
| 140 | + case _ => false |
| 141 | + } |
| 142 | + case _ => false |
| 143 | + }.isDefined |
| 144 | + } |
| 145 | + |
| 146 | + testGluten("Make sure dynamic pruning works on uncorrelated queries") { |
| 147 | + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { |
| 148 | + val df = sql(""" |
| 149 | + |SELECT d.store_id, |
| 150 | + | SUM(f.units_sold), |
| 151 | + | (SELECT SUM(f.units_sold) |
| 152 | + | FROM fact_stats f JOIN dim_stats d ON d.store_id = f.store_id |
| 153 | + | WHERE d.country = 'US') AS total_prod |
| 154 | + |FROM fact_stats f JOIN dim_stats d ON d.store_id = f.store_id |
| 155 | + |WHERE d.country = 'US' |
| 156 | + |GROUP BY 1 |
| 157 | + """.stripMargin) |
| 158 | + checkAnswer(df, Row(4, 50, 70) :: Row(5, 10, 70) :: Row(6, 10, 70) :: Nil) |
| 159 | + |
| 160 | + val plan = df.queryExecution.executedPlan |
| 161 | + val countSubqueryBroadcasts = |
| 162 | + collectWithSubqueries(plan) { |
| 163 | + case _: SubqueryBroadcastExec => 1 |
| 164 | + case _: ColumnarSubqueryBroadcastExec => 1 |
| 165 | + }.sum |
| 166 | + |
| 167 | + val countReusedSubqueryBroadcasts = |
| 168 | + collectWithSubqueries(plan) { |
| 169 | + case ReusedSubqueryExec(_: SubqueryBroadcastExec) => 1 |
| 170 | + case ReusedSubqueryExec(_: ColumnarSubqueryBroadcastExec) => 1 |
| 171 | + }.sum |
| 172 | + |
| 173 | + assert(countSubqueryBroadcasts == 1) |
| 174 | + assert(countReusedSubqueryBroadcasts == 1) |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + testGluten("SPARK-38674: Remove useless deduplicate in SubqueryBroadcastExec") { |
| 179 | + withTable("duplicate_keys") { |
| 180 | + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { |
| 181 | + Seq[(Int, String)]((1, "NL"), (1, "NL"), (3, "US"), (3, "US"), (3, "US")) |
| 182 | + .toDF("store_id", "country") |
| 183 | + .write |
| 184 | + .format(tableFormat) |
| 185 | + .saveAsTable("duplicate_keys") |
| 186 | + |
| 187 | + val df = sql(""" |
| 188 | + |SELECT date_id, product_id FROM fact_sk f |
| 189 | + |JOIN duplicate_keys s |
| 190 | + |ON f.store_id = s.store_id WHERE s.country = 'US' AND date_id > 1050 |
| 191 | + """.stripMargin) |
| 192 | + |
| 193 | + checkPartitionPruningPredicate(df, withSubquery = false, withBroadcast = true) |
| 194 | + |
| 195 | + val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) { |
| 196 | + case s: ColumnarSubqueryBroadcastExec => s |
| 197 | + } |
| 198 | + assert(subqueryBroadcastExecs.size === 1) |
| 199 | + subqueryBroadcastExecs.foreach { |
| 200 | + subqueryBroadcastExec => |
| 201 | + assert(subqueryBroadcastExec.metrics("numOutputRows").value === 1) |
| 202 | + } |
| 203 | + |
| 204 | + checkAnswer(df, Row(1060, 2) :: Row(1060, 2) :: Row(1060, 2) :: Nil) |
| 205 | + } |
| 206 | + } |
| 207 | + } |
| 208 | +} |
| 209 | + |
| 210 | +class GlutenDynamicPartitionPruningHiveScanSuiteAEOff |
| 211 | + extends GlutenDynamicPartitionPruningHiveScanSuiteBase |
| 212 | + with DisableAdaptiveExecutionSuite |
| 213 | + |
| 214 | +class GlutenDynamicPartitionPruningHiveScanSuiteAEOn |
| 215 | + extends GlutenDynamicPartitionPruningHiveScanSuiteBase |
| 216 | + with EnableAdaptiveExecutionSuite |
0 commit comments