diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala index 14029cf28ff..e3d43baedf4 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/CudfNodeValidationRule.scala @@ -54,30 +54,58 @@ case class CudfNodeValidationRule(glutenConf: GlutenConfig) extends Rule[SparkPl object CudfNodeValidationRule { def setTagForWholeStageTransformer(transformer: WholeStageTransformer): Unit = { - if (!VeloxConfig.get.cudfEnableTableScan) { - // Spark3.2 does not have exists - val hasLeaf = transformer.find { - case _: LeafTransformSupport => true - case _ => false - }.isDefined - if (!hasLeaf && VeloxConfig.get.cudfEnableValidation) { - if ( - VeloxCudfPlanValidatorJniWrapper.validate( - transformer.substraitPlan.toProtobuf.toByteArray) - ) { - transformer.foreach { - case _: LeafTransformSupport => - case t: TransformSupport => - t.setTagValue(CudfTag.CudfTag, true) - case _ => - } - transformer.setTagValue(CudfTag.CudfTag, true) - } - } else { - transformer.setTagValue(CudfTag.CudfTag, !hasLeaf) + // Spark 3.2 does not have TreeNode.exists, so use find(...).isDefined. + val hasLeaf = transformer.find { + case _: LeafTransformSupport => true + case _ => false + }.isDefined + + val canOffload = decideOffload( + hasLeaf, + VeloxConfig.get.cudfEnableTableScan, + VeloxConfig.get.cudfEnableValidation, + () => + VeloxCudfPlanValidatorJniWrapper.validate( + transformer.substraitPlan.toProtobuf.toByteArray)) + + if (canOffload) { + transformer.foreach { + case _: LeafTransformSupport => + case t: TransformSupport => + t.setTagValue(CudfTag.CudfTag, true) + case _ => } - } else { transformer.setTagValue(CudfTag.CudfTag, true) + } else { + transformer.setTagValue(CudfTag.CudfTag, false) + } + } + + /** + * Decide whether a whole-stage transformer can be offloaded to the cuDF GPU backend. + * + * Pure (no native calls) so the branching can be unit-tested: + * - a stage that reads a table is offloadable only when GPU table scan is enabled; + * - otherwise, when validation is enabled, the native validator decides (it exempts + * TableScan, so a scan-bearing stage passes only when every other operator is + * cuDF-capable); + * - when validation is disabled, the stage is offloaded optimistically. + * + * `validate` is invoked only on the validation path, never when a table-reading stage is + * rejected up front, so the previous "tag GPU unconditionally when table scan is enabled" + * behavior no longer skips validation. + */ + private[extension] def decideOffload( + hasLeaf: Boolean, + enableTableScan: Boolean, + enableValidation: Boolean, + validate: () => Boolean): Boolean = { + if (hasLeaf && !enableTableScan) { + false + } else if (!enableValidation) { + true + } else { + validate() } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/extension/CudfNodeValidationRuleSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/extension/CudfNodeValidationRuleSuite.scala new file mode 100644 index 00000000000..a4514f06c8d --- /dev/null +++ b/backends-velox/src/test/scala/org/apache/gluten/extension/CudfNodeValidationRuleSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.scalatest.funsuite.AnyFunSuite + +class CudfNodeValidationRuleSuite extends AnyFunSuite { + + // Records whether the (lazy) native validator was invoked, so we can assert it is NOT + // called on paths that must short-circuit before touching the GPU. + private class CountingValidator(result: Boolean) extends (() => Boolean) { + var called: Boolean = false + override def apply(): Boolean = { + called = true + result + } + } + + test("table-reading stage is rejected when GPU table scan is disabled, without validating") { + val validator = new CountingValidator(true) + val decision = CudfNodeValidationRule.decideOffload( + hasLeaf = true, + enableTableScan = false, + enableValidation = true, + validator) + assert(!decision) + assert(!validator.called, "native validator must not be called for a rejected scan stage") + } + + test("table-reading stage is validated when GPU table scan is enabled") { + // This is the regression guard: previously this path tagged GPU unconditionally and + // skipped validation. Now an unsupported stage (validator -> false) must NOT be offloaded. + val rejecting = new CountingValidator(false) + assert( + !CudfNodeValidationRule.decideOffload( + hasLeaf = true, + enableTableScan = true, + enableValidation = true, + rejecting)) + assert(rejecting.called, "native validator must run for a scan stage when table scan is on") + + // A supported scan stage (validator -> true) is still offloaded. + assert( + CudfNodeValidationRule.decideOffload( + hasLeaf = true, + enableTableScan = true, + enableValidation = true, + new CountingValidator(true))) + } + + test("non-scan stage is offloaded only when validation passes") { + assert( + CudfNodeValidationRule.decideOffload( + hasLeaf = false, + enableTableScan = false, + enableValidation = true, + new CountingValidator(true))) + assert( + !CudfNodeValidationRule.decideOffload( + hasLeaf = false, + enableTableScan = false, + enableValidation = true, + new CountingValidator(false))) + } + + test("validation disabled offloads optimistically without validating") { + Seq(true, false).foreach { + hasLeaf => + val validator = new CountingValidator(false) + assert( + CudfNodeValidationRule.decideOffload( + hasLeaf = hasLeaf, + enableTableScan = true, + enableValidation = false, + validator)) + assert(!validator.called, "native validator must not be called when validation is disabled") + } + } +}