diff --git a/checker/checker.go b/checker/checker.go index 0057c16cc..3bb61c19c 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -19,6 +19,8 @@ package checker import ( "fmt" "reflect" + "slices" + "strings" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" @@ -104,11 +106,15 @@ func (c *checker) check(e ast.Expr) { func (c *checker) checkIdent(e ast.Expr) { identName := e.AsIdent() // Check to see if the identifier is declared. - if ident := c.env.LookupIdent(identName); ident != nil { + if ident := c.env.resolveSimpleIdent(identName); ident != nil { + name := strings.TrimPrefix(ident.Name(), ".") + if ident.requiresDisambiguation { + name = "." + name + } c.setType(e, ident.Type()) - c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value())) + c.setReference(e, ast.NewIdentReference(name, ident.Value())) // Overwrite the identifier with its fully qualified name. - e.SetKindCase(c.NewIdent(e.ID(), ident.Name())) + e.SetKindCase(c.NewIdent(e.ID(), name)) return } @@ -119,18 +125,22 @@ func (c *checker) checkIdent(e ast.Expr) { func (c *checker) checkSelect(e ast.Expr) { sel := e.AsSelect() // Before traversing down the tree, try to interpret as qualified name. - qname, found := containers.ToQualifiedName(e) + qualifiers, found := c.computeQualifiers(e) if found { - ident := c.env.LookupIdent(qname) + ident := c.env.resolveQualifiedIdent(qualifiers...) if ident != nil { // We don't check for a TestOnly expression here since the `found` result is // always going to be false for TestOnly expressions. // Rewrite the node to be a variable reference to the resolved fully-qualified // variable name. + name := ident.Name() + if ident.requiresDisambiguation { + name = "." + name + } c.setType(e, ident.Type()) - c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value())) - e.SetKindCase(c.NewIdent(e.ID(), ident.Name())) + c.setReference(e, ast.NewIdentReference(name, ident.Value())) + e.SetKindCase(c.NewIdent(e.ID(), name)) return } } @@ -142,6 +152,29 @@ func (c *checker) checkSelect(e ast.Expr) { c.setType(e, substitute(c.mappings, resultType, false)) } +// computeQualifiers computes the qualified names parts of a select expression. +func (c *checker) computeQualifiers(e ast.Expr) ([]string, bool) { + var qualifiers []string + for e.Kind() == ast.SelectKind { + sel := e.AsSelect() + // test only expressions are not considered for qualified name selection. + if sel.IsTestOnly() { + return qualifiers, false + } + // otherwise append the select field name to the qualifier list (reverse order) + qualifiers = append(qualifiers, sel.FieldName()) + e = sel.Operand() + // If the next operand is an identifier, then append it, reverse the name sequence + // and return it to the caller.s + if e.Kind() == ast.IdentKind { + qualifiers = append(qualifiers, e.AsIdent()) + slices.Reverse(qualifiers) + return qualifiers, true + } + } + return qualifiers, false +} + func (c *checker) checkOptSelect(e ast.Expr) { // Collect metadata related to the opt select call packaged by the parser. call := e.AsCall() @@ -234,7 +267,7 @@ func (c *checker) checkCall(e ast.Expr) { // Regular static call with simple name. if !call.IsMemberFunction() { // Check for the existence of the function. - fn := c.env.LookupFunction(fnName) + fn := c.env.lookupFunction(fnName) if fn == nil { c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName) c.setType(e, types.ErrorType) @@ -256,7 +289,7 @@ func (c *checker) checkCall(e ast.Expr) { qualifiedPrefix, maybeQualified := containers.ToQualifiedName(target) if maybeQualified { maybeQualifiedName := qualifiedPrefix + "." + fnName - fn := c.env.LookupFunction(maybeQualifiedName) + fn := c.env.lookupFunction(maybeQualifiedName) if fn != nil { // The function name is namespaced and so preserving the target operand would // be an inaccurate representation of the desired evaluation behavior. @@ -269,7 +302,7 @@ func (c *checker) checkCall(e ast.Expr) { // Regular instance call. c.check(target) - fn := c.env.LookupFunction(fnName) + fn := c.env.lookupFunction(fnName) // Function found, attempt overload resolution. if fn != nil { c.resolveOverloadOrError(e, fn, target, args) @@ -441,7 +474,7 @@ func (c *checker) checkCreateStruct(e ast.Expr) { msgVal := e.AsStruct() // Determine the type of the message. resultType := types.ErrorType - ident := c.env.LookupIdent(msgVal.TypeName()) + ident := c.env.resolveTypeIdent(msgVal.TypeName()) if ident == nil { c.errors.undeclaredReference( e.ID(), c.location(e), c.env.container.Name(), msgVal.TypeName()) diff --git a/checker/checker_test.go b/checker/checker_test.go index 42ae9cd50..025de9615 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -2229,9 +2229,9 @@ _&&_(_==_(list~type(list(dyn))^list, decls.NewVariable("NotAMessage", types.NewNullableType(types.IntType)), }, }, - err: `ERROR: :1:12: 'wrapper(int)' is not a type - | NotAMessage{} - | ...........^`, + err: `ERROR: :1:12: undeclared reference to 'NotAMessage' (in container '') + | NotAMessage{} + | ...........^`, }, { in: `{}.map(c,[c,type(c)])`, @@ -2262,6 +2262,156 @@ _&&_(_==_(list~type(list(dyn))^list, @result~list(list(dyn))^@result)~list(list(dyn))`, outType: types.NewListType(types.NewListType(types.DynType)), }, + { + in: `[{'z': 0}].exists(y, y.z == 0)`, + env: testEnv{ + idents: []*decls.VariableDecl{ + decls.NewVariable("cel.example.y", types.NewMapType(types.StringType, types.IntType)), + }, + }, + out: `__comprehension__( + // Variable + y, + // Target + [ + { + "z"~string:0~int + }~map(string, int) + ]~list(map(string, int)), + // Accumulator + @result, + // Init + false~bool, + // LoopCondition + @not_strictly_false( + !_( + @result~bool^@result + )~bool^logical_not + )~bool^not_strictly_false, + // LoopStep + _||_( + @result~bool^@result, + _==_( + y~map(string, int)^y.z~int, + 0~int + )~bool^equals + )~bool^logical_or, + // Result + @result~bool^@result)~bool`, + outType: types.BoolType, + }, + { + in: `[{'y': 0}].exists(x, x.y == 0)`, + env: testEnv{ + idents: []*decls.VariableDecl{ + decls.NewVariable("x", types.NewMapType(types.StringType, types.IntType)), + }, + }, + out: `__comprehension__( + // Variable + x, + // Target + [ + { + "y"~string:0~int + }~map(string, int) + ]~list(map(string, int)), + // Accumulator + @result, + // Init + false~bool, + // LoopCondition + @not_strictly_false( + !_( + @result~bool^@result + )~bool^logical_not + )~bool^not_strictly_false, + // LoopStep + _||_( + @result~bool^@result, + _==_( + x~map(string, int)^x.y~int, + 0~int + )~bool^equals + )~bool^logical_or, + // Result + @result~bool^@result)~bool`, + outType: types.BoolType, + }, + { + in: `[0].exists(x, x != .x)`, + env: testEnv{ + idents: []*decls.VariableDecl{ + decls.NewVariable("x", types.IntType), + }, + }, + out: `__comprehension__( + // Variable + x, + // Target + [ + 0~int + ]~list(int), + // Accumulator + @result, + // Init + false~bool, + // LoopCondition + @not_strictly_false( + !_( + @result~bool^@result + )~bool^logical_not + )~bool^not_strictly_false, + // LoopStep + _||_( + @result~bool^@result, + _!=_( + x~int^x, + .x~int^.x + )~bool^not_equals + )~bool^logical_or, + // Result + @result~bool^@result)~bool`, + outType: types.BoolType, + }, + { + in: `[{'z': 0}].exists(y, .y.z == y.z)`, + env: testEnv{ + idents: []*decls.VariableDecl{ + decls.NewVariable("y.z", types.IntType), + }, + }, + out: `__comprehension__( + // Variable + y, + // Target + [ + { + "z"~string:0~int + }~map(string, int) + ]~list(map(string, int)), + // Accumulator + @result, + // Init + false~bool, + // LoopCondition + @not_strictly_false( + !_( + @result~bool^@result + )~bool^logical_not + )~bool^not_strictly_false, + // LoopStep + _||_( + @result~bool^@result, + _==_( + .y.z~int^.y.z, + y~map(string, int)^y.z~int + )~bool^equals + )~bool^logical_or, + // Result + @result~bool^@result)~bool`, + outType: types.BoolType, + }, } } diff --git a/checker/env.go b/checker/env.go index 8e9aec809..6d991eba1 100644 --- a/checker/env.go +++ b/checker/env.go @@ -129,36 +129,111 @@ func (e *Env) AddFunctions(declarations ...*decls.FunctionDecl) error { return formatError(errMsgs) } -// LookupIdent returns a Decl proto for typeName as an identifier in the Env. -// Returns nil if no such identifier is found in the Env. -func (e *Env) LookupIdent(name string) *decls.VariableDecl { +// newAttrResolution creates a new attribute resolution value. +func newAttrResolution(ident *decls.VariableDecl, requiresDisambiguation bool) *attributeResolution { + return &attributeResolution{ + VariableDecl: ident, + requiresDisambiguation: requiresDisambiguation, + } +} + +// attributeResolution wraps an existing variable and denotes whether disambiguation is needed +// during variable resolution. +type attributeResolution struct { + *decls.VariableDecl + + // requiresDisambiguation indicates the variable name should be dot-prefixed. + requiresDisambiguation bool +} + +// resolveSimpleIdent determines the resolved attribute for a single identifier. +func (e *Env) resolveSimpleIdent(name string) *attributeResolution { + local := e.lookupLocalIdent(name) + if local != nil && !strings.HasPrefix(name, ".") { + return newAttrResolution(local, false) + } for _, candidate := range e.container.ResolveCandidateNames(name) { - if ident := e.declarations.FindIdent(candidate); ident != nil { - return ident + if ident := e.lookupGlobalIdent(candidate); ident != nil { + return newAttrResolution(ident, local != nil) } + } + return nil +} - // Next try to import the name as a reference to a message type. - if t, found := e.provider.FindStructType(candidate); found { - return decls.NewVariable(candidate, t) +// resolveQualifiedIdent determines the resolved attribute for a qualified identifier. +func (e *Env) resolveQualifiedIdent(qualifiers ...string) *attributeResolution { + if len(qualifiers) == 1 { + return e.resolveSimpleIdent(qualifiers[0]) + } + local := e.lookupLocalIdent(qualifiers[0]) + if local != nil && !strings.HasPrefix(qualifiers[0], ".") { + // this should resolve through a field selection rather than a qualified identifier + return nil + } + // The qualifiers are concatenated together to indicate the qualified name to search + // for as a global identifier. Since select expressions are resolved from leaf to root + // if the fully concatenated string doesn't match a global identifier, indicate that + // no variable was found to continue the traversal up to the next simpler name. + varName := strings.Join(qualifiers, ".") + for _, candidate := range e.container.ResolveCandidateNames(varName) { + if ident := e.lookupGlobalIdent(candidate); ident != nil { + return newAttrResolution(ident, local != nil) } + } + return nil +} + +// resolveTypeIdent returns a Decl proto for typeName as an identifier in the Env. +// Returns nil if no such identifier is found in the Env. +func (e *Env) resolveTypeIdent(name string) *decls.VariableDecl { + for _, candidate := range e.container.ResolveCandidateNames(name) { + // Try to import the name as a reference to a message type. if i, found := e.provider.FindIdent(candidate); found { if t, ok := i.(*types.Type); ok { return decls.NewVariable(candidate, types.NewTypeTypeWithParam(t)) } } + // Next, try to find the struct type. + if t, found := e.provider.FindStructType(candidate); found { + return decls.NewVariable(candidate, t) + } + } + return nil +} + +// lookupLocalIdent finds the variable candidate in a local scope, returning nil if +// the candidate variable name is not a local variable. +func (e *Env) lookupLocalIdent(candidate string) *decls.VariableDecl { + return e.declarations.FindLocalIdent(candidate) +} - // Next try to import this as an enum value by splitting the name in a type prefix and - // the enum inside. - if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType { - return decls.NewConstant(candidate, types.IntType, enumValue) +// lookupGlobalIdent finds a candidate variable name in the root scope, returning +// nil if the identifier is not in the global scope. +func (e *Env) lookupGlobalIdent(candidate string) *decls.VariableDecl { + // Try to resolve the global identifier first. + if ident := e.declarations.FindGlobalIdent(candidate); ident != nil { + return ident + } + // Next try to import the name as a reference to a message type. + if i, found := e.provider.FindIdent(candidate); found { + if t, ok := i.(*types.Type); ok { + return decls.NewVariable(candidate, types.NewTypeTypeWithParam(t)) } } + if t, found := e.provider.FindStructType(candidate); found { + return decls.NewVariable(candidate, t) + } + // Next try to import this as an enum value by splitting the name in a type prefix and + // the enum inside. + if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType { + return decls.NewConstant(candidate, types.IntType, enumValue) + } return nil } -// LookupFunction returns a Decl proto for typeName as a function in env. +// lookupFunction returns a Decl proto for typeName as a function in env. // Returns nil if no such function is found in env. -func (e *Env) LookupFunction(name string) *decls.FunctionDecl { +func (e *Env) lookupFunction(name string) *decls.FunctionDecl { for _, candidate := range e.container.ResolveCandidateNames(name) { if fn := e.declarations.FindFunction(candidate); fn != nil { return fn diff --git a/checker/scopes.go b/checker/scopes.go index 8bb73ddb6..9ae9832e1 100644 --- a/checker/scopes.go +++ b/checker/scopes.go @@ -15,6 +15,8 @@ package checker import ( + "strings" + "github.com/google/cel-go/common/decls" ) @@ -76,6 +78,7 @@ func (s *Scopes) AddIdent(decl *decls.VariableDecl) { // found. // Note: The search is performed from innermost to outermost. func (s *Scopes) FindIdent(name string) *decls.VariableDecl { + name = strings.TrimPrefix(name, ".") if ident, found := s.scopes.idents[name]; found { return ident } @@ -89,12 +92,33 @@ func (s *Scopes) FindIdent(name string) *decls.VariableDecl { // nil if one does not exist. // Note: The search is only performed on the current scope and does not search outer scopes. func (s *Scopes) FindIdentInScope(name string) *decls.VariableDecl { + name = strings.TrimPrefix(name, ".") if ident, found := s.scopes.idents[name]; found { return ident } return nil } +// FindLocalIdent finds a locally scoped variable with a given name, ignoring the root scope. +func (s *Scopes) FindLocalIdent(name string) *decls.VariableDecl { + if s == nil || s.parent == nil { + return nil + } + if ident := s.FindIdentInScope(name); ident != nil { + return ident + } + return s.parent.FindLocalIdent(name) +} + +// FindGlobalIdent finds an identifier in the global scope, ignoring all local scopes. +func (s *Scopes) FindGlobalIdent(name string) *decls.VariableDecl { + scope := s + for scope.parent != nil { + scope = scope.parent + } + return scope.FindIdentInScope(name) +} + // SetFunction adds the function Decl to the current scope. // Note: Any previous entry for a function in the current scope with the same name is overwritten. func (s *Scopes) SetFunction(fn *decls.FunctionDecl) { @@ -105,6 +129,7 @@ func (s *Scopes) SetFunction(fn *decls.FunctionDecl) { // The search is performed from innermost to outermost. // Returns nil if no such function in Scopes. func (s *Scopes) FindFunction(name string) *decls.FunctionDecl { + name = strings.TrimPrefix(name, ".") if fn, found := s.scopes.functions[name]; found { return fn } diff --git a/conformance/BUILD.bazel b/conformance/BUILD.bazel index cdb8d6dfe..3ab686f2d 100644 --- a/conformance/BUILD.bazel +++ b/conformance/BUILD.bazel @@ -37,16 +37,9 @@ _ALL_TESTS = [ ] _TESTS_TO_SKIP = [ - "comparisons/eq_literal/eq_mixed_types_error,eq_list_elem_mixed_types_error,eq_map_value_mixed_types_error", - "comparisons/ne_literal/ne_mixed_types_error", - "comparisons/in_list_literal/elem_in_mixed_type_list_error", - "comparisons/in_map_literal/key_in_mixed_key_type_map_error", - "macros/exists/list_elem_type_exhaustive,map_key_type_exhaustive", - # Failing conformance tests. "fields/qualified_identifier_resolution/map_key_float,map_key_null,map_value_repeat_key", "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", - "macros/map/map_extract_keys", "timestamps/duration_converters/get_milliseconds", "optionals/optionals/map_optional_select_has", diff --git a/ext/bindings_test.go b/ext/bindings_test.go index df89eec46..ca570772c 100644 --- a/ext/bindings_test.go +++ b/ext/bindings_test.go @@ -90,6 +90,72 @@ var bindingTests = []struct { estimatedCost: checker.CostEstimate{Min: 38, Max: 40}, actualCost: 39, }, + { + name: "shadowed binding", + expr: `cel.bind(x, 0, x == 0)`, + vars: []cel.EnvOption{cel.Variable("x", cel.StringType)}, + in: map[string]any{ + "cel.example.x": "1", + }, + estimatedCost: checker.FixedCostEstimate(12), + actualCost: 12, + }, + { + name: "container shadowed binding", + expr: `cel.bind(x, 0, x == 0)`, + vars: []cel.EnvOption{ + cel.Container("cel.example"), + cel.Variable("cel.example.x", cel.StringType), + }, + in: map[string]any{ + "cel.example.x": "1", + }, + estimatedCost: checker.FixedCostEstimate(12), + actualCost: 12, + }, + { + name: "shadowing namespace resolution selector", + expr: `cel.bind(x, {'y': 0}, x.y == 0)`, + vars: []cel.EnvOption{ + cel.Container("cel.example"), + cel.Variable("cel.example.x.y", cel.IntType), + }, + in: map[string]any{ + "cel.example.x.y": 1, + }, + estimatedCost: checker.FixedCostEstimate(43), + actualCost: 43, + }, + { + name: "shadowing namespace resolution selector with local", + expr: `cel.bind(x, {'y': 0}, .x.y == x.y)`, + vars: []cel.EnvOption{ + cel.Variable("x.y", cel.IntType), + }, + in: map[string]any{ + "x.y": 0, + }, + estimatedCost: checker.FixedCostEstimate(44), + actualCost: 44, + }, + { + name: "namespace disambiguation", + expr: `cel.bind(y, 0, .y != y)`, + vars: []cel.EnvOption{ + cel.Variable("y", cel.IntType), + }, + in: map[string]any{ + "y": 1, + }, + estimatedCost: checker.FixedCostEstimate(13), + actualCost: 13, + }, + { + name: "nesting shadowing", + expr: `cel.bind(y, 0, cel.bind(y, 1, y != 0))`, + estimatedCost: checker.FixedCostEstimate(22), + actualCost: 22, + }, } func TestBindings(t *testing.T) { diff --git a/interpreter/attribute_patterns.go b/interpreter/attribute_patterns.go index 7d0759e37..41ca5cd21 100644 --- a/interpreter/attribute_patterns.go +++ b/interpreter/attribute_patterns.go @@ -16,6 +16,7 @@ package interpreter import ( "fmt" + "strings" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/types" @@ -207,10 +208,19 @@ func (fac *partialAttributeFactory) AbsoluteAttribute(id int64, names ...string) // 'maybe' NamespacedAttribute values are produced using the partialAttributeFactory rather than // the base AttributeFactory implementation. func (fac *partialAttributeFactory) MaybeAttribute(id int64, name string) Attribute { + var names []string + // When there's a single name with a dot prefix, it indicates that the 'maybe' attribute is a + // globally namespaced identifier. + if strings.HasPrefix(name, ".") { + names = append(names, name) + } else { + // In all other cases, the candidate names should be inferred. + names = fac.container.ResolveCandidateNames(name) + } return &maybeAttribute{ id: id, attrs: []NamespacedAttribute{ - fac.AbsoluteAttribute(id, fac.container.ResolveCandidateNames(name)...), + fac.AbsoluteAttribute(id, names...), }, adapter: fac.adapter, provider: fac.provider, diff --git a/interpreter/attributes.go b/interpreter/attributes.go index b1b3aacc8..4cf852d21 100644 --- a/interpreter/attributes.go +++ b/interpreter/attributes.go @@ -166,9 +166,17 @@ type attrFactory struct { // The namespaceNames represent the names the variable could have based on namespace // resolution rules. func (r *attrFactory) AbsoluteAttribute(id int64, names ...string) NamespacedAttribute { + disambiguateNames := make(map[int]bool) + for idx, name := range names { + if strings.HasPrefix(name, ".") { + disambiguateNames[idx] = true + names[idx] = strings.TrimPrefix(name, ".") + } + } return &absoluteAttribute{ id: id, namespaceNames: names, + disambiguateNames: disambiguateNames, qualifiers: []Qualifier{}, adapter: r.adapter, provider: r.provider, @@ -193,10 +201,19 @@ func (r *attrFactory) ConditionalAttribute(id int64, expr Interpretable, t, f At // MaybeAttribute collects variants of unchecked AbsoluteAttribute values which could either be // direct variable accesses or some combination of variable access with qualification. func (r *attrFactory) MaybeAttribute(id int64, name string) Attribute { + var names []string + // When there's a single name with a dot prefix, it indicates that the 'maybe' attribute is a + // globally namespaced identifier. + if strings.HasPrefix(name, ".") { + names = append(names, name) + } else { + // In all other cases, the candidate names should be inferred. + names = r.container.ResolveCandidateNames(name) + } return &maybeAttribute{ id: id, attrs: []NamespacedAttribute{ - r.AbsoluteAttribute(id, r.container.ResolveCandidateNames(name)...), + r.AbsoluteAttribute(id, names...), }, adapter: r.adapter, provider: r.provider, @@ -242,10 +259,13 @@ type absoluteAttribute struct { // namespaceNames represent the names the variable could have based on declared container // (package) of the expression. namespaceNames []string - qualifiers []Qualifier - adapter types.Adapter - provider types.Provider - fac AttributeFactory + // disambiguateNames stores a list of indices to namespaceNames which require disambiguation + disambiguateNames map[int]bool + + qualifiers []Qualifier + adapter types.Adapter + provider types.Provider + fac AttributeFactory errorOnBadPresenceTest bool } @@ -304,15 +324,34 @@ func (a *absoluteAttribute) String() string { // a type, then the result is `nil`, `error` with the error indicating the name of the first // variable searched as missing. func (a *absoluteAttribute) Resolve(vars Activation) (any, error) { - for _, nm := range a.namespaceNames { + // unwrap any local activations to ensure that we reach the variables provided as input + // to the expression in the event that we need to disambiguate between global and local + // variables. + // + // Presently, only dynamic and constant slot activations created during comprehensions + // support 'unwrapping', which is consistent with how local variables are introduced into CEL. + var inputVars Activation + if len(a.disambiguateNames) > 0 { + inputVars = vars + wrapped, ok := inputVars.(activationWrapper) + for ok { + inputVars = wrapped.Unwrap() + wrapped, ok = inputVars.(activationWrapper) + } + } + for idx, nm := range a.namespaceNames { // If the variable is found, process it. Otherwise, wait until the checks to // determine whether the type is unknown before returning. - obj, found := vars.ResolveName(nm) + v := vars + if disambiguate, found := a.disambiguateNames[idx]; found && disambiguate { + v = inputVars + } + obj, found := v.ResolveName(nm) if found { if celErr, ok := obj.(*types.Err); ok { return nil, celErr.Unwrap() } - obj, isOpt, err := applyQualifiers(vars, obj, a.qualifiers) + obj, isOpt, err := applyQualifiers(v, obj, a.qualifiers) if err != nil { return nil, err } diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 96b5a8ffc..9c8575db5 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -1404,6 +1404,11 @@ func (f *folder) Parent() Activation { return f.activation } +// Unwrap returns the parent activation, thus omitting access to local state +func (f *folder) Unwrap() Activation { + return f.activation +} + // UnknownAttributePatterns implements the PartialActivation interface returning the unknown patterns // if they were provided to the input activation, or an empty set if the proxied activation is not partial. func (f *folder) UnknownAttributePatterns() []*AttributePattern { diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 1ed5e57f6..d81ef1280 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -137,8 +137,10 @@ func (esa evalStateActivation) asEvalState() EvalState { return esa.state } -// activationWrapper identifies an object that can be unwrapped to access the underlying activation. +// activationWrapper identifies an object carrying local variables which should not be exposed to the user +// Activations used for such purposes can be unwrapped to return the activation which omits local state. type activationWrapper interface { + // Unwrap returns the Activation which omits local state. Unwrap() Activation } diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index b664520eb..6ca88ad40 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -1511,6 +1511,102 @@ func testData(t testing.TB) []testCase { expr: `{"invalid": dyn(null)}.?invalid.?nested`, out: types.OptionalNone, }, + { + name: "local_shadow_identifier_in_select", + expr: `[{'z': 0}].exists(y, y.z == 0)`, + container: "cel.example", + vars: []*decls.VariableDecl{ + decls.NewVariable("cel.example.y", types.IntType), + }, + in: map[string]any{ + "cel.example.y": map[string]int{"z": 1}, + }, + out: types.True, + }, + { + name: "local_shadow_identifier_in_select_global_disambiguation", + expr: `[{'z': 0}].exists(y, y.z == 0 && .y.z == 1)`, + container: "y", + vars: []*decls.VariableDecl{ + decls.NewVariable("y.z", types.IntType), + }, + in: map[string]any{ + "y.z": 1, + }, + out: types.True, + }, + { + name: "local_shadow_identifier_with_global_disambiguation", + expr: `[0].exists(x, x == 0 && .x == 1)`, + vars: []*decls.VariableDecl{ + decls.NewVariable("x", types.IntType), + }, + in: map[string]any{ + "x": 1, + }, + out: types.True, + }, + { + name: "local_double_shadow_identifier_with_global_disambiguation", + expr: `[0].exists(x, [x+1].exists(x, x == .x))`, + vars: []*decls.VariableDecl{ + decls.NewVariable("x", types.IntType), + }, + in: map[string]any{ + "x": 1, + }, + out: types.True, + }, + { + name: "unchecked_local_shadow_identifier_in_select", + expr: `[{'z': 0}].exists(y, y.z == 0)`, + unchecked: true, + container: "cel.example", + vars: []*decls.VariableDecl{ + decls.NewVariable("cel.example.y", types.IntType), + }, + in: map[string]any{ + "cel.example.y": map[string]int{"z": 1}, + }, + out: types.True, + }, + { + name: "unchecked_local_shadow_identifier_in_select_global_disambiguation", + expr: `[{'z': 0}].exists(y, y.z == 0 && .y.z == 1)`, + container: "y", + unchecked: true, + vars: []*decls.VariableDecl{ + decls.NewVariable("y.z", types.IntType), + }, + in: map[string]any{ + "y.z": 1, + }, + out: types.True, + }, + { + name: "unchecked_local_shadow_identifier_with_global_disambiguation", + expr: `[0].exists(x, x == 0 && .x == 1)`, + unchecked: true, + vars: []*decls.VariableDecl{ + decls.NewVariable("x", types.IntType), + }, + in: map[string]any{ + "x": 1, + }, + out: types.True, + }, + { + name: "unchecked_local_double_shadow_identifier_with_global_disambiguation", + expr: `[0].exists(x, [x+1].exists(x, x == .x))`, + unchecked: true, + vars: []*decls.VariableDecl{ + decls.NewVariable("x", types.IntType), + }, + in: map[string]any{ + "x": 1, + }, + out: types.True, + }, } } diff --git a/interpreter/planner.go b/interpreter/planner.go index f0e0d4305..0bc38449c 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -61,13 +61,20 @@ type planner struct { observers []StatefulObserver } +type planBuilder struct { + *planner + + localVars map[string]int +} + // Plan implements the interpretablePlanner interface. This implementation of the Plan method also // applies decorators to each Interpretable generated as part of the overall plan. Decorators are // useful for layering functionality into the evaluation that is not natively understood by CEL, // such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of // repeated expressions. func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { - i, err := p.plan(expr) + pb := &planBuilder{planner: p, localVars: make(map[string]int)} + i, err := pb.plan(expr) if err != nil { return nil, err } @@ -77,7 +84,7 @@ func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { return &ObservableInterpretable{Interpretable: i, observers: p.observers}, nil } -func (p *planner) plan(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) plan(expr ast.Expr) (Interpretable, error) { switch expr.Kind() { case ast.CallKind: return p.decorate(p.planCall(expr)) @@ -102,7 +109,7 @@ func (p *planner) plan(expr ast.Expr) (Interpretable, error) { // decorate applies the InterpretableDecorator functions to the given Interpretable. // Both the Interpretable and error generated by a Plan step are accepted as arguments // for convenience. -func (p *planner) decorate(i Interpretable, err error) (Interpretable, error) { +func (p *planBuilder) decorate(i Interpretable, err error) (Interpretable, error) { if err != nil { return nil, err } @@ -116,20 +123,26 @@ func (p *planner) decorate(i Interpretable, err error) (Interpretable, error) { } // planIdent creates an Interpretable that resolves an identifier from an Activation. -func (p *planner) planIdent(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planIdent(expr ast.Expr) (Interpretable, error) { // Establish whether the identifier is in the reference map. if identRef, found := p.refMap[expr.ID()]; found { return p.planCheckedIdent(expr.ID(), identRef) } // Create the possible attribute list for the unresolved reference. ident := expr.AsIdent() + if p.isLocalVar(ident) { + return &evalAttr{ + adapter: p.adapter, + attr: p.attrFactory.AbsoluteAttribute(expr.ID(), ident), + }, nil + } return &evalAttr{ adapter: p.adapter, attr: p.attrFactory.MaybeAttribute(expr.ID(), ident), }, nil } -func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Interpretable, error) { +func (p *planBuilder) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Interpretable, error) { // Plan a constant reference if this is the case for this simple identifier. if identRef.Value != nil { return NewConstValue(id, identRef.Value), nil @@ -158,7 +171,7 @@ func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Inter // a) selects a field from a map or proto. // b) creates a field presence test for a select within a has() macro. // c) resolves the select expression to a namespaced identifier. -func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planSelect(expr ast.Expr) (Interpretable, error) { // If the Select id appears in the reference map from the CheckedExpr proto then it is either // a namespaced identifier or enum value. if identRef, found := p.refMap[expr.ID()]; found { @@ -214,7 +227,7 @@ func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) { // planCall creates a callable Interpretable while specializing for common functions and invocation // patterns. Specifically, conditional operators &&, ||, ?:, and (in)equality functions result in // optimized Interpretable values. -func (p *planner) planCall(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCall(expr ast.Expr) (Interpretable, error) { call := expr.AsCall() target, fnName, oName := p.resolveFunction(expr) argCount := len(call.Args()) @@ -291,7 +304,7 @@ func (p *planner) planCall(expr ast.Expr) (Interpretable, error) { } // planCallZero generates a zero-arity callable Interpretable. -func (p *planner) planCallZero(expr ast.Expr, +func (p *planBuilder) planCallZero(expr ast.Expr, function string, overload string, impl *functions.Overload) (Interpretable, error) { @@ -307,7 +320,7 @@ func (p *planner) planCallZero(expr ast.Expr, } // planCallUnary generates a unary callable Interpretable. -func (p *planner) planCallUnary(expr ast.Expr, +func (p *planBuilder) planCallUnary(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -335,7 +348,7 @@ func (p *planner) planCallUnary(expr ast.Expr, } // planCallBinary generates a binary callable Interpretable. -func (p *planner) planCallBinary(expr ast.Expr, +func (p *planBuilder) planCallBinary(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -364,7 +377,7 @@ func (p *planner) planCallBinary(expr ast.Expr, } // planCallVarArgs generates a variable argument callable Interpretable. -func (p *planner) planCallVarArgs(expr ast.Expr, +func (p *planBuilder) planCallVarArgs(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -392,7 +405,7 @@ func (p *planner) planCallVarArgs(expr ast.Expr, } // planCallEqual generates an equals (==) Interpretable. -func (p *planner) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalEq{ id: expr.ID(), lhs: args[0], @@ -401,7 +414,7 @@ func (p *planner) planCallEqual(expr ast.Expr, args []Interpretable) (Interpreta } // planCallNotEqual generates a not equals (!=) Interpretable. -func (p *planner) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalNe{ id: expr.ID(), lhs: args[0], @@ -410,7 +423,7 @@ func (p *planner) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpr } // planCallLogicalAnd generates a logical and (&&) Interpretable. -func (p *planner) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalAnd{ id: expr.ID(), terms: args, @@ -418,7 +431,7 @@ func (p *planner) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Inter } // planCallLogicalOr generates a logical or (||) Interpretable. -func (p *planner) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalOr{ id: expr.ID(), terms: args, @@ -426,7 +439,7 @@ func (p *planner) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interp } // planCallConditional generates a conditional / ternary (c ? t : f) Interpretable. -func (p *planner) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) { cond := args[0] t := args[1] var tAttr Attribute @@ -454,7 +467,7 @@ func (p *planner) planCallConditional(expr ast.Expr, args []Interpretable) (Inte // planCallIndex either extends an attribute with the argument to the index operation, or creates // a relative attribute based on the return of a function call or operation. -func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) { +func (p *planBuilder) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) { op := args[0] ind := args[1] opType := p.typeMap[op.ID()] @@ -489,7 +502,7 @@ func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bo } // planCreateList generates a list construction Interpretable. -func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCreateList(expr ast.Expr) (Interpretable, error) { list := expr.AsList() optionalIndices := list.OptionalIndices() elements := list.Elements() @@ -518,7 +531,7 @@ func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) { } // planCreateStruct generates a map or object construction Interpretable. -func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCreateMap(expr ast.Expr) (Interpretable, error) { m := expr.AsMap() entries := m.Entries() optionals := make([]bool, len(entries)) @@ -552,7 +565,7 @@ func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) { } // planCreateObj generates an object construction Interpretable. -func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCreateStruct(expr ast.Expr) (Interpretable, error) { obj := expr.AsStruct() typeName, defined := p.resolveTypeName(obj.TypeName()) if !defined { @@ -586,7 +599,7 @@ func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) { } // planComprehension generates an Interpretable fold operation. -func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planComprehension(expr ast.Expr) (Interpretable, error) { fold := expr.AsComprehension() accu, err := p.plan(fold.AccuInit()) if err != nil { @@ -596,6 +609,7 @@ func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { if err != nil { return nil, err } + p.pushLocalVars(fold.AccuVar(), fold.IterVar(), fold.IterVar2()) cond, err := p.plan(fold.LoopCondition()) if err != nil { return nil, err @@ -604,10 +618,12 @@ func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { if err != nil { return nil, err } + p.popLocalVars(fold.IterVar(), fold.IterVar2()) result, err := p.plan(fold.Result()) if err != nil { return nil, err } + p.popLocalVars(fold.AccuVar()) return &evalFold{ id: expr.ID(), accuVar: fold.AccuVar(), @@ -623,13 +639,13 @@ func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { } // planConst generates a constant valued Interpretable. -func (p *planner) planConst(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planConst(expr ast.Expr) (Interpretable, error) { return NewConstValue(expr.ID(), expr.AsLiteral()), nil } // resolveTypeName takes a qualified string constructed at parse time, applies the proto // namespace resolution rules to it in a scan over possible matching types in the TypeProvider. -func (p *planner) resolveTypeName(typeName string) (string, bool) { +func (p *planBuilder) resolveTypeName(typeName string) (string, bool) { for _, qualifiedTypeName := range p.container.ResolveCandidateNames(typeName) { if _, found := p.provider.FindStructType(qualifiedTypeName); found { return qualifiedTypeName, true @@ -646,7 +662,7 @@ func (p *planner) resolveTypeName(typeName string) (string, bool) { // - The target expression may only consist of ident and select expressions. // - The function is declared in the environment using its fully-qualified name. // - The fully-qualified function name matches the string serialized target value. -func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) { +func (p *planBuilder) resolveFunction(expr ast.Expr) (ast.Expr, string, string) { // Note: similar logic exists within the `checker/checker.go`. If making changes here // please consider the impact on checker.go and consolidate implementations or mirror code // as appropriate. @@ -687,7 +703,7 @@ func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) { // namespaced identifiers must be stripped, as all declarations already use fully-qualified // names. This stripping behavior is handled automatically by the ResolveCandidateNames // call. - return target, stripLeadingDot(fnName), "" + return target, strings.TrimPrefix(fnName, "."), "" } // Handle the situation where the function target actually indicates a qualified function name. @@ -710,7 +726,7 @@ func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) { // relativeAttr indicates that the attribute in this case acts as a qualifier and as such needs to // be observed to ensure that it's evaluation value is properly recorded for state tracking. -func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (InterpretableAttribute, error) { +func (p *planBuilder) relativeAttr(id int64, eval Interpretable, opt bool) (InterpretableAttribute, error) { eAttr, ok := eval.(InterpretableAttribute) if !ok { eAttr = &evalAttr{ @@ -733,7 +749,7 @@ func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (Interpre // toQualifiedName converts an expression AST into a qualified name if possible, with a boolean // 'found' value that indicates if the conversion is successful. -func (p *planner) toQualifiedName(operand ast.Expr) (string, bool) { +func (p *planBuilder) toQualifiedName(operand ast.Expr) (string, bool) { // If the checker identified the expression as an attribute by the type-checker, then it can't // possibly be part of qualified name in a namespace. _, isAttr := p.refMap[operand.ID()] @@ -759,9 +775,35 @@ func (p *planner) toQualifiedName(operand ast.Expr) (string, bool) { return "", false } -func stripLeadingDot(name string) string { - if strings.HasPrefix(name, ".") { - return name[1:] +func (p *planBuilder) pushLocalVars(names ...string) { + for _, name := range names { + if name == "" { + continue + } + if cnt, found := p.localVars[name]; found { + p.localVars[name] = cnt + 1 + } else { + p.localVars[name] = 1 + } + } +} + +func (p *planBuilder) popLocalVars(names ...string) { + for _, name := range names { + if name == "" { + continue + } + if cnt, found := p.localVars[name]; found { + if cnt == 1 { + delete(p.localVars, name) + } else { + p.localVars[name] = cnt - 1 + } + } } - return name +} + +func (p *planBuilder) isLocalVar(name string) bool { + _, found := p.localVars[name] + return found } diff --git a/type_collision_fix.diff b/type_collision_fix.diff new file mode 100644 index 000000000..d1b9d2014 --- /dev/null +++ b/type_collision_fix.diff @@ -0,0 +1,82 @@ +commit b5925cf0b39ff9ff652a6c545d7ceced58f5dd1a +Author: TristonianJones +Date: Tue Dec 16 15:15:35 2025 -0800 + + Remove the treatment of standard identifiers as variables for types + +diff --git a/cel/library.go b/cel/library.go +index 59a10e8..85e2aca 100644 +--- a/cel/library.go ++++ b/cel/library.go +@@ -182,7 +182,7 @@ func (lib *stdLibrary) CompileOptions() []EnvOption { + if err = lib.subset.Validate(); err != nil { + return nil, err + } +- e.variables = append(e.variables, stdlib.Types()...) ++ // e.variables = append(e.variables, stdlib.Types()...) + for _, fn := range funcs { + existing, found := e.functions[fn.Name()] + if found { +diff --git a/cel/testdata/standard_env.prompt.txt b/cel/testdata/standard_env.prompt.txt +index 18f0faf..ca73de1 100644 +--- a/cel/testdata/standard_env.prompt.txt ++++ b/cel/testdata/standard_env.prompt.txt +@@ -11,21 +11,6 @@ found in C++ or Java code. + + Only use the following variables, macros, and functions in expressions. + +-Variables: +- +-* bool is a type +-* bytes is a type +-* double is a type +-* google.protobuf.Duration is a type +-* google.protobuf.Timestamp is a type +-* int is a type +-* list is a type +-* map is a type +-* null_type is a type +-* string is a type +-* type is a type +-* uint is a type +- + Functions: + + * !_ - logically negate a boolean value. +diff --git a/checker/env.go b/checker/env.go +index 8e9aec8..dd738fb 100644 +--- a/checker/env.go ++++ b/checker/env.go +@@ -138,14 +138,14 @@ func (e *Env) LookupIdent(name string) *decls.VariableDecl { + } + + // Next try to import the name as a reference to a message type. +- if t, found := e.provider.FindStructType(candidate); found { +- return decls.NewVariable(candidate, t) +- } + if i, found := e.provider.FindIdent(candidate); found { + if t, ok := i.(*types.Type); ok { + return decls.NewVariable(candidate, types.NewTypeTypeWithParam(t)) + } + } ++ if t, found := e.provider.FindStructType(candidate); found { ++ return decls.NewVariable(candidate, t) ++ } + + // Next try to import this as an enum value by splitting the name in a type prefix and + // the enum inside. +diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go +index b664520..51d737b 100644 +--- a/interpreter/interpreter_test.go ++++ b/interpreter/interpreter_test.go +@@ -1041,6 +1041,10 @@ func testData(t testing.TB) []testCase { + }, + }, + }, ++ { ++ name: "type_dyn_equals_string", ++ expr: `type(dyn('')) == string`, ++ }, + { + name: "select_key", + expr: `m.strMap['val'] == 'string'