Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -1052,9 +1052,10 @@ func (p *interopCELTypeProvider) FindStructFieldType(structType, fieldName strin
return nil, false
}
return &types.FieldType{
Type: t,
IsSet: ft.IsSet,
GetFrom: ft.GetFrom,
Type: t,
IsSet: ft.IsSet,
GetFrom: ft.GetFrom,
IsJSONField: ft.IsJSONField,
}, true
}
return nil, false
Expand Down
8 changes: 8 additions & 0 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.E
// check() deletes some nodes while rewriting the AST. For example the Select operand is
// deleted when a variable reference is replaced with a Ident expression.
c.AST.ClearUnusedIDs()
if env.jsonFieldNames {
c.AST.SourceInfo().AddExtension(
ast.NewExtension("json_name", ast.NewExtensionVersion(1, 1), ast.ComponentRuntime),
)
}
return c.AST, errs
}

Expand Down Expand Up @@ -718,6 +723,9 @@ func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*
}

if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found {
if c.env.jsonFieldNames && !ft.IsJSONField {
c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName)
}
return ft.Type, found
}

Expand Down
47 changes: 45 additions & 2 deletions checker/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2413,6 +2413,42 @@ _&&_(_==_(list~type(list(dyn))^list,
@result~bool^@result)~bool`,
outType: types.BoolType,
},
{
in: `TestAllTypes{?singleInt32: {}.?i}`,
container: "google.expr.proto2.test",
env: testEnv{optionalSyntax: true, jsonFieldNames: true},
out: `google.expr.proto2.test.TestAllTypes{
?singleInt32:_?._(
{}~map(dyn, int),
"i"
)~optional_type(int)^select_optional_field
}~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes`,
outType: types.NewObjectType(
"google.expr.proto2.test.TestAllTypes",
),
},
{
in: `TestAllTypes{?singleInt32: {'i': 20}.?i}.singleInt32`,
container: "google.expr.proto2.test",
env: testEnv{optionalSyntax: true, jsonFieldNames: true},
out: `google.expr.proto2.test.TestAllTypes{
?singleInt32:_?._(
{
"i"~string:20~int
}~map(string, int),
"i"
)~optional_type(int)^select_optional_field
}~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes.singleInt32~int`,
outType: types.IntType,
},
{
in: `TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32`,
container: "google.expr.proto2.test",
env: testEnv{optionalSyntax: true, jsonFieldNames: true},
err: `ERROR: <input>:1:41: undefined field 'single_bool'
| TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32
| ........................................^`,
},
}
}

Expand Down Expand Up @@ -2470,6 +2506,7 @@ type testEnv struct {
functions []*decls.FunctionDecl
variadicASTs bool
optionalSyntax bool
jsonFieldNames bool
}

func TestCheck(t *testing.T) {
Expand Down Expand Up @@ -2505,9 +2542,12 @@ func TestCheck(t *testing.T) {
t.Fatalf("Unexpected parse errors: %v", errors.ToDisplayString())
}

reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{})
reg, err := types.NewProtoRegistry(
types.JSONFieldNames(tc.env.jsonFieldNames),
types.ProtoTypes(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}),
)
if err != nil {
t.Fatalf("types.NewRegistry() failed: %v", err)
t.Fatalf("types.NewProtoRegistry() failed: %v", err)
}
if tc.env.optionalSyntax {
if err := reg.RegisterType(types.OptionalType); err != nil {
Expand All @@ -2522,6 +2562,9 @@ func TestCheck(t *testing.T) {
if len(tc.opts) != 0 {
opts = tc.opts
}
if tc.env.jsonFieldNames {
opts = append(opts, JSONFieldNames(true))
}
env, err := NewEnv(cont, reg, opts...)
if err != nil {
t.Fatalf("NewEnv(cont, reg) failed: %v", err)
Expand Down
2 changes: 2 additions & 0 deletions checker/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ type Env struct {
declarations *Scopes
aggLitElemType aggregateLiteralElementType
filteredOverloadIDs map[string]struct{}
jsonFieldNames bool
}

// NewEnv returns a new *Env with the given parameters.
Expand Down Expand Up @@ -104,6 +105,7 @@ func NewEnv(container *containers.Container, provider types.Provider, opts ...Op
declarations: declarations,
aggLitElemType: aggLitElemType,
filteredOverloadIDs: filteredOverloadIDs,
jsonFieldNames: envOptions.jsonFieldNames,
}, nil
}

Expand Down
9 changes: 9 additions & 0 deletions checker/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type options struct {
crossTypeNumericComparisons bool
homogeneousAggregateLiterals bool
validatedDeclarations *Scopes
jsonFieldNames bool
}

// Option is a functional option for configuring the type-checker
Expand All @@ -40,3 +41,11 @@ func ValidatedDeclarations(env *Env) Option {
return nil
}
}

// JSONFieldNames enables the use of json names instead of the standard protobuf snake_case field names
func JSONFieldNames(enabled bool) Option {
return func(opts *options) error {
opts.jsonFieldNames = enabled
return nil
}
}
73 changes: 73 additions & 0 deletions common/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo {
for id, call := range info.macroCalls {
callsCopy[id] = defaultFactory.CopyExpr(call)
}
var extCopy []Extension
if len(info.extensions) > 0 {
extCopy = make([]Extension, len(info.extensions))
copy(extCopy, info.extensions)
}
return &SourceInfo{
syntax: info.syntax,
desc: info.desc,
Expand All @@ -239,6 +244,7 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo {
baseCol: info.baseCol,
offsetRanges: rangesCopy,
macroCalls: callsCopy,
extensions: extCopy,
}
}

Expand All @@ -252,6 +258,9 @@ type SourceInfo struct {
baseCol int32
offsetRanges map[int64]OffsetRange
macroCalls map[int64]Expr

// extensions indicate versioned optional features which affect the execution of one or more CEL component.
extensions []Extension
}

// RenumberIDs performs an in-place update of the expression IDs within the SourceInfo.
Expand Down Expand Up @@ -420,6 +429,23 @@ func (s *SourceInfo) ComputeOffsetAbsolute(line, col int32) int32 {
return offset + col
}

// Extensions returns the set of extensions present in the source.
func (s *SourceInfo) Extensions() []Extension {
var extensions []Extension
if s == nil {
return extensions
}
return s.extensions
}

// AddExtension adds an extension record into the SourceInfo.
func (s *SourceInfo) AddExtension(ext Extension) {
if s == nil {
return
}
s.extensions = append(s.extensions, ext)
}

// OffsetRange captures the start and stop positions of a section of text in the input expression.
type OffsetRange struct {
Start int32
Expand Down Expand Up @@ -489,6 +515,53 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
return true
}

// NewExtension creates an Extension to be recorded on the SourceInfo.
func NewExtension(id string, version ExtensionVersion, components ...ExtensionComponent) Extension {
return Extension{
ID: id,
Version: version,
Components: components,
}
}

// Extension represents a versioned, optional feature present in the AST that affects CEL component behavior.
type Extension struct {
// ID indicates the unique name of the extension.
ID string
// Version indicates the major / minor version.
Version ExtensionVersion
// Components enumerates the CEL components affected by the feature.
Components []ExtensionComponent
}

// NewExtensionVersion creates a new extension version with a major, minor version.
func NewExtensionVersion(major, minor int64) ExtensionVersion {
return ExtensionVersion{Major: major, Minor: minor}
}

// ExtensionVersion represents a semantic version with a major and minor number.
type ExtensionVersion struct {
// Major version of the extension.
// All versions with the same major number are expected to be compatible with all minor version changes.
Major int64

// Minor version of the extension which indicates that some small non-semantic change has been made to
// the extension.
Minor int64
}

// ExtensionComponent indicates which CEL component is affected.
type ExtensionComponent int

const (
// ComponentParser means the feature affects expression parsing.
ComponentParser ExtensionComponent = iota + 1
// ComponentTypeChecker means the feature affects type-checking.
ComponentTypeChecker
// ComponentRuntime alters program planning or evaluation of the AST.
ComponentRuntime
)

type maxIDVisitor struct {
maxID int64
*baseVisitor
Expand Down
53 changes: 53 additions & 0 deletions common/ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"reflect"
"testing"

"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
Expand Down Expand Up @@ -83,6 +84,55 @@ func TestASTCopy(t *testing.T) {
}
}

func TestASTJsonNames(t *testing.T) {
tests := []string{
`google.expr.proto3.test.TestAllTypes{}`,
`google.expr.proto3.test.TestAllTypes{repeatedInt32: [1, 2]}`,
`google.expr.proto3.test.TestAllTypes{singleInt32: 2}.singleInt32 == 2`,
}

for _, tst := range tests {
checked := mustTypeCheck(t, tst, checker.JSONFieldNames(true), types.JSONFieldNames(true))
copyChecked := ast.Copy(checked)
if !reflect.DeepEqual(copyChecked.Expr(), checked.Expr()) {
t.Errorf("Copy() got expr %v, wanted %v", copyChecked.Expr(), checked.Expr())
}
if !reflect.DeepEqual(copyChecked.SourceInfo(), checked.SourceInfo()) {
t.Errorf("Copy() got source info %v, wanted %v", copyChecked.SourceInfo(), checked.SourceInfo())
}
copyParsed := ast.Copy(ast.NewAST(checked.Expr(), checked.SourceInfo()))
if !reflect.DeepEqual(copyParsed.Expr(), checked.Expr()) {
t.Errorf("Copy() got expr %v, wanted %v", copyParsed.Expr(), checked.Expr())
}
if !reflect.DeepEqual(copyParsed.SourceInfo(), checked.SourceInfo()) {
t.Errorf("Copy() got source info %v, wanted %v", copyParsed.SourceInfo(), checked.SourceInfo())
}
checkedPB, err := ast.ToProto(checked)
if err != nil {
t.Errorf("ast.ToProto() failed: %v", err)
}
copyCheckedPB, err := ast.ToProto(copyChecked)
if err != nil {
t.Errorf("ast.ToProto() failed: %v", err)
}
if !proto.Equal(checkedPB, copyCheckedPB) {
t.Errorf("Copy() produced different proto results, got %v, wanted %v",
prototext.Format(checkedPB), prototext.Format(copyCheckedPB))
}
checkedRoundtrip, err := ast.ToAST(checkedPB)
if err != nil {
t.Errorf("ast.ToAST() failed: %v", err)
}
same := reflect.DeepEqual(checked.Expr(), checkedRoundtrip.Expr()) &&
reflect.DeepEqual(checked.ReferenceMap(), checkedRoundtrip.ReferenceMap()) &&
reflect.DeepEqual(checked.TypeMap(), checkedRoundtrip.TypeMap()) &&
reflect.DeepEqual(checked.SourceInfo().MacroCalls(), checkedRoundtrip.SourceInfo().MacroCalls())
if !same {
t.Errorf("Roundtrip got %v, wanted %v", checkedRoundtrip, checked)
}
}
}

func TestASTNilSafety(t *testing.T) {
ex, err := ast.ProtoToExpr(nil)
if err != nil {
Expand Down Expand Up @@ -184,6 +234,9 @@ func TestSourceInfoNilSafety(t *testing.T) {
if len(testInfo.MacroCalls()) != 0 {
t.Errorf("MacroCalls() got %v, wanted empty map", testInfo.MacroCalls())
}
if len(testInfo.Extensions()) != 0 {
t.Errorf("Extensions() got %v, wanted empty list", testInfo.Extensions())
}
if call, found := testInfo.GetMacroCall(0); found {
t.Errorf("GetMacroCall(0) got %v, wanted not found", call)
}
Expand Down
49 changes: 49 additions & 0 deletions common/ast/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ import (
structpb "google.golang.org/protobuf/types/known/structpb"
)

var (
pbComponentMap = map[exprpb.SourceInfo_Extension_Component]ExtensionComponent{
exprpb.SourceInfo_Extension_COMPONENT_PARSER: ComponentParser,
exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER: ComponentTypeChecker,
exprpb.SourceInfo_Extension_COMPONENT_RUNTIME: ComponentRuntime,
}
componentPBMap = map[ExtensionComponent]exprpb.SourceInfo_Extension_Component{
ComponentParser: exprpb.SourceInfo_Extension_COMPONENT_PARSER,
ComponentTypeChecker: exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER,
ComponentRuntime: exprpb.SourceInfo_Extension_COMPONENT_RUNTIME,
}
)

// ToProto converts an AST to a CheckedExpr protobouf.
func ToProto(ast *AST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap()))
Expand Down Expand Up @@ -534,6 +547,25 @@ func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) {
}
sourceInfo.MacroCalls[id] = call
}
for _, ext := range info.Extensions() {
var components []exprpb.SourceInfo_Extension_Component
for _, c := range ext.Components {
comp, found := componentPBMap[c]
if found {
components = append(components, comp)
}
}
ver := &exprpb.SourceInfo_Extension_Version{
Major: ext.Version.Major,
Minor: ext.Version.Minor,
}
pbExt := &exprpb.SourceInfo_Extension{
Id: ext.ID,
Version: ver,
AffectedComponents: components,
}
sourceInfo.Extensions = append(sourceInfo.Extensions, pbExt)
}
return sourceInfo, nil
}

Expand All @@ -556,6 +588,23 @@ func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) {
}
sourceInfo.SetMacroCall(id, call)
}
for _, pbExt := range info.GetExtensions() {
var components []ExtensionComponent
for _, c := range pbExt.GetAffectedComponents() {
comp, found := pbComponentMap[*c.Enum()]
if found {
components = append(components, comp)
}
}
sourceInfo.AddExtension(NewExtension(
pbExt.GetId(),
NewExtensionVersion(
pbExt.GetVersion().GetMajor(),
pbExt.GetVersion().GetMinor(),
),
components...,
))
}
return sourceInfo, nil
}

Expand Down
Loading