Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.DistributeType;
import org.apache.doris.nereids.trees.plans.Plan;
Expand Down Expand Up @@ -74,16 +73,12 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -509,29 +504,34 @@ public Void visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg
addRequestPropertyToChildren(PhysicalProperties.GATHER);
return null;
}
List<ExprId> groupByExprIds = agg.getGroupByExpressions().stream()
.filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toList());
List<ExprId> groupByExprIds = new ArrayList<>();
Map<ExprId, Expression> groupByExprIdToExpr = Maps.newHashMap();
for (Expression groupByExpr : agg.getGroupByExpressions()) {
if (groupByExpr instanceof SlotReference) {
ExprId groupByExprId = ((SlotReference) groupByExpr).getExprId();
groupByExprIds.add(groupByExprId);
groupByExprIdToExpr.put(groupByExprId, groupByExpr);
}
}
DistributionSpec parentDist = requestPropertyFromParent.getDistributionSpec();
if (parentDist instanceof DistributionSpecHash) {
DistributionSpecHash distributionRequestFromParent = (DistributionSpecHash) parentDist;
List<ExprId> parentHashExprIds = distributionRequestFromParent.getOrderedShuffledColumns();
Set<ExprId> intersectIdSet = Sets.intersection(new HashSet<>(parentHashExprIds),
new HashSet<>(groupByExprIds));
if (!intersectIdSet.isEmpty() && intersectIdSet.size() < groupByExprIds.size()) {
List<ExprId> intersectIdList = new ArrayList<>();
for (ExprId exprId : parentHashExprIds) {
if (!intersectIdSet.contains(exprId)) {
continue;
}
intersectIdList.add(exprId);
}
if (shouldUseParent(intersectIdList, agg, context)) {
addRequestPropertyToChildren(
PhysicalProperties.createHash(intersectIdList, ShuffleType.REQUIRE));
List<ExprId> parentHashExprIdsInGroupBy = new ArrayList<>();
List<Expression> parentHashExprsInGroupBy = new ArrayList<>();
for (ExprId parentHashExprId : parentHashExprIds) {
Expression parentHashExpr = groupByExprIdToExpr.get(parentHashExprId);
if (parentHashExpr == null) {
continue;
}
parentHashExprIdsInGroupBy.add(parentHashExprId);
parentHashExprsInGroupBy.add(parentHashExpr);
}
if (!parentHashExprIdsInGroupBy.isEmpty()
&& parentHashExprIdsInGroupBy.size() < groupByExprIds.size()
&& shouldUseParent(parentHashExprsInGroupBy, agg, context)) {
addRequestPropertyToChildren(
PhysicalProperties.createHash(parentHashExprIdsInGroupBy, ShuffleType.REQUIRE));
}
}
addRequestPropertyToChildren(PhysicalProperties.createHash(groupByExprIds, ShuffleType.REQUIRE));
Expand All @@ -547,35 +547,24 @@ public Void visitPhysicalBucketedHashAggregate(
return null;
}

private boolean shouldUseParent(List<ExprId> parentHashExprIds, PhysicalHashAggregate<? extends Plan> agg,
private boolean shouldUseParent(List<Expression> parentHashExprs, PhysicalHashAggregate<? extends Plan> agg,
PlanContext context) {
if (!context.getConnectContext().getSessionVariable().aggShuffleUseParentKey) {
return false;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call — if there is no group expression at all, we cannot derive stats and should not gamble on the parent subset key.

Optional<GroupExpression> groupExpression = agg.getGroupExpression();
if (!groupExpression.isPresent()) {
return true;
return false;
}
if (agg.hasSourceRepeat()) {
return false;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core fix — previously returning true here meant the optimizer would use the narrower parent hash key even with no stats at all, which could lead to severe data skew and OOM. Returning false (fall through to the full group-by key) is the conservative and correct choice.

Consider adding a brief comment here explaining the rationale, e.g.:

// Without stats we cannot assess whether the parent subset key has enough
// NDV to avoid skew; fall back to the safe full group-by distribution.

Statistics aggChildStats = groupExpression.get().childStatistics(0);
if (aggChildStats == null) {
return true;
}
List<Slot> aggChildOutput = agg.child().getOutput();
Map<ExprId, Slot> exprIdSlotMap = new HashMap<>();
for (Slot slot : aggChildOutput) {
exprIdSlotMap.put(slot.getExprId(), slot);
}
List<Expression> parentHashExprs = new ArrayList<>(parentHashExprIds.size());
for (ExprId exprId : parentHashExprIds) {
if (exprIdSlotMap.containsKey(exprId)) {
parentHashExprs.add(exprIdSlotMap.get(exprId));
}
return false;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same pattern as above — hasUnknownStatistics returning true now correctly causes us to skip the parent subset optimization instead of blindly trying it.

}
if (AggregateUtils.hasUnknownStatistics(parentHashExprs, aggChildStats)) {
return true;
return false;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: NDV exactly equal to LOW_NDV_THRESHOLD (1024) is treated as insufficient — this is consistent with how SplitAggMultiPhase also uses > (strictly greater), so the threshold boundary is uniform across callers. 👍

}
double combinedNdv = StatsCalculator.estimateGroupByRowCount(parentHashExprs, aggChildStats);
return combinedNdv > AggregateUtils.LOW_NDV_THRESHOLD;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.AggregateUtils;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.ColumnStatisticBuilder;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -411,7 +415,7 @@ void testAggregateWithAggShuffleUseParentKeyDisabled() {
}

@Test
void testAggregateWithAggShuffleUseParentKeyEnabled() {
void testAggregateWithAggShuffleUseParentKeyEnabledAndUnknownStats() {
// Create ConnectContext with aggShuffleUseParentKey = true (default value)
ConnectContext testConnectContext = MemoTestUtils.createConnectContext();
testConnectContext.getSessionVariable().aggShuffleUseParentKey = true;
Expand Down Expand Up @@ -446,14 +450,108 @@ public org.apache.doris.statistics.Statistics childStatistics(int idx) {
List<List<PhysicalProperties>> actual
= requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);

// When aggShuffleUseParentKey is true, shouldUseParent may return true
// If shouldUseParent returns true, it will add parent key (key1) first, then all groupByExpressions (key1, key2)
Assertions.assertEquals(2, actual.size(), "Should have at least one property request");
List<List<PhysicalProperties>> expected = Lists.newArrayList();
expected.add(Lists.newArrayList(PhysicalProperties.createHash(
Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE)));
Assertions.assertEquals(expected, actual);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The childStatistics override returns null here, which exercises the aggChildStats == nullreturn false path. The test now correctly expects only the full group-by key distribution. Consider updating the comment above to reflect the new behavior (e.g., // When stats are null, parent subset should NOT be used).


@Test
void testAggregateWithAggShuffleUseParentKeyEnabledAndLowNdvStats() {
ConnectContext testConnectContext = MemoTestUtils.createConnectContext();
testConnectContext.getSessionVariable().aggShuffleUseParentKey = true;
testConnectContext.getSessionVariable().setBeNumberForTest(3);

SlotReference key1 = new SlotReference(new ExprId(0), "col1", IntegerType.INSTANCE, true, ImmutableList.of());
SlotReference key2 = new SlotReference(new ExprId(1), "col2", IntegerType.INSTANCE, true, ImmutableList.of());
GroupPlan childPlan = new GroupPlan(new Group(GroupId.createGenerator().getNextId(),
new GroupExpression(new LogicalOneRowRelation(new RelationId(6), ImmutableList.of(key1, key2)))
.getPlan().getLogicalProperties()));
PhysicalHashAggregate<GroupPlan> aggregate = new PhysicalHashAggregate<>(
Lists.newArrayList(key1, key2),
Lists.newArrayList(key1, key2),
new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT),
true,
logicalProperties,
false,
childPlan
);
Statistics childStats = new StatisticsBuilder()
.setRowCount(10000)
.putColumnStatistics(key1,
new ColumnStatisticBuilder(10000).setNdv(AggregateUtils.LOW_NDV_THRESHOLD).build())
.build();
GroupExpression groupExpression = new GroupExpression(aggregate) {
@Override
public Statistics childStatistics(int idx) {
return childStats;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice boundary test — setNdv(AggregateUtils.LOW_NDV_THRESHOLD) (1024) and correctly expecting the parent key NOT to be used, since combinedNdv > LOW_NDV_THRESHOLD is false when NDV is exactly at the threshold.

}
};
new Group(null, groupExpression, null);

PhysicalProperties parentProperties = PhysicalProperties.createHash(
Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);

Mockito.when(jobContext.getRequiredProperties()).thenReturn(parentProperties);

RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(testConnectContext, jobContext);
List<List<PhysicalProperties>> actual
= requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);

List<List<PhysicalProperties>> expected = Lists.newArrayList();
expected.add(Lists.newArrayList(PhysicalProperties.createHash(
Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE)));
Assertions.assertEquals(expected, actual);
}

@Test
void testAggregateWithAggShuffleUseParentKeyEnabledAndHighNdvStats() {
ConnectContext testConnectContext = MemoTestUtils.createConnectContext();
testConnectContext.getSessionVariable().aggShuffleUseParentKey = true;
testConnectContext.getSessionVariable().setBeNumberForTest(3);

SlotReference key1 = new SlotReference(new ExprId(0), "col1", IntegerType.INSTANCE, true, ImmutableList.of());
SlotReference key2 = new SlotReference(new ExprId(1), "col2", IntegerType.INSTANCE, true, ImmutableList.of());
GroupPlan childPlan = new GroupPlan(new Group(GroupId.createGenerator().getNextId(),
new GroupExpression(new LogicalOneRowRelation(new RelationId(6), ImmutableList.of(key1, key2)))
.getPlan().getLogicalProperties()));
PhysicalHashAggregate<GroupPlan> aggregate = new PhysicalHashAggregate<>(
Lists.newArrayList(key1, key2),
Lists.newArrayList(key1, key2),
new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT),
true,
logicalProperties,
false,
childPlan
);
Statistics childStats = new StatisticsBuilder()
.setRowCount(10000)
.putColumnStatistics(key1, new ColumnStatisticBuilder(10000).setNdv(2000).build())
.build();
GroupExpression groupExpression = new GroupExpression(aggregate) {
@Override
public Statistics childStatistics(int idx) {
return childStats;
}
};
new Group(null, groupExpression, null);

PhysicalProperties parentProperties = PhysicalProperties.createHash(
Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);

Mockito.when(jobContext.getRequiredProperties()).thenReturn(parentProperties);

RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(testConnectContext, jobContext);
List<List<PhysicalProperties>> actual
= requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);

PhysicalProperties parentProp = PhysicalProperties.createHash(
Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);
PhysicalProperties aggProp = PhysicalProperties.createHash(
Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE);
Assertions.assertTrue(actual.contains(ImmutableList.of(aggProp)) && actual.contains(ImmutableList.of(parentProp)));
Assertions.assertEquals(2, actual.size());
Assertions.assertTrue(actual.contains(ImmutableList.of(parentProp)));
Assertions.assertTrue(actual.contains(ImmutableList.of(aggProp)));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,16 @@ void countMultiColumnsWithGby() {
physicalHashJoin(
physicalProject(
physicalHashAggregate(
physicalHashAggregate(
physicalDistribute(any())))),
physicalDistribute(
physicalHashAggregate(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The physicalDistribute wrappers are now expected in the plan shape because shouldUseParent no longer returns true when stats are unknown (which is the case in this unit test). Previously, the parent subset key was blindly adopted, which could eliminate the distribute node. This test change correctly reflects the stricter stats gate — the full group-by key distribution is used, and the distribute is preserved.

This is an intended side effect of the fix, but worth confirming: is the plan shape here what you would expect to see in production queries after this change?

physicalHashAggregate(
physicalDistribute(any())))))),
physicalProject(
physicalHashAggregate(
physicalHashAggregate(
physicalDistribute(any()))))
physicalDistribute(
physicalHashAggregate(
physicalHashAggregate(
physicalDistribute(any()))))))
).when(join ->
join.getJoinType() == JoinType.INNER_JOIN && join.getHashJoinConjuncts().get(0) instanceof NullSafeEqual
)
Expand Down
Loading