|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.esql.optimizer.rules.logical; |
9 | 9 |
|
10 | | -import org.elasticsearch.common.util.Maps; |
11 | 10 | import org.elasticsearch.xpack.esql.core.expression.Alias; |
12 | 11 | import org.elasticsearch.xpack.esql.core.expression.Attribute; |
13 | 12 | import org.elasticsearch.xpack.esql.core.expression.AttributeMap; |
14 | 13 | import org.elasticsearch.xpack.esql.core.expression.Expression; |
15 | 14 | import org.elasticsearch.xpack.esql.core.expression.NamedExpression; |
16 | | -import org.elasticsearch.xpack.esql.core.tree.Source; |
17 | 15 | import org.elasticsearch.xpack.esql.core.util.Holder; |
18 | 16 | import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; |
19 | 17 | import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; |
20 | 18 | import org.elasticsearch.xpack.esql.plan.logical.Aggregate; |
21 | | -import org.elasticsearch.xpack.esql.plan.logical.Eval; |
22 | | -import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; |
23 | | -import org.elasticsearch.xpack.esql.plan.logical.Project; |
24 | 19 |
|
25 | | -import java.util.ArrayList; |
26 | 20 | import java.util.HashMap; |
27 | 21 | import java.util.List; |
28 | 22 | import java.util.Map; |
|
43 | 37 | * becomes |
44 | 38 | * stats a = min(x), c = count(*) by g | eval b = a, d = c | keep a, b, c, d, g |
45 | 39 | */ |
46 | | -public final class ReplaceAggregateAggExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> { |
47 | | - public ReplaceAggregateAggExpressionWithEval() { |
48 | | - super(OptimizerRules.TransformDirection.UP); |
49 | | - } |
| 40 | +public final class ReplaceAggregateAggExpressionWithEval extends AbstractAggregateDeduplicator { |
50 | 41 |
|
51 | 42 | @Override |
52 | | - protected LogicalPlan rule(Aggregate aggregate) { |
53 | | - // an alias map for evaluatable grouping functions |
54 | | - AttributeMap.Builder<Expression> aliasesBuilder = AttributeMap.builder(); |
55 | | - // a function map for non-evaluatable grouping functions |
56 | | - Map<GroupingFunction.NonEvaluatableGroupingFunction, Attribute> nonEvalGroupingAttributes = new HashMap<>( |
57 | | - aggregate.groupings().size() |
58 | | - ); |
| 43 | + protected AttributeMap<Expression> buildAliases(Aggregate aggregate) { |
| 44 | + AttributeMap.Builder<Expression> builder = AttributeMap.builder(); |
| 45 | + Map<GroupingFunction.NonEvaluatableGroupingFunction, Attribute> nonEvalGroupingAttributes = new HashMap<>(); |
59 | 46 | aggregate.forEachExpressionUp(Alias.class, a -> { |
60 | | - if (a.child() instanceof GroupingFunction.NonEvaluatableGroupingFunction groupingFunction) { |
61 | | - nonEvalGroupingAttributes.put(groupingFunction, a.toAttribute()); |
| 47 | + if (a.child() instanceof GroupingFunction.NonEvaluatableGroupingFunction gf) { |
| 48 | + nonEvalGroupingAttributes.put(gf, a.toAttribute()); |
62 | 49 | } else { |
63 | | - aliasesBuilder.put(a.toAttribute(), a.child()); |
| 50 | + builder.put(a.toAttribute(), a.child()); |
64 | 51 | } |
65 | 52 | }); |
66 | | - var aliases = aliasesBuilder.build(); |
67 | | - |
68 | | - // break down each aggregate into AggregateFunction and/or grouping key |
69 | | - // preserve the projection at the end |
70 | | - List<? extends NamedExpression> aggs = aggregate.aggregates(); |
71 | | - |
72 | | - // root/naked aggs |
73 | | - Map<AggregateFunction, Alias> rootAggs = Maps.newLinkedHashMapWithExpectedSize(aggs.size()); |
74 | | - // evals (original expression relying on multiple aggs) |
75 | | - List<Alias> newEvals = new ArrayList<>(); |
76 | | - List<NamedExpression> newProjections = new ArrayList<>(); |
77 | | - // track the aggregate aggs (including grouping which is not an AggregateFunction) |
78 | | - List<NamedExpression> newAggs = new ArrayList<>(); |
79 | | - |
80 | | - Holder<Boolean> changed = new Holder<>(false); |
81 | | - int[] counter = new int[] { 0 }; |
82 | | - |
83 | | - for (NamedExpression agg : aggs) { |
84 | | - if (agg instanceof Alias as) { |
85 | | - // use intermediate variable to mark child as final for lambda use |
86 | | - Expression child = as.child(); |
87 | | - |
88 | | - // common case - handle duplicates |
89 | | - if (child instanceof AggregateFunction af) { |
90 | | - // canonical representation, with resolved aliases |
91 | | - AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e)); |
92 | | - |
93 | | - Alias found = rootAggs.get(canonical); |
94 | | - // aggregate is new |
95 | | - if (found == null) { |
96 | | - rootAggs.put(canonical, as); |
97 | | - newAggs.add(as); |
98 | | - newProjections.add(as.toAttribute()); |
99 | | - } |
100 | | - // agg already exists - preserve the current alias but point it to the existing agg |
101 | | - // thus don't add it to the list of aggs as we don't want duplicated compute |
102 | | - else { |
103 | | - changed.set(true); |
104 | | - newProjections.add(as.replaceChild(found.toAttribute())); |
105 | | - } |
106 | | - } |
107 | | - // nested expression over aggregate function or groups |
108 | | - // replace them with reference and move the expression into a follow-up eval |
109 | | - else { |
110 | | - changed.set(true); |
111 | | - Expression aggExpression = child.transformUp(AggregateFunction.class, af -> { |
112 | | - // canonical representation, with resolved aliases |
113 | | - AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e)); |
114 | | - Alias alias = rootAggs.get(canonical); |
115 | | - if (alias == null) { |
116 | | - // create synthetic alias over the found agg function |
117 | | - alias = new Alias(af.source(), syntheticName(canonical, child, counter[0]++), af.canonical(), null, true); |
118 | | - // and remember it to remove duplicates |
119 | | - rootAggs.put(canonical, alias); |
120 | | - // add it to the list of aggregates and continue |
121 | | - newAggs.add(alias); |
122 | | - } |
123 | | - // (even when found) return a reference to it |
124 | | - return alias.toAttribute(); |
125 | | - }); |
126 | | - |
127 | | - // replace non-evaluatable grouping functions with their references |
128 | | - aggExpression = aggExpression.transformUp( |
129 | | - GroupingFunction.NonEvaluatableGroupingFunction.class, |
130 | | - nonEvalGroupingAttributes::get |
131 | | - ); |
132 | | - |
133 | | - Alias alias = as.replaceChild(aggExpression); |
134 | | - newEvals.add(alias); |
135 | | - newProjections.add(alias.toAttribute()); |
136 | | - } |
137 | | - } |
138 | | - // not an alias (e.g. grouping field) |
139 | | - else { |
140 | | - newAggs.add(agg); |
141 | | - newProjections.add(agg.toAttribute()); |
142 | | - } |
143 | | - } |
| 53 | + return builder.build(); |
| 54 | + } |
144 | 55 |
|
145 | | - LogicalPlan plan = aggregate; |
146 | | - if (changed.get()) { |
147 | | - Source source = aggregate.source(); |
148 | | - plan = aggregate.with(aggregate.child(), aggregate.groupings(), newAggs); |
149 | | - if (newEvals.size() > 0) { |
150 | | - plan = new Eval(source, plan, newEvals); |
| 56 | + @Override |
| 57 | + protected void processAlias( |
| 58 | + Alias as, |
| 59 | + AttributeMap<Expression> aliases, |
| 60 | + Map<AggregateFunction, Alias> rootAggs, |
| 61 | + List<NamedExpression> newAggs, |
| 62 | + List<NamedExpression> newProjections, |
| 63 | + List<Alias> evals, |
| 64 | + Holder<Boolean> changed, |
| 65 | + int[] counter |
| 66 | + ) { |
| 67 | + Expression child = as.child(); |
| 68 | + if (child instanceof AggregateFunction af) { |
| 69 | + AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e)); |
| 70 | + Alias found = rootAggs.get(canonical); |
| 71 | + if (found == null) { |
| 72 | + rootAggs.put(canonical, as); |
| 73 | + newAggs.add(as); |
| 74 | + newProjections.add(as.toAttribute()); |
| 75 | + } else { |
| 76 | + changed.set(true); |
| 77 | + newProjections.add(as.replaceChild(found.toAttribute())); |
151 | 78 | } |
152 | | - // preserve initial projection |
153 | | - plan = new Project(source, plan, newProjections); |
| 79 | + } else { |
| 80 | + changed.set(true); |
| 81 | + Expression aggExpression = child.transformUp(AggregateFunction.class, af -> { |
| 82 | + AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e)); |
| 83 | + Alias alias = rootAggs.get(canonical); |
| 84 | + if (alias == null) { |
| 85 | + alias = new Alias(af.source(), syntheticName(canonical, child, counter[0]++), af.canonical(), null, true); |
| 86 | + rootAggs.put(canonical, alias); |
| 87 | + newAggs.add(alias); |
| 88 | + } |
| 89 | + return alias.toAttribute(); |
| 90 | + }); |
| 91 | + Alias alias = as.replaceChild(aggExpression); |
| 92 | + evals.add(alias); |
| 93 | + newProjections.add(alias.toAttribute()); |
154 | 94 | } |
155 | | - |
156 | | - return plan; |
157 | 95 | } |
158 | 96 |
|
159 | 97 | private static String syntheticName(Expression expression, Expression af, int counter) { |
|
0 commit comments