diff --git a/cel/inlining.go b/cel/inlining.go index a4530e19e..d9a5e89a5 100644 --- a/cel/inlining.go +++ b/cel/inlining.go @@ -178,9 +178,38 @@ func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, i )) return } + if zeroValExpr, ok := zeroValueExpr(ctx, inlinedType); ok { + ctx.UpdateExpr(prev, + ctx.NewCall(operators.NotEquals, + inlined, zeroValExpr)) + return + } ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType) } +// zeroValueExpr creates an expression representing the empty or zero value for the given type +// Note: bytes, lists, maps, and strings are supported via the `SizerType` trait. +func zeroValueExpr(ctx *OptimizerContext, t *Type) (ast.Expr, bool) { + // Note: bytes, strings, lists, and maps are covered by the "sizer-type" check + switch t.Kind() { + case types.BoolKind: + return ctx.NewLiteral(types.False), true + case types.DoubleKind: + return ctx.NewLiteral(types.Double(0)), true + case types.DurationKind: + return ctx.NewCall(overloads.TypeConvertDuration, ctx.NewLiteral(types.String("0s"))), true + case types.IntKind: + return ctx.NewLiteral(types.IntZero), true + case types.TimestampKind: + return ctx.NewCall(overloads.TypeConvertTimestamp, ctx.NewLiteral(types.Int(0))), true + case types.StructKind: + return ctx.NewStruct(t.TypeName(), []ast.EntryExpr{}), true + case types.UintKind: + return ctx.NewLiteral(types.Uint(0)), true + } + return nil, false +} + // isBindable indicates whether the inlined type can be used within a cel.bind() if the expression // being replaced occurs within a presence test. Value types with a size() method or field selection // support can be bound. @@ -212,17 +241,43 @@ func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type // field selection. This may be a future refinement. func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher { return func(e ast.NavigableExpr) bool { - if e.Kind() == ast.IdentKind && e.AsIdent() == varName { - return true + name, found := maybeAsVariableName(e) + if !found || name != varName { + return false + } + + // Determine whether the variable being referenced has been shadowed by a comprehension + p, hasParent := e.Parent() + for hasParent { + if p.Kind() != ast.ComprehensionKind { + p, hasParent = p.Parent() + continue + } + // If the inline variable name matches any of the comprehension variables at any scope, + // return false as the variable has been shadowed. + compre := p.AsComprehension() + if varName == compre.AccuVar() || varName == compre.IterVar() || varName == compre.IterVar2() { + return false + } + p, hasParent = p.Parent() } - if e.Kind() == ast.SelectKind { - sel := e.AsSelect() - // While the `ToQualifiedName` call could take the select directly, this - // would skip presence tests from possible matches, which we would like - // to include. - qualName, found := containers.ToQualifiedName(sel.Operand()) - return found && qualName+"."+sel.FieldName() == varName + + return true + } +} + +func maybeAsVariableName(e ast.NavigableExpr) (string, bool) { + if e.Kind() == ast.IdentKind { + return e.AsIdent(), true + } + if e.Kind() == ast.SelectKind { + sel := e.AsSelect() + // While the `ToQualifiedName` call could take the select directly, this + // would skip presence tests from possible matches, which we would like + // to include. + if qualName, found := containers.ToQualifiedName(sel.Operand()); found { + return qualName + "." + sel.FieldName(), true } - return false } + return "", false } diff --git a/cel/inlining_test.go b/cel/inlining_test.go index 4dec1074a..d9ad88b7a 100644 --- a/cel/inlining_test.go +++ b/cel/inlining_test.go @@ -22,6 +22,299 @@ import ( proto3pb "github.com/google/cel-go/test/proto3pb" ) +func TestInliningOptimizerNoopShadow(t *testing.T) { + type varExpr struct { + name string + alias string + t *cel.Type + expr string + } + tests := []struct { + name string + expr string + vars []varExpr + inlined string + }{ + { + name: "shadow at parent", + expr: `[0].exists(shadowed_ident, shadowed_ident == 0)`, + vars: []varExpr{ + { + name: "shadowed_ident", + t: cel.IntType, + expr: "1", + }, + }, + inlined: `[0].exists(shadowed_ident, shadowed_ident == 0)`, + }, + { + name: "shadow in ancestor", + expr: `[[1]].all(shadowed_ident, shadowed_ident.all(shadowed, shadowed + 1 == 2))`, + vars: []varExpr{ + { + name: "shadowed_ident", + t: cel.IntType, + expr: "42", + }, + }, + inlined: `[[1]].all(shadowed_ident, shadowed_ident.all(shadowed, shadowed + 1 == 2))`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + opts := []cel.EnvOption{ + cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + } + varDecls := make([]cel.EnvOption, len(tc.vars)) + for i, v := range tc.vars { + varDecls[i] = cel.Variable(v.name, v.t) + } + e, err := cel.NewEnv(append(varDecls, opts...)...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + inlinedVars := []*cel.InlineVariable{} + for _, v := range tc.vars { + if v.expr == "" { + continue + } + checked, iss := e.Compile(v.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", v.expr, iss.Err()) + } + if v.alias == "" { + inlinedVars = append(inlinedVars, cel.NewInlineVariable(v.name, checked)) + } else { + inlinedVars = append(inlinedVars, cel.NewInlineVariableWithAlias(v.name, v.alias, checked)) + } + } + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + inlined, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if inlined != tc.inlined { + t.Errorf("inlined got %q, wanted %q", inlined, tc.inlined) + } + }) + } +} + +func TestInliningOptimizerPresenceTests(t *testing.T) { + type varExpr struct { + name string + alias string + t *cel.Type + expr string + } + tests := []struct { + name string + expr string + vars []varExpr + inlined string + }{ + { + name: "presence with bool literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.BoolType, + expr: "true", + }, + }, + inlined: `true != false`, + }, + { + name: "presence with bytes literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.BytesType, + expr: "b'abc'", + }, + }, + inlined: `b"\141\142\143".size() != 0`, + }, + { + name: "presence with double literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.DoubleType, + expr: "42.0", + }, + }, + inlined: `42.0 != 0.0`, + }, + { + name: "presence with duration literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.DurationType, + expr: "duration('1s')", + }, + }, + inlined: `duration("1s") != duration("0s")`, + }, + { + name: "presence with int literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.IntType, + expr: "1", + }, + }, + inlined: `1 != 0`, + }, + { + name: "presence with list literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.ListType(cel.StringType), + expr: "['foo', 'bar']", + }, + }, + inlined: `["foo", "bar"].size() != 0`, + }, + { + name: "presence with map literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.MapType(cel.StringType, cel.StringType), + expr: "{'foo': 'bar'}", + }, + }, + inlined: `{"foo": "bar"}.size() != 0`, + }, + { + name: "presence with string literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.StringType, + expr: "'foo'", + }, + }, + inlined: `"foo".size() != 0`, + }, + { + name: "presence with struct literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + expr: "TestAllTypes{single_int64: 1}", + }, + }, + inlined: `google.expr.proto3.test.TestAllTypes{single_int64: 1} != google.expr.proto3.test.TestAllTypes{}`, + }, + { + name: "presence with timestamp literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.TimestampType, + expr: "timestamp(123)", + }, + }, + inlined: `timestamp(123) != timestamp(0)`, + }, + { + name: "presence with uint literal", + expr: `has(msg.single_any.processing_purpose)`, + vars: []varExpr{ + { + name: "msg.single_any.processing_purpose", + t: cel.UintType, + expr: "1u", + }, + }, + inlined: `1u != 0u`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + opts := []cel.EnvOption{ + cel.Container("google.expr.proto3.test"), + cel.Types(&proto3pb.TestAllTypes{}), + cel.Variable("msg", cel.ObjectType("google.expr.proto3.test.TestAllTypes")), + cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + } + varDecls := make([]cel.EnvOption, len(tc.vars)) + for i, v := range tc.vars { + varDecls[i] = cel.Variable(v.name, v.t) + } + e, err := cel.NewEnv(append(varDecls, opts...)...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + inlinedVars := []*cel.InlineVariable{} + for _, v := range tc.vars { + if v.expr == "" { + continue + } + checked, iss := e.Compile(v.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", v.expr, iss.Err()) + } + if v.alias == "" { + inlinedVars = append(inlinedVars, cel.NewInlineVariable(v.name, checked)) + } else { + inlinedVars = append(inlinedVars, cel.NewInlineVariableWithAlias(v.name, v.alias, checked)) + } + } + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + if err != nil { + t.Fatalf("NewStaticOptimizer() failed: %v", err) + } + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + inlined, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if inlined != tc.inlined { + t.Errorf("inlined got %q, wanted %q", inlined, tc.inlined) + } + }) + } +} + func TestInliningOptimizer(t *testing.T) { type varExpr struct { name string