diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 27b6ad3b59..2c18cbd08d 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -272,11 +272,11 @@ - [ ] element_at - [ ] map - [ ] map_concat -- [ ] map_contains_key +- [x] map_contains_key - [ ] map_entries - [ ] map_from_arrays - [ ] map_from_entries -- [ ] map_keys +- [x] map_keys - [ ] map_values - [ ] str_to_map - [ ] try_element_at diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 73b88ae935..bac3929dca 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -127,6 +127,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapEntries] -> CometMapEntries, classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, + classOf[MapContainsKey] -> CometMapContainsKey, classOf[MapFromEntries] -> CometMapFromEntries) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 78b2180756..34e76215f3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -90,6 +90,23 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { } } +object CometMapContainsKey extends CometExpressionSerde[MapContainsKey] { + + override def convert( + expr: MapContainsKey, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + // Replace with array_has(map_keys(map), key) + val mapExpr = exprToProtoInternal(expr.left, inputs, binding) + val keyExpr = exprToProtoInternal(expr.right, inputs, binding) + + val mapKeysExpr = scalarFunctionExprToProto("map_keys", mapExpr) + + val mapContainsKeyExpr = scalarFunctionExprToProto("array_has", mapKeysExpr, keyExpr) + optExprWithInfo(mapContainsKeyExpr, expr, expr.children: _*) + } +} + object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") { val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries" val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries" diff --git a/spark/src/test/resources/sql-tests/expressions/map/map_contains_key.sql b/spark/src/test/resources/sql-tests/expressions/map/map_contains_key.sql new file mode 100644 index 0000000000..7dc3ce436d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/map_contains_key.sql @@ -0,0 +1,75 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- TODO: replace map_from_arrays with map whenever map is supported in Comet + +-- Basic integer key tests with map literals +query +select map_contains_key(map_from_arrays(array(1, 2), array('a', 'b')), 5) + +query +select map_contains_key(map_from_arrays(array(1, 2), array('a', 'b')), 1) + +-- Decimal type coercion tests +-- TODO: requires map cast to be supported in Comet +query spark_answer_only +select map_contains_key(map_from_arrays(array(1, 2), array('a', 'b')), 5.0) + +query spark_answer_only +select map_contains_key(map_from_arrays(array(1, 2), array('a', 'b')), 1.0) + +query spark_answer_only +select map_contains_key(map_from_arrays(array(1.0, 2), array('a', 'b')), 5) + +query spark_answer_only +select map_contains_key(map_from_arrays(array(1.0, 2), array('a', 'b')), 1) + +-- Empty map tests +-- TODO: requires casting from NullType to be supported in Comet +query spark_answer_only +select map_contains_key(map_from_arrays(array(), array()), 0) + +-- Test with table data +statement +CREATE TABLE test_map_contains_key(m map) USING parquet + +statement +INSERT INTO test_map_contains_key VALUES (map_from_arrays(array('a', 'b', 'c'), array(1, 2, 3))), (map_from_arrays(array('x'), array(10))), (map_from_arrays(array(), array())), (NULL) + +query +SELECT map_contains_key(m, 'a') FROM test_map_contains_key + +query +SELECT map_contains_key(m, 'x') FROM test_map_contains_key + +query +SELECT map_contains_key(m, 'missing') FROM test_map_contains_key + +-- Test with integer key map +statement +CREATE TABLE test_map_int_key(m map) USING parquet + +statement +INSERT INTO test_map_int_key VALUES (map_from_arrays(array(1, 2), array('a', 'b'))), (map_from_arrays(array(), array())), (NULL) + +query +SELECT map_contains_key(m, 1) FROM test_map_int_key + +query +SELECT map_contains_key(m, 5) FROM test_map_int_key