diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java index cca16d06639b02..7208aeba6e9759 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java @@ -19,19 +19,20 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * merge one row relation into union, for easy to compute physical properties @@ -56,15 +57,13 @@ public Rule build() { } else { ImmutableList.Builder constantExprs = new Builder<>(); List projects = ((LogicalOneRowRelation) child).getProjects(); - for (int j = 0; j < projects.size(); j++) { - NamedExpression project = projects.get(j); - DataType targetType = u.getOutput().get(j).getDataType(); - if (project.getDataType().equals(targetType)) { - constantExprs.add(project); - } else { - constantExprs.add((NamedExpression) project.withChildren( - TypeCoercionUtils.castIfNotSameType(project.child(0), targetType))); - } + Map replaceMap = new HashMap<>(); + for (NamedExpression project : projects) { + replaceMap.put(project.toSlot(), project); + } + for (Expression regularChildOutput : u.getRegularChildOutput(i)) { + constantExprs.add((NamedExpression) ExpressionUtils.replace( + regularChildOutput, replaceMap)); } constantExprsList.add(constantExprs.build()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimitDistinctThroughUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimitDistinctThroughUnion.java index df3069105d2924..1a136673a1df04 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimitDistinctThroughUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimitDistinctThroughUnion.java @@ -68,11 +68,11 @@ public List buildRules() { LogicalUnion union = agg.child(); List newChildren = new ArrayList<>(); - for (Plan child : union.children()) { + for (int childIdx = 0; childIdx < union.arity(); ++childIdx) { Map replaceMap = new HashMap<>(); for (int i = 0; i < union.getOutputs().size(); ++i) { NamedExpression output = union.getOutputs().get(i); - replaceMap.put(output, child.getOutput().get(i)); + replaceMap.put(output, union.getRegularChildOutput(childIdx).get(i)); } List newGroupBy = agg.getGroupByExpressions().stream() @@ -82,7 +82,8 @@ public List buildRules() { .map(expr -> ExpressionUtils.replaceNameExpression(expr, replaceMap)) .collect(Collectors.toList()); - LogicalAggregate newAgg = new LogicalAggregate<>(newGroupBy, newOutputs, child); + LogicalAggregate newAgg = new LogicalAggregate<>( + newGroupBy, newOutputs, union.child(childIdx)); LogicalLimit newLimit = limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0, newAgg); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java index f6e944a7a914f0..cde42c7ad243d9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java @@ -68,18 +68,18 @@ public List buildRules() { LogicalAggregate agg = topN.child(); LogicalUnion union = agg.child(); List newChildren = new ArrayList<>(); - for (Plan child : union.children()) { + for (int childIdx = 0; childIdx < union.arity(); ++childIdx) { Map replaceMap = new HashMap<>(); for (int i = 0; i < union.getOutputs().size(); ++i) { NamedExpression output = union.getOutputs().get(i); - replaceMap.put(output, child.getOutput().get(i)); + replaceMap.put(output, union.getRegularChildOutput(childIdx).get(i)); } List orderKeys = topN.getOrderKeys().stream() .map(orderKey -> orderKey.withExpression( ExpressionUtils.replace(orderKey.getExpr(), replaceMap))) .collect(ImmutableList.toImmutableList()); newChildren.add(new LogicalTopN<>(orderKeys, topN.getLimit() + topN.getOffset(), 0, - PlanUtils.distinct(child))); + PlanUtils.distinct(union.child(childIdx)))); } if (union.children().equals(newChildren)) { return null; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SetOperationOutputMappingTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SetOperationOutputMappingTest.java new file mode 100644 index 00000000000000..be2abce91289f3 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SetOperationOutputMappingTest.java @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; + +class SetOperationOutputMappingTest extends TestWithFeService implements MemoPatternMatchSupported { + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + useDatabase("test"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + + createTable("CREATE TABLE set_operation_output_mapping_t (\n" + + " k int NULL,\n" + + " v int NULL\n" + + ") ENGINE=OLAP\n" + + "DISTRIBUTED BY HASH(k) BUCKETS 1\n" + + "PROPERTIES (\n" + + " \"replication_allocation\" = \"tag.location.default: 1\"\n" + + ");"); + } + + @Override + protected void runBeforeEach() throws Exception { + StatementScopeIdGenerator.clear(); + } + + @Test + void testMergeOneRowRelationUsesRegularChildOutput() { + Alias firstProject = new Alias(new ExprId(1), new IntegerLiteral(10), "first_col"); + Alias secondProject = new Alias(new ExprId(2), new IntegerLiteral(20), "second_col"); + SlotReference secondProjectSlot = (SlotReference) secondProject.toSlot(); + LogicalOneRowRelation oneRowRelation = new LogicalOneRowRelation( + new RelationId(1), ImmutableList.of(firstProject, secondProject)); + + SlotReference unionOutput = new SlotReference(new ExprId(10), "second_col", + IntegerType.INSTANCE, false, ImmutableList.of()); + LogicalUnion union = new LogicalUnion(Qualifier.ALL, + ImmutableList.of(unionOutput), + ImmutableList.of(ImmutableList.of(secondProjectSlot)), + ImmutableList.of(), + false, + ImmutableList.of(oneRowRelation)); + + Plan rewritten = PlanChecker.from(MemoTestUtils.createConnectContext(), union) + .applyTopDown(new MergeOneRowRelationIntoUnion()) + .getPlan(); + + Assertions.assertInstanceOf(LogicalUnion.class, rewritten); + LogicalUnion rewrittenUnion = (LogicalUnion) rewritten; + Assertions.assertEquals(0, rewrittenUnion.children().size()); + Assertions.assertEquals(1, rewrittenUnion.getConstantExprsList().size()); + + List constantExprs = rewrittenUnion.getConstantExprsList().get(0); + Assertions.assertEquals(1, constantExprs.size()); + Assertions.assertEquals(secondProject.getExprId(), constantExprs.get(0).getExprId()); + Assertions.assertInstanceOf(IntegerLiteral.class, constantExprs.get(0).child(0)); + Assertions.assertEquals(20, ((IntegerLiteral) constantExprs.get(0).child(0)).getValue()); + } + + @Test + void testPushDownTopNDistinctThroughUnionUsesRegularChildOutput() { + String sql = "SELECT *\n" + + "FROM (\n" + + " SELECT *\n" + + " FROM (\n" + + " SELECT k, v, v, ROW_NUMBER() OVER (ORDER BY k DESC)\n" + + " FROM set_operation_output_mapping_t\n" + + " ) u1\n" + + " UNION\n" + + " SELECT *\n" + + " FROM (\n" + + " SELECT k, v, v, ROW_NUMBER() OVER (ORDER BY k DESC)\n" + + " FROM set_operation_output_mapping_t\n" + + " ) u2\n" + + ") u\n" + + "ORDER BY 1\n" + + "LIMIT 10"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalUnion( + logicalTopN().when(topN -> topN.getLimit() == 10 && topN.getOffset() == 0), + logicalTopN().when(topN -> topN.getLimit() == 10 && topN.getOffset() == 0) + ) + ); + } + + @Test + void testPushDownLimitDistinctThroughUnionUsesRegularChildOutput() { + String sql = "SELECT DISTINCT *\n" + + "FROM (\n" + + " SELECT *\n" + + " FROM (\n" + + " SELECT k, v, v, ROW_NUMBER() OVER (ORDER BY k DESC)\n" + + " FROM set_operation_output_mapping_t\n" + + " ) u1\n" + + " UNION ALL\n" + + " SELECT *\n" + + " FROM (\n" + + " SELECT k, v, v, ROW_NUMBER() OVER (ORDER BY k DESC)\n" + + " FROM set_operation_output_mapping_t\n" + + " ) u2\n" + + ") u\n" + + "LIMIT 10"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalUnion( + logicalLimit().when(limit -> limit.getLimit() == 10 && limit.getOffset() == 0), + logicalLimit().when(limit -> limit.getLimit() == 10 && limit.getOffset() == 0) + ) + ); + } +} diff --git a/regression-test/suites/nereids_rules_p0/merge_one_row_relation/merge_one_row_relation_into_union.groovy b/regression-test/suites/nereids_rules_p0/merge_one_row_relation/merge_one_row_relation_into_union.groovy new file mode 100644 index 00000000000000..6e865d07c7df39 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/merge_one_row_relation/merge_one_row_relation_into_union.groovy @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("merge_one_row_relation_into_union") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + sql """ + EXPLAIN SHAPE PLAN SELECT v + FROM ( + SELECT CAST('unused' AS CHAR(6)) AS k, CAST('alpha' AS CHAR(6)) AS v + UNION ALL + SELECT CAST('unused' AS CHAR(6)), CAST('beta' AS CHAR(6)) + ) u + GROUP BY v + """ +} diff --git a/regression-test/suites/nereids_rules_p0/push_down_limit_distinct/push_down_limit_distinct_through_union.groovy b/regression-test/suites/nereids_rules_p0/push_down_limit_distinct/push_down_limit_distinct_through_union.groovy new file mode 100644 index 00000000000000..69fbb8b8ae1a6a --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/push_down_limit_distinct/push_down_limit_distinct_through_union.groovy @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("push_down_limit_distinct_through_union") { + sql "set parallel_pipeline_task_num=2" + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql "set runtime_filter_mode=OFF" + sql "SET disable_join_reorder=true" + sql "set disable_nereids_rules=PRUNE_EMPTY_PARTITION" + + sql """ + DROP TABLE IF EXISTS push_down_limit_distinct_union_t; + """ + + sql """ + CREATE TABLE IF NOT EXISTS push_down_limit_distinct_union_t ( + `id` int NULL, + `score` int NULL + ) ENGINE = OLAP + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """ + explain shape plan select distinct * + from ( + select * + from ( + select id, score, score, row_number() over (order by id desc) + from push_down_limit_distinct_union_t + ) u1 + union all + select * + from ( + select id, score, score, row_number() over (order by id desc) + from push_down_limit_distinct_union_t + ) u2 + ) u + limit 10; + """ +} diff --git a/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy b/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy index 16042fbebf0190..0e81db6367d119 100644 --- a/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy +++ b/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy @@ -72,7 +72,26 @@ suite("push_down_top_n_distinct_through_union") { explain shape plan select * from ((select * from table2 t1 limit 5) union (select * from table2 t2 limit 5)) sub order by id limit 10; """ + sql """ + explain shape plan select * + from ( + select * + from ( + select id, score, score, row_number() over (order by id desc) + from table2 + ) u1 + union + select * + from ( + select id, score, score, row_number() over (order by id desc) + from table2 + ) u2 + ) u + order by 1 + limit 10; + """ + qt_push_down_topn_union_complex_conditions """ explain shape plan select * from (select * from table2 t1 where t1.score > 10 and t1.name = 'Test' union select * from table2 t2 where t2.id < 5 and t2.score < 20) sub order by id limit 10; """ -} \ No newline at end of file +}