diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 832932f0a49..36c9813fbb8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -56,6 +56,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil} import org.apache.spark.sql.execution.window._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SparkVersionUtil @@ -403,6 +404,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { original: GetMapValue): ExpressionTransformer = GetMapValueTransformer(substraitExprName, left, right, failOnError = false, original) + /** Transform map_from_entries to Substrait. */ + override def genMapFromEntriesTransformer( + substraitExprName: String, + child: ExpressionTransformer, + expr: Expression): ExpressionTransformer = { + val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) + val chExprName = + if (mapKeyDedupPolicy.toString == SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { + ExpressionNames.MAP_FROM_ENTRIES_LAST_WIN + } else { + substraitExprName + } + GenericExpressionTransformer(chExprName, Seq(child), expr) + } + /** * Generate ShuffleDependency for ColumnarShuffleExchangeExec. * diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 07d44cc309e..602d57d4a88 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.clickhouse.CHConfig import org.apache.gluten.config.GlutenConfig import org.apache.gluten.expression.{FlattenedAnd, FlattenedOr} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -1072,6 +1072,69 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } } + test("Test map_from_entries") { + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + (ConstantFolding.ruleName + "," + NullPropagation.ruleName)) { + val query = + """ + |select id, map_from_entries(entries) from ( + | select id, + | case + | when id = 0 then array( + | named_struct('key', cast(1 as int), 'value', 'a'), + | named_struct('key', cast(2 as int), 'value', cast(null as string))) + | when id = 1 then cast(array() as array>) + | when id = 2 then cast(null as array>) + | else array( + | cast(null as struct), + | named_struct('key', cast(4 as int), 'value', 'd')) + | end as entries + | from range(4) + |) order by id + |""".stripMargin + runQueryAndCompare(query)(checkGlutenPlan[ProjectExecTransformer]) + runQueryAndCompare( + "select map_from_entries(cast(array() as array>)) " + + "from range(1)")(checkGlutenPlan[ProjectExecTransformer]) + + intercept[SparkException] { + sql( + """ + |select map_from_entries(array( + | named_struct('key', cast(null as int), 'value', 'a'))) + |from range(1) + |""".stripMargin).collect() + } + + intercept[SparkException] { + sql( + """ + |select map_from_entries(array( + | named_struct('key', cast(1 as int), 'value', 'a'), + | named_struct('key', cast(1 as int), 'value', 'b'))) + |from range(1) + |""".stripMargin).collect() + } + } + } + + test("Test map_from_entries with LAST_WIN map key policy") { + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + (ConstantFolding.ruleName + "," + NullPropagation.ruleName), + SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString + ) { + runQueryAndCompare( + """ + |select map_from_entries(array( + | named_struct('key', cast(1 as int), 'value', 'a'), + | named_struct('key', cast(1 as int), 'value', 'b'))) + |from range(1) + |""".stripMargin)(checkGlutenPlan[ProjectExecTransformer]) + } + } + test("Test transform_keys/transform_values") { val sql = """ diff --git a/cpp-ch/local-engine/Functions/SparkFunctionMapFromEntries.cpp b/cpp-ch/local-engine/Functions/SparkFunctionMapFromEntries.cpp new file mode 100644 index 00000000000..eafc1fa3b3d --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionMapFromEntries.cpp @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int ILLEGAL_COLUMN; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + +template +class SparkFunctionMapFromEntries : public IFunction +{ +public: + static constexpr auto name = last_win ? "sparkMapFromEntriesLastWin" : "sparkMapFromEntries"; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + const auto * array_type = checkAndGetDataType(removeNullable(arguments[0]).get()); + if (!array_type) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Argument for function {} must be Array, but it has type {}", + getName(), + arguments[0]->getName()); + + const auto & entry_type = array_type->getNestedType(); + const auto entry_type_without_nullable = removeNullable(entry_type); + if (isNothing(entry_type_without_nullable)) + { + auto map_type = std::make_shared( + std::make_shared(), + std::make_shared()); + if (arguments[0]->isNullable() || entry_type->isNullable()) + return makeNullable(map_type); + return map_type; + } + + const auto * tuple_type = checkAndGetDataType(entry_type_without_nullable.get()); + if (!tuple_type || tuple_type->getElements().size() != 2) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Argument for function {} must be Array of pair Tuple, but it has nested type {}", + getName(), + entry_type->getName()); + + const auto & elements = tuple_type->getElements(); + auto map_type = std::make_shared(removeNullableOrLowCardinalityNullable(elements[0]), elements[1]); + if (arguments[0]->isNullable() || entry_type->isNullable()) + return makeNullable(map_type); + return map_type; + } + + ColumnPtr executeImpl( + const ColumnsWithTypeAndName & arguments, + const DataTypePtr & result_type, + size_t input_rows_count) const override + { + ColumnPtr holder = arguments[0].column->convertToFullColumnIfConst(); + + const PaddedPODArray * input_null_map = nullptr; + if (const auto * nullable = checkAndGetColumn(holder.get())) + { + input_null_map = &nullable->getNullMapData(); + holder = nullable->getNestedColumnPtr(); + } + + const auto * entries_array = checkAndGetColumn(holder.get()); + if (!entries_array) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Argument column for function {} must be Array, but it is {}", + getName(), + holder->getName()); + + const auto & entries_offsets = entries_array->getOffsets(); + const IColumn * entries_data = &entries_array->getData(); + const PaddedPODArray * entry_null_map = nullptr; + if (const auto * nullable_entries = checkAndGetColumn(entries_data)) + { + entry_null_map = &nullable_entries->getNullMapData(); + entries_data = &nullable_entries->getNestedColumn(); + } + + const auto & result_map_type = assert_cast(*removeNullable(result_type)); + if (isNothing(entries_data->getDataType())) + { + auto result_key_column = result_map_type.getKeyType()->createColumn(); + auto result_value_column = result_map_type.getValueType()->createColumn(); + auto result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count, 0); + + ColumnUInt8::MutablePtr result_null_map; + PaddedPODArray * result_null_map_data = nullptr; + if (result_type->isNullable()) + { + result_null_map = ColumnUInt8::create(input_rows_count, 0); + result_null_map_data = &result_null_map->getData(); + } + + size_t previous_entry_offset = 0; + for (size_t row = 0; row < input_rows_count; ++row) + { + const auto current_entry_offset = entries_offsets[row]; + const bool input_map_null = input_null_map && (*input_null_map)[row]; + bool entry_map_null = false; + if (!input_map_null && entry_null_map) + { + for (size_t entry = previous_entry_offset; entry < current_entry_offset; ++entry) + { + if ((*entry_null_map)[entry]) + { + entry_map_null = true; + break; + } + } + } + if (result_null_map_data) + (*result_null_map_data)[row] = input_map_null || entry_map_null; + previous_entry_offset = current_entry_offset; + } + + auto nested_column = ColumnArray::create( + ColumnTuple::create( + Columns{std::move(result_key_column), std::move(result_value_column)}), + std::move(result_offsets_column)); + auto result_column = ColumnMap::create(std::move(nested_column)); + if (result_type->isNullable()) + return ColumnNullable::create(std::move(result_column), std::move(result_null_map)); + return result_column; + } + + const auto * entries_tuple = checkAndGetColumn(entries_data); + if (!entries_tuple || entries_tuple->tupleSize() != 2) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Nested column for function {} must be Tuple with 2 elements, but it is {}", + getName(), + entries_data->getName()); + + const auto & key_column = entries_tuple->getColumn(0); + const auto & value_column = entries_tuple->getColumn(1); + ColumnPtr key_insert_holder; + const IColumn * key_insert_column = &key_column; + if (const auto * nullable_key_column = checkAndGetColumn(&key_column)) + key_insert_column = &nullable_key_column->getNestedColumn(); + else if (const auto * low_cardinality_key_column = checkAndGetColumn(&key_column); + low_cardinality_key_column && low_cardinality_key_column->nestedIsNullable()) + { + key_insert_holder = low_cardinality_key_column->cloneWithDefaultOnNull(); + key_insert_column = key_insert_holder.get(); + } + + auto result_key_column = result_map_type.getKeyType()->createColumn(); + auto result_value_column = result_map_type.getValueType()->createColumn(); + auto result_offsets_column = ColumnArray::ColumnOffsets::create(); + auto & result_offsets = result_offsets_column->getData(); + result_offsets.reserve(input_rows_count); + + ColumnUInt8::MutablePtr result_null_map; + PaddedPODArray * result_null_map_data = nullptr; + if (result_type->isNullable()) + { + result_null_map = ColumnUInt8::create(input_rows_count, 0); + result_null_map_data = &result_null_map->getData(); + } + + size_t previous_entry_offset = 0; + size_t result_offset = 0; + for (size_t row = 0; row < input_rows_count; ++row) + { + const auto current_entry_offset = entries_offsets[row]; + + if (input_null_map && (*input_null_map)[row]) + { + appendNullMap(result_null_map_data, row, result_offsets, result_offset); + previous_entry_offset = current_entry_offset; + continue; + } + + if (hasNullEntry(entry_null_map, previous_entry_offset, current_entry_offset)) + { + appendNullMap(result_null_map_data, row, result_offsets, result_offset); + previous_entry_offset = current_entry_offset; + continue; + } + + auto selected_entries = + selectEntriesForRow(key_column, previous_entry_offset, current_entry_offset); + appendSelectedEntries( + selected_entries, + *key_insert_column, + value_column, + *result_key_column, + *result_value_column, + result_offset); + + result_offsets.push_back(result_offset); + previous_entry_offset = current_entry_offset; + } + + auto nested_column = ColumnArray::create( + ColumnTuple::create(Columns{std::move(result_key_column), std::move(result_value_column)}), + std::move(result_offsets_column)); + auto result_column = ColumnMap::create(std::move(nested_column)); + if (result_type->isNullable()) + return ColumnNullable::create(std::move(result_column), std::move(result_null_map)); + return result_column; + } + +private: + struct UInt128Hash + { + size_t operator()(const UInt128 & value) const + { + return std::hash{}(value.items[0]) + ^ (std::hash{}(value.items[1]) << 1); + } + }; + + using SelectedEntry = std::pair; + using SelectedEntries = std::vector; + + static bool hasNullEntry( + const PaddedPODArray * entry_null_map, + size_t previous_entry_offset, + size_t current_entry_offset) + { + if (!entry_null_map) + return false; + + for (size_t entry = previous_entry_offset; entry < current_entry_offset; ++entry) + { + if ((*entry_null_map)[entry]) + return true; + } + return false; + } + + static void appendNullMap( + PaddedPODArray * result_null_map_data, + size_t row, + ColumnArray::Offsets & result_offsets, + size_t result_offset) + { + if (result_null_map_data) + (*result_null_map_data)[row] = 1; + result_offsets.push_back(result_offset); + } + + static SelectedEntries selectEntriesForRow( + const IColumn & key_column, + size_t previous_entry_offset, + size_t current_entry_offset) + { + SelectedEntries selected_entries; + selected_entries.reserve(current_entry_offset - previous_entry_offset); + + std::unordered_map first_selected_index_by_hash; + first_selected_index_by_hash.reserve(current_entry_offset - previous_entry_offset); + std::unordered_map, UInt128Hash> + collision_selected_indices_by_hash; + + for (size_t entry = previous_entry_offset; entry < current_entry_offset; ++entry) + { + if (key_column.isNullAt(entry)) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Cannot use NULL as map key in function {}", name); + + SipHash hash_function; + key_column.updateHashWithValue(entry, hash_function); + const UInt128 hash = hash_function.get128(); + + bool has_duplicate_key = false; + size_t duplicate_selected_index = 0; + const auto first_selected_index_it = first_selected_index_by_hash.find(hash); + if (first_selected_index_it != first_selected_index_by_hash.end()) + { + const auto first_selected_index = first_selected_index_it->second; + if (key_column.compareAt( + entry, + selected_entries[first_selected_index].first, + key_column, + 1) == 0) + { + has_duplicate_key = true; + duplicate_selected_index = first_selected_index; + } + else + { + const auto collision_selected_indices_it = + collision_selected_indices_by_hash.find(hash); + if (collision_selected_indices_it != collision_selected_indices_by_hash.end()) + { + for (const auto selected_index : collision_selected_indices_it->second) + { + if (key_column.compareAt( + entry, + selected_entries[selected_index].first, + key_column, + 1) == 0) + { + has_duplicate_key = true; + duplicate_selected_index = selected_index; + break; + } + } + } + } + } + + if (has_duplicate_key) + { + if constexpr (last_win) + { + selected_entries[duplicate_selected_index].second = entry; + continue; + } + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Duplicate map key is found in function {}", + name); + } + + if (first_selected_index_it == first_selected_index_by_hash.end()) + { + first_selected_index_by_hash.emplace(hash, selected_entries.size()); + } + else + { + collision_selected_indices_by_hash[hash].push_back(selected_entries.size()); + } + selected_entries.emplace_back(entry, entry); + } + + return selected_entries; + } + + static void appendSelectedEntries( + const SelectedEntries & selected_entries, + const IColumn & key_insert_column, + const IColumn & value_column, + IColumn & result_key_column, + IColumn & result_value_column, + size_t & result_offset) + { + for (const auto & selected_entry : selected_entries) + { + result_key_column.insertFrom(key_insert_column, selected_entry.first); + result_value_column.insertFrom(value_column, selected_entry.second); + ++result_offset; + } + } +}; + +REGISTER_FUNCTION(SparkMapFromEntries) +{ + factory.registerFunction>(); + factory.registerFunction>(); +} + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index c1b7dcdb2eb..ae760543bb8 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -173,6 +173,8 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(GetMapValue, get_map_value, arrayElementO REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapKeys, map_keys, mapKeys); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapValues, map_values, mapValues); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapFromArrays, map_from_arrays, mapFromArrays); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapFromEntries, map_from_entries, sparkMapFromEntries); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapFromEntriesLastWin, map_from_entries_last_win, sparkMapFromEntriesLastWin); // json functions REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FlattenJsonStringOnRequired, flattenJSONStringOnRequired, flattenJSONStringOnRequired); diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 7969d305025..4f322a66d88 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -339,7 +339,8 @@ object ExpressionConverter extends SQLConfHelper with Logging { BackendsApiManager.getSparkPlanExecApiInstance.genMapFromEntriesTransformer( substraitExprName, replaceWithExpressionTransformer0(m.child, attributeSeq, expressionsMap), - m) + m + ) case e: Explode => ExplodeTransformer( substraitExprName, diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index ca05f0ade1a..dd5c3a188b6 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -308,6 +308,7 @@ object ExpressionNames { final val TRANSFORM_KEYS = "transform_keys" final val TRANSFORM_VALUES = "transform_values" final val MAP_FROM_ENTRIES = "map_from_entries" + final val MAP_FROM_ENTRIES_LAST_WIN = "map_from_entries_last_win" final val STR_TO_MAP = "str_to_map" final val MAP_FILTER = "map_filter" final val MAP_CONTAINS_KEY = "map_contains_key"