diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 146e0feb8e..8c41070e0b 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -53,6 +53,7 @@ use datafusion_spark::function::math::hex::SparkHex; use datafusion_spark::function::math::width_bucket::SparkWidthBucket; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; +use datafusion_spark::function::string::elt::SparkElt; use futures::poll; use futures::stream::StreamExt; use jni::objects::JByteBuffer; @@ -355,6 +356,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkElt::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 73b88ae935..ffbfb2e39c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -174,7 +174,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[StringTrimRight] -> CometScalarFunction("rtrim"), classOf[Left] -> CometLeft, classOf[Substring] -> CometSubstring, - classOf[Upper] -> CometUpper) + classOf[Upper] -> CometUpper, + classOf[Elt] -> CometElt) private val bitwiseExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[BitwiseAnd] -> CometBitwiseAnd, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index ea42b245aa..94927a39f3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,8 +21,8 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Left, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} -import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Elt, Expression, InitCap, Left, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.types._ import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -289,6 +289,16 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { } } +object CometElt extends CometScalarFunction[Elt]("elt") { + + override def getSupportLevel(expr: Elt): SupportLevel = { + if (expr.failOnError) { + return Unsupported(Some("ANSI mode not supported")) + } + Compatible(None) + } +} + trait CommonStringExprs { def stringDecode( diff --git a/spark/src/test/resources/sql-tests/expressions/string/elt.sql b/spark/src/test/resources/sql-tests/expressions/string/elt.sql new file mode 100644 index 0000000000..fe49bd25fd --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/elt.sql @@ -0,0 +1,27 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_elt(a string, b string, c string, idx int) USING parquet + +statement +INSERT INTO test_elt VALUES ('a', 'b', 'c', 1), ('a', 'b', '', 2), (NULL, 'b', 'c', NULL), ('a', NULL, 'c', -100), (NULL, NULL, NULL, 0) + +query +SELECT elt(idx, a, b, c) FROM test_elt diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 2a2932c643..b5f21ec4e0 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -23,8 +23,9 @@ import scala.util.Random import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.spark.sql.types.{DataTypes, StringType, StructField, StructType} import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} @@ -378,4 +379,43 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("elt") { + val wrongNumArgsWithoutSuggestionExceptionMsg = + "[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `elt` requires > 1 parameters but the actual number is 1." + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "false", + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + val r = new Random(42) + val fieldsCount = 10 + val indexes = Seq.range(1, fieldsCount) + val edgeCasesIndexes = Seq(-1, 0, -100, fieldsCount + 100) + val schema = indexes + .foldLeft(new StructType())((schema, idx) => + schema.add(s"c$idx", StringType, nullable = true)) + val df = FuzzDataGenerator.generateDataFrame( + r, + spark, + schema, + 100, + DataGenOptions(maxStringLength = 6)) + df.withColumn( + "idx", + lit(Random.shuffle(indexes ++ edgeCasesIndexes).headOption.getOrElse(-1))) + .createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql(s"SELECT elt(idx, ${schema.fieldNames.mkString(",")}) FROM t1")) + checkSparkAnswerAndOperator( + sql(s"SELECT elt(cast(null as int), ${schema.fieldNames.mkString(",")}) FROM t1")) + checkSparkAnswerMaybeThrows(sql("SELECT elt(1) FROM t1")) match { + case (Some(spark), Some(comet)) => + assert(spark.getMessage.contains(wrongNumArgsWithoutSuggestionExceptionMsg)) + assert(comet.getMessage.contains(wrongNumArgsWithoutSuggestionExceptionMsg)) + case (spark, comet) => + fail( + s"Expected Spark and Comet to throw exception, but got\nSpark: $spark\nComet: $comet") + } + checkSparkAnswerAndOperator("SELECT elt(2, 'a', 'b', 'c')") + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index c7c750aed6..d219477d77 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -76,7 +76,8 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("substring", "select substring(c1, 1, 100) from parquetV1Table"), StringExprConfig("translate", "select translate(c1, '123456', 'aBcDeF') from parquetV1Table"), StringExprConfig("trim", "select trim(c1) from parquetV1Table"), - StringExprConfig("upper", "select upper(c1) from parquetV1Table")) + StringExprConfig("upper", "select upper(c1) from parquetV1Table"), + StringExprConfig("elt", "select elt(2, c1, c1) from parquetV1Table")) override def runCometBenchmark(mainArgs: Array[String]): Unit = { runBenchmarkWithTable("String expressions", 1024) { v =>