diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java index cca16d06639b02..cf136c11097f33 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java @@ -19,19 +19,20 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * merge one row relation into union, for easy to compute physical properties @@ -56,15 +57,13 @@ public Rule build() { } else { ImmutableList.Builder constantExprs = new Builder<>(); List projects = ((LogicalOneRowRelation) child).getProjects(); - for (int j = 0; j < projects.size(); j++) { - NamedExpression project = projects.get(j); - DataType targetType = u.getOutput().get(j).getDataType(); - if (project.getDataType().equals(targetType)) { - constantExprs.add(project); - } else { - constantExprs.add((NamedExpression) project.withChildren( - TypeCoercionUtils.castIfNotSameType(project.child(0), targetType))); - } + Map replaceMap = new HashMap<>(); + for (NamedExpression expr : projects) { + replaceMap.put(expr.toSlot(), expr); + } + for (Expression regularChildOutput : u.getRegularChildOutput(i)) { + constantExprs.add((NamedExpression) ExpressionUtils.replace( + regularChildOutput, replaceMap)); } constantExprsList.add(constantExprs.build()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnionTest.java new file mode 100644 index 00000000000000..99a18c75059499 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnionTest.java @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; + +public class MergeOneRowRelationIntoUnionTest { + + @Test + public void testMergeUsesRegularChildOutputToBuildConstantExprs() { + Alias firstProject = new Alias(new ExprId(1), new IntegerLiteral(10), "first_col"); + Alias secondProject = new Alias(new ExprId(2), new IntegerLiteral(20), "second_col"); + SlotReference secondProjectSlot = (SlotReference) secondProject.toSlot(); + LogicalOneRowRelation oneRowRelation = new LogicalOneRowRelation( + new RelationId(1), ImmutableList.of(firstProject, secondProject)); + + SlotReference unionOutput = new SlotReference(new ExprId(10), "second_col", + IntegerType.INSTANCE, false, ImmutableList.of()); + LogicalUnion union = new LogicalUnion(Qualifier.ALL, + ImmutableList.of(unionOutput), + ImmutableList.of(ImmutableList.of(secondProjectSlot)), + ImmutableList.of(), + false, + ImmutableList.of(oneRowRelation)); + + Plan rewritten = PlanChecker.from(MemoTestUtils.createConnectContext(), union) + .applyTopDown(new MergeOneRowRelationIntoUnion()) + .getPlan(); + + Assertions.assertInstanceOf(LogicalUnion.class, rewritten); + LogicalUnion rewrittenUnion = (LogicalUnion) rewritten; + Assertions.assertEquals(0, rewrittenUnion.children().size()); + Assertions.assertEquals(1, rewrittenUnion.getConstantExprsList().size()); + + List constantExprs = rewrittenUnion.getConstantExprsList().get(0); + Assertions.assertEquals(1, constantExprs.size()); + Assertions.assertEquals(secondProject.getExprId(), constantExprs.get(0).getExprId()); + Assertions.assertInstanceOf(IntegerLiteral.class, constantExprs.get(0).child(0)); + Assertions.assertEquals(20, ((IntegerLiteral) constantExprs.get(0).child(0)).getValue()); + } +} diff --git a/regression-test/data/query_p0/compress_materialize/compress_materialize.out b/regression-test/data/query_p0/compress_materialize/compress_materialize.out index 79a06a9640ab21..624962aef1f632 100644 --- a/regression-test/data/query_p0/compress_materialize/compress_materialize.out +++ b/regression-test/data/query_p0/compress_materialize/compress_materialize.out @@ -53,21 +53,11 @@ a 1 δΈ­ 8 bb 3 --- !explain_sort_agg -- -cost = 19.563333333333333 -PhysicalResultSink[294] ( outputExprs=[v1#1] ) -+--PhysicalProject[289]@5 ( stats=1, projects=[v1#1] ) - +--PhysicalQuickSort[284]@4 ( stats=1, orderKeys=[encode_as_bigint(v1)#4 asc null first], phase=MERGE_SORT ) - +--PhysicalDistribute[279]@7 ( stats=1, distributionSpec=DistributionSpecGather ) - +--PhysicalQuickSort[274]@7 ( stats=1, orderKeys=[encode_as_bigint(v1)#4 asc null first], phase=LOCAL_SORT ) - +--PhysicalProject[269]@3 ( stats=1, projects=[decode_as_varchar(encode_as_bigint(v1)#3) AS `v1`#1, encode_as_bigint(decode_as_varchar(encode_as_bigint(v1)#3)) AS `encode_as_bigint(decode_as_varchar(encode_as_bigint(v1)))`#4], multi_proj=l0([encode_as_bigint(v1)#3, decode_as_varchar(encode_as_bigint(v1)#3) AS `decode_as_varchar(encode_as_bigint(v1))`#5])l1([decode_as_varchar(encode_as_bigint(v1))#5 AS `v1`#1, encode_as_bigint(decode_as_varchar(encode_as_bigint(v1))#5) AS `encode_as_bigint(decode_as_varchar(encode_as_bigint(v1)))`#4]) ) - +--PhysicalHashAggregate[264]@2 ( stats=1, aggPhase=GLOBAL, aggMode=BUFFER_TO_RESULT, maybeUseStreaming=false, groupByExpr=[encode_as_bigint(v1)#3], outputExpr=[encode_as_bigint(v1)#3], partitionExpr=Optional[[encode_as_bigint(v1)#3]], topnFilter=false, topnPushDown=false ) - +--PhysicalDistribute[259]@8 ( stats=1, distributionSpec=DistributionSpecHash ( orderedShuffledColumns=[3], shuffleType=EXECUTION_BUCKETED, tableId=-1, selectedIndexId=-1, partitionIds=[], equivalenceExprIds=[[3]], exprIdToEquivalenceSet={3=0} ) ) - +--PhysicalHashAggregate[254]@8 ( stats=1, aggPhase=LOCAL, aggMode=INPUT_TO_BUFFER, maybeUseStreaming=true, groupByExpr=[encode_as_bigint(v1)#3], outputExpr=[encode_as_bigint(v1)#3], partitionExpr=Optional[[encode_as_bigint(v1)#3]], topnFilter=false, topnPushDown=false ) - +--PhysicalProject[249]@1 ( stats=1, projects=[encode_as_bigint(v1#1) AS `encode_as_bigint(v1)`#3] ) - +--PhysicalOlapScan[t1]@0 ( stats=1 ) +-- !const_union_group_by -- +alpha 1 +beta 1 --- !exec_sort_agg -- -a -b +-- !const_union_project_order -- +alpha +beta diff --git a/regression-test/suites/query_p0/compress_materialize/compress_materialize.groovy b/regression-test/suites/query_p0/compress_materialize/compress_materialize.groovy index 3e5266cd420332..a94b33c99f0b55 100644 --- a/regression-test/suites/query_p0/compress_materialize/compress_materialize.groovy +++ b/regression-test/suites/query_p0/compress_materialize/compress_materialize.groovy @@ -194,5 +194,26 @@ suite("compress_materialize") { qt_sort "select * from compressSort order by k desc, v"; qt_sort "select * from compressSort order by k desc nulls last"; qt_sort "select * from compressSort order by k desc nulls last, v limit 3"; -} + order_qt_const_union_group_by """ + SELECT /*+ SET_VAR(enable_compress_materialize=true) */ + v, + COUNT(*) AS c + FROM ( + SELECT CAST('alpha' AS CHAR(6)) AS v + UNION ALL + SELECT CAST('beta' AS CHAR(6)) + ) u + GROUP BY v + """ + + order_qt_const_union_project_order """ + SELECT v + FROM ( + SELECT CAST('unused' AS CHAR(6)) AS k, CAST('alpha' AS CHAR(6)) AS v + UNION ALL + SELECT CAST('unused' AS CHAR(6)), CAST('beta' AS CHAR(6)) + ) u + GROUP BY v + """ +}