diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java index f6e944a7a914f0..5eff82f73e7af2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnion.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; @@ -63,23 +64,23 @@ public class PushDownTopNDistinctThroughUnion implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of( logicalTopN(logicalAggregate(logicalUnion().when(union -> union.getQualifier() == Qualifier.ALL)) - .when(agg -> agg.isDistinct())) + .when(Aggregate::isDistinct)) .then(topN -> { LogicalAggregate agg = topN.child(); LogicalUnion union = agg.child(); List newChildren = new ArrayList<>(); - for (Plan child : union.children()) { + for (int i = 0; i < union.arity(); ++i) { Map replaceMap = new HashMap<>(); - for (int i = 0; i < union.getOutputs().size(); ++i) { - NamedExpression output = union.getOutputs().get(i); - replaceMap.put(output, child.getOutput().get(i)); + for (int j = 0; j < union.getOutputs().size(); ++j) { + NamedExpression output = union.getOutputs().get(j); + replaceMap.put(output, union.getRegularChildOutput(i).get(j)); } List orderKeys = topN.getOrderKeys().stream() .map(orderKey -> orderKey.withExpression( ExpressionUtils.replace(orderKey.getExpr(), replaceMap))) .collect(ImmutableList.toImmutableList()); newChildren.add(new LogicalTopN<>(orderKeys, topN.getLimit() + topN.getOffset(), 0, - PlanUtils.distinct(child))); + PlanUtils.distinct(union.child(i)))); } if (union.children().equals(newChildren)) { return null; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnionTest.java new file mode 100644 index 00000000000000..00cfaf1f89fb14 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughUnionTest.java @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +class PushDownTopNDistinctThroughUnionTest extends TestWithFeService implements MemoPatternMatchSupported { + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + useDatabase("test"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + + createTable("CREATE TABLE push_down_topn_distinct_union_t (\n" + + " zzc int NULL,\n" + + " sy_zcjlr int NULL\n" + + ") ENGINE=OLAP\n" + + "DISTRIBUTED BY HASH(zzc) BUCKETS 1\n" + + "PROPERTIES (\n" + + " \"replication_allocation\" = \"tag.location.default: 1\"\n" + + ");"); + } + + @Override + protected void runBeforeEach() throws Exception { + StatementScopeIdGenerator.clear(); + } + + @Test + void testUnionDistinctWithDuplicateOutputAndWindow() { + String sql = "SELECT *\n" + + "FROM (\n" + + " SELECT *\n" + + " FROM (\n" + + " SELECT zzc, sy_zcjlr, sy_zcjlr, ROW_NUMBER() OVER (ORDER BY zzc DESC)\n" + + " FROM push_down_topn_distinct_union_t\n" + + " ) u1\n" + + " UNION\n" + + " SELECT *\n" + + " FROM (\n" + + " SELECT zzc, sy_zcjlr, sy_zcjlr, ROW_NUMBER() OVER (ORDER BY zzc DESC)\n" + + " FROM push_down_topn_distinct_union_t\n" + + " ) u2\n" + + ") u\n" + + "ORDER BY 1\n" + + "LIMIT 10"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalUnion( + logicalTopN().when(topN -> topN.getLimit() == 10 && topN.getOffset() == 0), + logicalTopN().when(topN -> topN.getLimit() == 10 && topN.getOffset() == 0) + ) + ); + } +} diff --git a/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy b/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy index 16042fbebf0190..bebf8507aba4e3 100644 --- a/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy +++ b/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_union.groovy @@ -72,7 +72,27 @@ suite("push_down_top_n_distinct_through_union") { explain shape plan select * from ((select * from table2 t1 limit 5) union (select * from table2 t2 limit 5)) sub order by id limit 10; """ + // Make sure to use getRegularChildOutput to obtain the child output corresponding to the union output in PushDownTopNDistinctThroughUnion + sql """ + select * + from ( + select * + from ( + select id, score, score, row_number() over (order by id desc) + from table2 + ) u1 + union + select * + from ( + select id, score, score, row_number() over (order by id desc) + from table2 + ) u2 + ) u + order by 1 + limit 10; + """ + qt_push_down_topn_union_complex_conditions """ explain shape plan select * from (select * from table2 t1 where t1.score > 10 and t1.name = 'Test' union select * from table2 t2 where t2.id < 5 and t2.score < 20) sub order by id limit 10; """ -} \ No newline at end of file +}