Skip to content

Commit b541eb7

Browse files
committed
Reshapes some code
1 parent c3cfc7d commit b541eb7

File tree

5 files changed

+170
-220
lines changed

5 files changed

+170
-220
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
9+
10+
import org.elasticsearch.common.util.Maps;
11+
import org.elasticsearch.xpack.esql.core.expression.Alias;
12+
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
13+
import org.elasticsearch.xpack.esql.core.expression.Expression;
14+
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
15+
import org.elasticsearch.xpack.esql.core.util.Holder;
16+
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
17+
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
18+
import org.elasticsearch.xpack.esql.plan.logical.Eval;
19+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
20+
import org.elasticsearch.xpack.esql.plan.logical.Project;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.Map;
25+
26+
abstract class AbstractAggregateDeduplicator extends OptimizerRules.OptimizerRule<Aggregate> {
27+
protected AbstractAggregateDeduplicator() {
28+
super(OptimizerRules.TransformDirection.UP);
29+
}
30+
31+
@Override
32+
protected LogicalPlan rule(Aggregate aggregate) {
33+
AttributeMap<Expression> aliases = buildAliases(aggregate);
34+
Map<AggregateFunction, Alias> rootAggs = Maps.newLinkedHashMapWithExpectedSize(aggregate.aggregates().size());
35+
List<NamedExpression> newAggs = new ArrayList<>();
36+
List<NamedExpression> newProjections = new ArrayList<>();
37+
List<Alias> evals = new ArrayList<>();
38+
Holder<Boolean> changed = new Holder<>(false);
39+
int[] counter = new int[] { 0 };
40+
for (NamedExpression agg : aggregate.aggregates()) {
41+
if (agg instanceof Alias as) {
42+
processAlias(as, aliases, rootAggs, newAggs, newProjections, evals, changed, counter);
43+
} else {
44+
newAggs.add(agg);
45+
newProjections.add(agg.toAttribute());
46+
}
47+
}
48+
if (!changed.get()) {
49+
return aggregate;
50+
}
51+
LogicalPlan plan = aggregate.with(aggregate.child(), aggregate.groupings(), newAggs);
52+
if (!evals.isEmpty()) {
53+
plan = new Eval(aggregate.source(), plan, evals);
54+
}
55+
return new Project(aggregate.source(), plan, newProjections);
56+
}
57+
58+
/** Build alias map — subclasses can override to handle grouping functions differently */
59+
protected abstract AttributeMap<Expression> buildAliases(Aggregate aggregate);
60+
61+
/** Process each alias — subclasses can override to add Eval logic or grouping replacements */
62+
protected abstract void processAlias(
63+
Alias as,
64+
AttributeMap<Expression> aliases,
65+
Map<AggregateFunction, Alias> rootAggs,
66+
List<NamedExpression> newAggs,
67+
List<NamedExpression> newProjections,
68+
List<Alias> evals,
69+
Holder<Boolean> changed,
70+
int[] counter
71+
);
72+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/DeduplicateAggs.java

Lines changed: 29 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,15 @@
77

88
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
99

10-
import org.elasticsearch.common.util.Maps;
1110
import org.elasticsearch.xpack.esql.core.expression.Alias;
1211
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
1312
import org.elasticsearch.xpack.esql.core.expression.Expression;
1413
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
15-
import org.elasticsearch.xpack.esql.core.tree.Source;
1614
import org.elasticsearch.xpack.esql.core.util.Holder;
1715
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
1816
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
1917
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
20-
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
21-
import org.elasticsearch.xpack.esql.plan.logical.Project;
2218

23-
import java.util.ArrayList;
2419
import java.util.List;
2520
import java.util.Map;
2621

@@ -30,74 +25,41 @@
3025
* becomes
3126
* stats a = min(x), c = count(*) by g | eval b = a, d = c | keep a, b, c, d, g
3227
*/
33-
public final class DeduplicateAggs extends OptimizerRules.OptimizerRule<Aggregate> implements OptimizerRules.CoordinatorOnly {
34-
public DeduplicateAggs() {
35-
super(OptimizerRules.TransformDirection.UP);
36-
}
37-
28+
public final class DeduplicateAggs extends AbstractAggregateDeduplicator implements OptimizerRules.CoordinatorOnly {
3829
@Override
39-
protected LogicalPlan rule(Aggregate aggregate) {
40-
// an alias map for evaluatable grouping functions
41-
AttributeMap.Builder<Expression> aliasesBuilder = AttributeMap.builder();
30+
protected AttributeMap<Expression> buildAliases(Aggregate aggregate) {
31+
AttributeMap.Builder<Expression> builder = AttributeMap.builder();
4232
aggregate.forEachExpressionUp(Alias.class, a -> {
43-
if (a.child() instanceof GroupingFunction.NonEvaluatableGroupingFunction == false) {
44-
aliasesBuilder.put(a.toAttribute(), a.child());
33+
if (!(a.child() instanceof GroupingFunction.NonEvaluatableGroupingFunction)) {
34+
builder.put(a.toAttribute(), a.child());
4535
}
4636
});
47-
var aliases = aliasesBuilder.build();
48-
49-
// break down each aggregate into AggregateFunction and/or grouping key
50-
// preserve the projection at the end
51-
List<? extends NamedExpression> aggs = aggregate.aggregates();
52-
53-
// root/naked aggs
54-
Map<AggregateFunction, Alias> rootAggs = Maps.newLinkedHashMapWithExpectedSize(aggs.size());
55-
List<NamedExpression> newProjections = new ArrayList<>();
56-
// track the aggregate aggs (including grouping which is not an AggregateFunction)
57-
List<NamedExpression> newAggs = new ArrayList<>();
58-
59-
Holder<Boolean> changed = new Holder<>(false);
60-
61-
for (NamedExpression agg : aggs) {
62-
if (agg instanceof Alias as) {
63-
// use intermediate variable to mark child as final for lambda use
64-
Expression child = as.child();
65-
66-
// common case - handle duplicates
67-
if (child instanceof AggregateFunction af) {
68-
// canonical representation, with resolved aliases
69-
AggregateFunction canonical = (AggregateFunction) af.transformUp(e -> aliases.resolve(e, e));
37+
return builder.build();
38+
}
7039

71-
Alias found = rootAggs.get(canonical);
72-
// aggregate is new
73-
if (found == null) {
74-
rootAggs.put(canonical, as);
75-
newAggs.add(as);
76-
newProjections.add(as.toAttribute());
77-
}
78-
// agg already exists - preserve the current alias but point it to the existing agg
79-
// thus don't add it to the list of aggs as we don't want duplicated compute
80-
else {
81-
changed.set(true);
82-
newProjections.add(as.replaceChild(found.toAttribute()));
83-
}
84-
}
85-
}
86-
// not an alias (e.g. grouping field)
87-
else {
88-
newAggs.add(agg);
89-
newProjections.add(agg.toAttribute());
40+
@Override
41+
protected void processAlias(
42+
Alias as,
43+
AttributeMap<Expression> aliases,
44+
Map<AggregateFunction, Alias> rootAggs,
45+
List<NamedExpression> newAggs,
46+
List<NamedExpression> newProjections,
47+
List<Alias> evals,
48+
Holder<Boolean> changed,
49+
int[] counter
50+
) {
51+
Expression child = as.child();
52+
if (child instanceof AggregateFunction af) {
53+
AggregateFunction canonical = (AggregateFunction) af.transformUp(e -> aliases.resolve(e, e));
54+
Alias found = rootAggs.get(canonical);
55+
if (found == null) {
56+
rootAggs.put(canonical, as);
57+
newAggs.add(as);
58+
newProjections.add(as.toAttribute());
59+
} else {
60+
changed.set(true);
61+
newProjections.add(as.replaceChild(found.toAttribute()));
9062
}
9163
}
92-
93-
LogicalPlan plan = aggregate;
94-
if (changed.get()) {
95-
Source source = aggregate.source();
96-
plan = aggregate.with(aggregate.child(), aggregate.groupings(), newAggs);
97-
// preserve initial projection
98-
plan = new Project(source, plan, newProjections);
99-
}
100-
101-
return plan;
10264
}
10365
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java

Lines changed: 46 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,16 @@
77

88
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
99

10-
import org.elasticsearch.common.util.Maps;
1110
import org.elasticsearch.xpack.esql.core.expression.Alias;
1211
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1312
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
1413
import org.elasticsearch.xpack.esql.core.expression.Expression;
1514
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
16-
import org.elasticsearch.xpack.esql.core.tree.Source;
1715
import org.elasticsearch.xpack.esql.core.util.Holder;
1816
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
1917
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
2018
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
21-
import org.elasticsearch.xpack.esql.plan.logical.Eval;
22-
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
23-
import org.elasticsearch.xpack.esql.plan.logical.Project;
2419

25-
import java.util.ArrayList;
2620
import java.util.HashMap;
2721
import java.util.List;
2822
import java.util.Map;
@@ -43,117 +37,61 @@
4337
* becomes
4438
* stats a = min(x), c = count(*) by g | eval b = a, d = c | keep a, b, c, d, g
4539
*/
46-
public final class ReplaceAggregateAggExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {
47-
public ReplaceAggregateAggExpressionWithEval() {
48-
super(OptimizerRules.TransformDirection.UP);
49-
}
40+
public final class ReplaceAggregateAggExpressionWithEval extends AbstractAggregateDeduplicator {
5041

5142
@Override
52-
protected LogicalPlan rule(Aggregate aggregate) {
53-
// an alias map for evaluatable grouping functions
54-
AttributeMap.Builder<Expression> aliasesBuilder = AttributeMap.builder();
55-
// a function map for non-evaluatable grouping functions
56-
Map<GroupingFunction.NonEvaluatableGroupingFunction, Attribute> nonEvalGroupingAttributes = new HashMap<>(
57-
aggregate.groupings().size()
58-
);
43+
protected AttributeMap<Expression> buildAliases(Aggregate aggregate) {
44+
AttributeMap.Builder<Expression> builder = AttributeMap.builder();
45+
Map<GroupingFunction.NonEvaluatableGroupingFunction, Attribute> nonEvalGroupingAttributes = new HashMap<>();
5946
aggregate.forEachExpressionUp(Alias.class, a -> {
60-
if (a.child() instanceof GroupingFunction.NonEvaluatableGroupingFunction groupingFunction) {
61-
nonEvalGroupingAttributes.put(groupingFunction, a.toAttribute());
47+
if (a.child() instanceof GroupingFunction.NonEvaluatableGroupingFunction gf) {
48+
nonEvalGroupingAttributes.put(gf, a.toAttribute());
6249
} else {
63-
aliasesBuilder.put(a.toAttribute(), a.child());
50+
builder.put(a.toAttribute(), a.child());
6451
}
6552
});
66-
var aliases = aliasesBuilder.build();
67-
68-
// break down each aggregate into AggregateFunction and/or grouping key
69-
// preserve the projection at the end
70-
List<? extends NamedExpression> aggs = aggregate.aggregates();
71-
72-
// root/naked aggs
73-
Map<AggregateFunction, Alias> rootAggs = Maps.newLinkedHashMapWithExpectedSize(aggs.size());
74-
// evals (original expression relying on multiple aggs)
75-
List<Alias> newEvals = new ArrayList<>();
76-
List<NamedExpression> newProjections = new ArrayList<>();
77-
// track the aggregate aggs (including grouping which is not an AggregateFunction)
78-
List<NamedExpression> newAggs = new ArrayList<>();
79-
80-
Holder<Boolean> changed = new Holder<>(false);
81-
int[] counter = new int[] { 0 };
82-
83-
for (NamedExpression agg : aggs) {
84-
if (agg instanceof Alias as) {
85-
// use intermediate variable to mark child as final for lambda use
86-
Expression child = as.child();
87-
88-
// common case - handle duplicates
89-
if (child instanceof AggregateFunction af) {
90-
// canonical representation, with resolved aliases
91-
AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e));
92-
93-
Alias found = rootAggs.get(canonical);
94-
// aggregate is new
95-
if (found == null) {
96-
rootAggs.put(canonical, as);
97-
newAggs.add(as);
98-
newProjections.add(as.toAttribute());
99-
}
100-
// agg already exists - preserve the current alias but point it to the existing agg
101-
// thus don't add it to the list of aggs as we don't want duplicated compute
102-
else {
103-
changed.set(true);
104-
newProjections.add(as.replaceChild(found.toAttribute()));
105-
}
106-
}
107-
// nested expression over aggregate function or groups
108-
// replace them with reference and move the expression into a follow-up eval
109-
else {
110-
changed.set(true);
111-
Expression aggExpression = child.transformUp(AggregateFunction.class, af -> {
112-
// canonical representation, with resolved aliases
113-
AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e));
114-
Alias alias = rootAggs.get(canonical);
115-
if (alias == null) {
116-
// create synthetic alias over the found agg function
117-
alias = new Alias(af.source(), syntheticName(canonical, child, counter[0]++), af.canonical(), null, true);
118-
// and remember it to remove duplicates
119-
rootAggs.put(canonical, alias);
120-
// add it to the list of aggregates and continue
121-
newAggs.add(alias);
122-
}
123-
// (even when found) return a reference to it
124-
return alias.toAttribute();
125-
});
126-
127-
// replace non-evaluatable grouping functions with their references
128-
aggExpression = aggExpression.transformUp(
129-
GroupingFunction.NonEvaluatableGroupingFunction.class,
130-
nonEvalGroupingAttributes::get
131-
);
132-
133-
Alias alias = as.replaceChild(aggExpression);
134-
newEvals.add(alias);
135-
newProjections.add(alias.toAttribute());
136-
}
137-
}
138-
// not an alias (e.g. grouping field)
139-
else {
140-
newAggs.add(agg);
141-
newProjections.add(agg.toAttribute());
142-
}
143-
}
53+
return builder.build();
54+
}
14455

145-
LogicalPlan plan = aggregate;
146-
if (changed.get()) {
147-
Source source = aggregate.source();
148-
plan = aggregate.with(aggregate.child(), aggregate.groupings(), newAggs);
149-
if (newEvals.size() > 0) {
150-
plan = new Eval(source, plan, newEvals);
56+
@Override
57+
protected void processAlias(
58+
Alias as,
59+
AttributeMap<Expression> aliases,
60+
Map<AggregateFunction, Alias> rootAggs,
61+
List<NamedExpression> newAggs,
62+
List<NamedExpression> newProjections,
63+
List<Alias> evals,
64+
Holder<Boolean> changed,
65+
int[] counter
66+
) {
67+
Expression child = as.child();
68+
if (child instanceof AggregateFunction af) {
69+
AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e));
70+
Alias found = rootAggs.get(canonical);
71+
if (found == null) {
72+
rootAggs.put(canonical, as);
73+
newAggs.add(as);
74+
newProjections.add(as.toAttribute());
75+
} else {
76+
changed.set(true);
77+
newProjections.add(as.replaceChild(found.toAttribute()));
15178
}
152-
// preserve initial projection
153-
plan = new Project(source, plan, newProjections);
79+
} else {
80+
changed.set(true);
81+
Expression aggExpression = child.transformUp(AggregateFunction.class, af -> {
82+
AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e));
83+
Alias alias = rootAggs.get(canonical);
84+
if (alias == null) {
85+
alias = new Alias(af.source(), syntheticName(canonical, child, counter[0]++), af.canonical(), null, true);
86+
rootAggs.put(canonical, alias);
87+
newAggs.add(alias);
88+
}
89+
return alias.toAttribute();
90+
});
91+
Alias alias = as.replaceChild(aggExpression);
92+
evals.add(alias);
93+
newProjections.add(alias.toAttribute());
15494
}
155-
156-
return plan;
15795
}
15896

15997
private static String syntheticName(Expression expression, Expression af, int counter) {

0 commit comments

Comments
 (0)