Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SparkVersionUtil
Expand Down Expand Up @@ -403,6 +404,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
original: GetMapValue): ExpressionTransformer =
GetMapValueTransformer(substraitExprName, left, right, failOnError = false, original)

/** Transform map_from_entries to Substrait. */
override def genMapFromEntriesTransformer(
substraitExprName: String,
child: ExpressionTransformer,
expr: Expression): ExpressionTransformer = {
val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)
val chExprName =
if (mapKeyDedupPolicy.toString == SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
ExpressionNames.MAP_FROM_ENTRIES_LAST_WIN
} else {
substraitExprName
}
GenericExpressionTransformer(chExprName, Seq(child), expr)
}

/**
* Generate ShuffleDependency for ColumnarShuffleExchangeExec.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.clickhouse.CHConfig
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.expression.{FlattenedAnd, FlattenedOr}

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -1072,6 +1072,69 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
}
}

test("Test map_from_entries") {
withSQLConf(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
(ConstantFolding.ruleName + "," + NullPropagation.ruleName)) {
val query =
"""
|select id, map_from_entries(entries) from (
| select id,
| case
| when id = 0 then array(
| named_struct('key', cast(1 as int), 'value', 'a'),
| named_struct('key', cast(2 as int), 'value', cast(null as string)))
| when id = 1 then cast(array() as array<struct<key:int,value:string>>)
| when id = 2 then cast(null as array<struct<key:int,value:string>>)
| else array(
| cast(null as struct<key:int,value:string>),
| named_struct('key', cast(4 as int), 'value', 'd'))
| end as entries
| from range(4)
|) order by id
|""".stripMargin
runQueryAndCompare(query)(checkGlutenPlan[ProjectExecTransformer])
runQueryAndCompare(
"select map_from_entries(cast(array() as array<struct<key:int,value:string>>)) " +
"from range(1)")(checkGlutenPlan[ProjectExecTransformer])

intercept[SparkException] {
sql(
"""
|select map_from_entries(array(
| named_struct('key', cast(null as int), 'value', 'a')))
|from range(1)
|""".stripMargin).collect()
}

intercept[SparkException] {
sql(
"""
|select map_from_entries(array(
| named_struct('key', cast(1 as int), 'value', 'a'),
| named_struct('key', cast(1 as int), 'value', 'b')))
|from range(1)
|""".stripMargin).collect()
}
}
}

test("Test map_from_entries with LAST_WIN map key policy") {
withSQLConf(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
(ConstantFolding.ruleName + "," + NullPropagation.ruleName),
SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString
) {
runQueryAndCompare(
"""
|select map_from_entries(array(
| named_struct('key', cast(1 as int), 'value', 'a'),
| named_struct('key', cast(1 as int), 'value', 'b')))
|from range(1)
|""".stripMargin)(checkGlutenPlan[ProjectExecTransformer])
}
}

test("Test transform_keys/transform_values") {
val sql =
"""
Expand Down
Loading
Loading