diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 2e217f6af0..960c09a168 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{ArrayType, MapType} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{createUnaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} object CometMapKeys extends CometExpressionSerde[MapKeys] { @@ -84,8 +84,33 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { val keyType = expr.left.dataType.asInstanceOf[ArrayType].elementType val valueType = expr.right.dataType.asInstanceOf[ArrayType].elementType val returnType = MapType(keyType = keyType, valueType = valueType) - val mapFromArraysExpr = - scalarFunctionExprToProtoWithReturnType("map", returnType, false, keysExpr, valuesExpr) - optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) + for { + isNotNullExprProto <- keyIsNotNullExpr(expr, inputs, binding) + mapFromArraysExprProto <- scalarFunctionExprToProto("map", keysExpr, valuesExpr) + nullLiteralExprProto <- exprToProtoInternal(Literal(null, returnType), inputs, binding) + } yield { + val caseWhenExprProto = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(isNotNullExprProto) + .addThen(mapFromArraysExprProto) + .setElseExpr(nullLiteralExprProto) + .build() + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(caseWhenExprProto) + .build() + } + } + + private def keyIsNotNullExpr( + expr: MapFromArrays, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createUnaryExpr( + expr, + expr.left, + inputs, + binding, + (builder, keyExpr) => builder.setIsNotNull(keyExpr)) } } diff --git a/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql b/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql index 5d6ac3d550..59cc55e852 100644 --- a/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql +++ b/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql @@ -26,9 +26,7 @@ INSERT INTO test_map_from_arrays VALUES (array('a', 'b', 'c'), array(1, 2, 3)), query spark_answer_only SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NOT NULL --- Comet bug: map_from_arrays(NULL, NULL) causes native crash "map key cannot be null" --- https://github.com/apache/datafusion-comet/issues/3327 -query ignore(https://github.com/apache/datafusion-comet/issues/3327) +query SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NULL -- literal arguments