diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala index 606fa377b8a7..9b4e06172a4a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala @@ -79,22 +79,27 @@ object GlutenWriterColumnarRules { // So FakeRowAdaptor will always consumes columnar data, // thus avoiding the case of c2r->aqe->r2c->writer case aqe: AdaptiveSparkPlanExec => - command.withNewChildren( - Array( - BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToCarrierRow( - AdaptiveSparkPlanExec( - aqe.inputPlan, - aqe.context, - aqe.preprocessingRules, - aqe.isSubquery, - supportsColumnar = true - )))) + val newChild = BackendsApiManager.getSparkPlanExecApiInstance + .genColumnarToCarrierRow(aqe.inputPlan) + command.withNewChildren(Array(wrapColumnarToRowWithAqe(newChild, aqe))) case other => command.withNewChildren( Array(BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToCarrierRow(other))) } } + private def wrapColumnarToRowWithAqe( + newChild: SparkPlan, + aqe: AdaptiveSparkPlanExec): AdaptiveSparkPlanExec = { + aqe.inputPlan.logicalLink.foreach(newChild.setLogicalLink) + AdaptiveSparkPlanExec( + newChild, + aqe.context, + aqe.preprocessingRules, + aqe.isSubquery, + supportsColumnar = false) + } + case class NativeWritePostRule(session: SparkSession) extends Rule[SparkPlan] { override def apply(p: SparkPlan): SparkPlan = p match { diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala index 618dc07431f7..f4450b0c6404 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala @@ -1255,12 +1255,12 @@ class VeloxAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQLT sparkContext.listenerBus.waitUntilEmpty() assert(plan.isInstanceOf[V2TableWriteExec]) val childPlan = plan.asInstanceOf[V2TableWriteExec].child - assert(childPlan.isInstanceOf[ColumnarToCarrierRowExecBase]) + assert(childPlan.isInstanceOf[AdaptiveSparkPlanExec]) assert( childPlan - .asInstanceOf[ColumnarToCarrierRowExecBase] - .child - .isInstanceOf[AdaptiveSparkPlanExec]) + .asInstanceOf[AdaptiveSparkPlanExec] + .inputPlan + .isInstanceOf[ColumnarToCarrierRowExecBase]) spark.listenerManager.unregister(listener) } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala index 0c5eae6e586c..f768209f5480 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala @@ -1259,12 +1259,12 @@ class VeloxAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQLT sparkContext.listenerBus.waitUntilEmpty() assert(plan.isInstanceOf[V2TableWriteExec]) val childPlan = plan.asInstanceOf[V2TableWriteExec].child - assert(childPlan.isInstanceOf[ColumnarToCarrierRowExecBase]) + assert(childPlan.isInstanceOf[AdaptiveSparkPlanExec]) assert( childPlan - .asInstanceOf[ColumnarToCarrierRowExecBase] - .child - .isInstanceOf[AdaptiveSparkPlanExec]) + .asInstanceOf[AdaptiveSparkPlanExec] + .inputPlan + .isInstanceOf[ColumnarToCarrierRowExecBase]) spark.listenerManager.unregister(listener) } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala index 6acdfed41940..f25584380aaf 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala @@ -1208,12 +1208,12 @@ class VeloxAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQLT sparkContext.listenerBus.waitUntilEmpty() assert(plan.isInstanceOf[V2TableWriteExec]) val childPlan = plan.asInstanceOf[V2TableWriteExec].child - assert(childPlan.isInstanceOf[ColumnarToCarrierRowExecBase]) + assert(childPlan.isInstanceOf[AdaptiveSparkPlanExec]) assert( childPlan - .asInstanceOf[ColumnarToCarrierRowExecBase] - .child - .isInstanceOf[AdaptiveSparkPlanExec]) + .asInstanceOf[AdaptiveSparkPlanExec] + .inputPlan + .isInstanceOf[ColumnarToCarrierRowExecBase]) spark.listenerManager.unregister(listener) } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala index 6acdfed41940..f25584380aaf 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala @@ -1208,12 +1208,12 @@ class VeloxAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQLT sparkContext.listenerBus.waitUntilEmpty() assert(plan.isInstanceOf[V2TableWriteExec]) val childPlan = plan.asInstanceOf[V2TableWriteExec].child - assert(childPlan.isInstanceOf[ColumnarToCarrierRowExecBase]) + assert(childPlan.isInstanceOf[AdaptiveSparkPlanExec]) assert( childPlan - .asInstanceOf[ColumnarToCarrierRowExecBase] - .child - .isInstanceOf[AdaptiveSparkPlanExec]) + .asInstanceOf[AdaptiveSparkPlanExec] + .inputPlan + .isInstanceOf[ColumnarToCarrierRowExecBase]) spark.listenerManager.unregister(listener) } diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala index 67e9baed0401..8ac43cdc94cc 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala @@ -1214,12 +1214,12 @@ class VeloxAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQLT sparkContext.listenerBus.waitUntilEmpty() assert(plan.isInstanceOf[V2TableWriteExec]) val childPlan = plan.asInstanceOf[V2TableWriteExec].child - assert(childPlan.isInstanceOf[ColumnarToCarrierRowExecBase]) + assert(childPlan.isInstanceOf[AdaptiveSparkPlanExec]) assert( childPlan - .asInstanceOf[ColumnarToCarrierRowExecBase] - .child - .isInstanceOf[AdaptiveSparkPlanExec]) + .asInstanceOf[AdaptiveSparkPlanExec] + .inputPlan + .isInstanceOf[ColumnarToCarrierRowExecBase]) spark.listenerManager.unregister(listener) } diff --git a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala index 6ebefbe15a43..dd57813fca07 100644 --- a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/adaptive/velox/VeloxAdaptiveQueryExecSuite.scala @@ -1218,12 +1218,12 @@ class VeloxAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQLT sparkContext.listenerBus.waitUntilEmpty() assert(plan.isInstanceOf[V2TableWriteExec]) val childPlan = plan.asInstanceOf[V2TableWriteExec].child - assert(childPlan.isInstanceOf[ColumnarToCarrierRowExecBase]) + assert(childPlan.isInstanceOf[AdaptiveSparkPlanExec]) assert( childPlan - .asInstanceOf[ColumnarToCarrierRowExecBase] - .child - .isInstanceOf[AdaptiveSparkPlanExec]) + .asInstanceOf[AdaptiveSparkPlanExec] + .inputPlan + .isInstanceOf[ColumnarToCarrierRowExecBase]) spark.listenerManager.unregister(listener) }