Skip to content

Commit 66a441f

Browse files
authored
Add support for global constants to constant folder (#1180)
Global constants like enums are represented in the AST as identifiers. These identifiers can be looked up in the AST reference map. If there is an entry in the reference map, we can try to fold it as a constant and to turn it into a literal. This reduces the size of the AST and unlocks more optimization opportunities.
1 parent dc36eaa commit 66a441f

File tree

2 files changed

+102
-11
lines changed

2 files changed

+102
-11
lines changed

cel/folding.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@ func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
3838
}
3939
}
4040

41+
// Adds an Activation which provides known values for the folding evaluator
42+
//
43+
// Any values the activation provides will be used by the constant folder and turned into
44+
// literals in the AST.
45+
//
46+
// Defaults to the NoVars() Activation
47+
func FoldKnownValues(knownValues Activation) ConstantFoldingOption {
48+
return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
49+
if knownValues != nil {
50+
opt.knownValues = knownValues
51+
} else {
52+
opt.knownValues = NoVars()
53+
}
54+
return opt, nil
55+
}
56+
}
57+
4158
// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
4259
// literal values within function calls and select statements with their evaluated result.
4360
func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
@@ -56,6 +73,7 @@ func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, e
5673

5774
type constantFoldingOptimizer struct {
5875
maxFoldIterations int
76+
knownValues Activation
5977
}
6078

6179
// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
@@ -68,7 +86,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
6886
// Walk the list of foldable expression and continue to fold until there are no more folds left.
6987
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
7088
// a logic bug with the selection of expressions.
71-
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return constantExprMatcher(ctx, a, e) }
89+
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return opt.constantExprMatcher(ctx, a, e) }
7290
foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture)
7391
foldCount := 0
7492
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
@@ -83,8 +101,10 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
83101
continue
84102
}
85103
// Otherwise, assume all context is needed to evaluate the expression.
86-
err := tryFold(ctx, a, fold)
87-
if err != nil {
104+
err := opt.tryFold(ctx, a, fold)
105+
// Ignore errors for identifiers, since there is no guarantee that the environment
106+
// has a value for them.
107+
if err != nil && fold.Kind() != ast.IdentKind {
88108
ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error())
89109
return a
90110
}
@@ -96,7 +116,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
96116
// one last time. In this case, there's no guarantee they'll run, so we only update the
97117
// target comprehension node with the literal value if the evaluation succeeds.
98118
for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) {
99-
tryFold(ctx, a, compre)
119+
opt.tryFold(ctx, a, compre)
100120
}
101121

102122
// If the output is a list, map, or struct which contains optional entries, then prune it
@@ -126,7 +146,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
126146
//
127147
// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise
128148
// the method will return an error.
129-
func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
149+
func (opt *constantFoldingOptimizer) tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
130150
// Assume all context is needed to evaluate the expression.
131151
subAST := &Ast{
132152
impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()),
@@ -135,7 +155,11 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
135155
if err != nil {
136156
return err
137157
}
138-
out, _, err := prg.Eval(NoVars())
158+
activation := opt.knownValues
159+
if activation == nil {
160+
activation = NoVars()
161+
}
162+
out, _, err := prg.Eval(activation)
139163
if err != nil {
140164
return err
141165
}
@@ -469,13 +493,15 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
469493
// Only comprehensions which are not nested are included as possible constant folds, and only
470494
// if all variables referenced in the comprehension stack exist are only iteration or
471495
// accumulation variables.
472-
func constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
496+
func (opt *constantFoldingOptimizer) constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
473497
switch e.Kind() {
474498
case ast.CallKind:
475499
return constantCallMatcher(e)
476500
case ast.SelectKind:
477501
sel := e.AsSelect() // guaranteed to be a navigable value
478502
return constantMatcher(sel.Operand().(ast.NavigableExpr))
503+
case ast.IdentKind:
504+
return opt.knownValues != nil && a.ReferenceMap()[e.ID()] != nil
479505
case ast.ComprehensionKind:
480506
if isNestedComprehension(e) {
481507
return false

cel/folding_test.go

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"google.golang.org/protobuf/proto"
2525

2626
"github.com/google/cel-go/common/ast"
27+
"github.com/google/cel-go/common/types"
2728
"github.com/google/cel-go/common/types/ref"
2829

2930
proto3pb "github.com/google/cel-go/test/proto3pb"
@@ -32,8 +33,9 @@ import (
3233

3334
func TestConstantFoldingOptimizer(t *testing.T) {
3435
tests := []struct {
35-
expr string
36-
folded string
36+
expr string
37+
folded string
38+
knownValues map[string]any
3739
}{
3840
{
3941
expr: `[1, 1 + 2, 1 + (2 + 3)]`,
@@ -279,23 +281,86 @@ func TestConstantFoldingOptimizer(t *testing.T) {
279281
expr: `1 + 2 + x == x + 2 + 1`,
280282
folded: `3 + x == x + 2 + 1`,
281283
},
284+
{
285+
expr: `google.expr.proto3.test.ImportedGlobalEnum.IMPORT_BAR`,
286+
folded: `1`,
287+
knownValues: map[string]any{},
288+
},
289+
{
290+
expr: `google.expr.proto3.test.ImportedGlobalEnum.IMPORT_BAR`,
291+
folded: `google.expr.proto3.test.ImportedGlobalEnum.IMPORT_BAR`,
292+
},
293+
{
294+
expr: `c == google.expr.proto3.test.ImportedGlobalEnum.IMPORT_BAZ ? "BAZ" : "Unknown"`,
295+
folded: `"BAZ"`,
296+
knownValues: map[string]any{},
297+
},
298+
{
299+
expr: `[
300+
google.expr.proto3.test.ImportedGlobalEnum.IMPORT_BAR,
301+
c,
302+
google.expr.proto3.test.ImportedGlobalEnum.IMPORT_FOO
303+
].exists(e, e == google.expr.proto3.test.ImportedGlobalEnum.IMPORT_FOO)
304+
? "has Foo" : "no Foo"`,
305+
folded: `"has Foo"`,
306+
knownValues: map[string]any{},
307+
},
308+
{
309+
expr: `l.exists(e, e == "foo") ? "has Foo" : "no Foo"`,
310+
folded: `"has Foo"`,
311+
knownValues: map[string]any{
312+
"l": []string{"foo", "bar", "baz"},
313+
},
314+
},
315+
{
316+
expr: `"foo" in l`,
317+
folded: `true`,
318+
knownValues: map[string]any{
319+
"l": []string{"foo", "bar", "baz"},
320+
},
321+
},
322+
{
323+
expr: `o.repeated_int32`,
324+
folded: `[1, 2, 3]`,
325+
knownValues: map[string]any{
326+
"o": &proto3pb.TestAllTypes{RepeatedInt32: []int32{1, 2, 3}},
327+
},
328+
},
282329
}
283330
e, err := NewEnv(
284331
OptionalTypes(),
285332
EnableMacroCallTracking(),
286333
Types(&proto3pb.TestAllTypes{}),
287-
Variable("x", DynType))
334+
Variable("x", DynType),
335+
Constant("c", IntType, types.Int(proto3pb.ImportedGlobalEnum_IMPORT_BAZ)),
336+
)
288337
if err != nil {
289338
t.Fatalf("NewEnv() failed: %v", err)
290339
}
340+
e, err = e.Extend(Variable("l", ListType(StringType)))
341+
if err != nil {
342+
t.Fatalf("Extend() failed: %v", err)
343+
}
344+
e, err = e.Extend(Variable("o", ObjectType("google.expr.proto3.test.TestAllTypes")))
345+
if err != nil {
346+
t.Fatalf("Extend() failed: %v", err)
347+
}
291348
for _, tst := range tests {
292349
tc := tst
293350
t.Run(tc.expr, func(t *testing.T) {
294351
checked, iss := e.Compile(tc.expr)
295352
if iss.Err() != nil {
296353
t.Fatalf("Compile() failed: %v", iss.Err())
297354
}
298-
folder, err := NewConstantFoldingOptimizer()
355+
var foldingOpts []ConstantFoldingOption
356+
if tc.knownValues != nil {
357+
knownValues, err := NewActivation(tc.knownValues)
358+
if err != nil {
359+
t.Fatalf("NewActivation() failed: %v", err)
360+
}
361+
foldingOpts = append(foldingOpts, FoldKnownValues(knownValues))
362+
}
363+
folder, err := NewConstantFoldingOptimizer(foldingOpts...)
299364
if err != nil {
300365
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
301366
}

0 commit comments

Comments
 (0)