diff --git a/cel/env.go b/cel/env.go
index a7aa6db34..d3295b6a5 100644
--- a/cel/env.go
+++ b/cel/env.go
@@ -1052,9 +1052,10 @@ func (p *interopCELTypeProvider) FindStructFieldType(structType, fieldName strin
return nil, false
}
return &types.FieldType{
- Type: t,
- IsSet: ft.IsSet,
- GetFrom: ft.GetFrom,
+ Type: t,
+ IsSet: ft.IsSet,
+ GetFrom: ft.GetFrom,
+ IsJSONField: ft.IsJSONField,
}, true
}
return nil, false
diff --git a/checker/checker.go b/checker/checker.go
index d07d8e799..42d27a428 100644
--- a/checker/checker.go
+++ b/checker/checker.go
@@ -71,6 +71,11 @@ func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.E
// check() deletes some nodes while rewriting the AST. For example the Select operand is
// deleted when a variable reference is replaced with a Ident expression.
c.AST.ClearUnusedIDs()
+ if env.jsonFieldNames {
+ c.AST.SourceInfo().AddExtension(
+ ast.NewExtension("json_name", ast.NewExtensionVersion(1, 1), ast.ComponentRuntime),
+ )
+ }
return c.AST, errs
}
@@ -718,6 +723,9 @@ func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*
}
if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found {
+ if c.env.jsonFieldNames && !ft.IsJSONField {
+ c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName)
+ }
return ft.Type, found
}
diff --git a/checker/checker_test.go b/checker/checker_test.go
index 387c8bd75..623e1f709 100644
--- a/checker/checker_test.go
+++ b/checker/checker_test.go
@@ -2413,6 +2413,42 @@ _&&_(_==_(list~type(list(dyn))^list,
@result~bool^@result)~bool`,
outType: types.BoolType,
},
+ {
+ in: `TestAllTypes{?singleInt32: {}.?i}`,
+ container: "google.expr.proto2.test",
+ env: testEnv{optionalSyntax: true, jsonFieldNames: true},
+ out: `google.expr.proto2.test.TestAllTypes{
+ ?singleInt32:_?._(
+ {}~map(dyn, int),
+ "i"
+ )~optional_type(int)^select_optional_field
+ }~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes`,
+ outType: types.NewObjectType(
+ "google.expr.proto2.test.TestAllTypes",
+ ),
+ },
+ {
+ in: `TestAllTypes{?singleInt32: {'i': 20}.?i}.singleInt32`,
+ container: "google.expr.proto2.test",
+ env: testEnv{optionalSyntax: true, jsonFieldNames: true},
+ out: `google.expr.proto2.test.TestAllTypes{
+ ?singleInt32:_?._(
+ {
+ "i"~string:20~int
+ }~map(string, int),
+ "i"
+ )~optional_type(int)^select_optional_field
+ }~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes.singleInt32~int`,
+ outType: types.IntType,
+ },
+ {
+ in: `TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32`,
+ container: "google.expr.proto2.test",
+ env: testEnv{optionalSyntax: true, jsonFieldNames: true},
+ err: `ERROR: :1:41: undefined field 'single_bool'
+ | TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32
+ | ........................................^`,
+ },
}
}
@@ -2470,6 +2506,7 @@ type testEnv struct {
functions []*decls.FunctionDecl
variadicASTs bool
optionalSyntax bool
+ jsonFieldNames bool
}
func TestCheck(t *testing.T) {
@@ -2505,9 +2542,12 @@ func TestCheck(t *testing.T) {
t.Fatalf("Unexpected parse errors: %v", errors.ToDisplayString())
}
- reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{})
+ reg, err := types.NewProtoRegistry(
+ types.JSONFieldNames(tc.env.jsonFieldNames),
+ types.ProtoTypes(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}),
+ )
if err != nil {
- t.Fatalf("types.NewRegistry() failed: %v", err)
+ t.Fatalf("types.NewProtoRegistry() failed: %v", err)
}
if tc.env.optionalSyntax {
if err := reg.RegisterType(types.OptionalType); err != nil {
@@ -2522,6 +2562,9 @@ func TestCheck(t *testing.T) {
if len(tc.opts) != 0 {
opts = tc.opts
}
+ if tc.env.jsonFieldNames {
+ opts = append(opts, JSONFieldNames(true))
+ }
env, err := NewEnv(cont, reg, opts...)
if err != nil {
t.Fatalf("NewEnv(cont, reg) failed: %v", err)
diff --git a/checker/env.go b/checker/env.go
index 16c8ae60f..477918c48 100644
--- a/checker/env.go
+++ b/checker/env.go
@@ -74,6 +74,7 @@ type Env struct {
declarations *Scopes
aggLitElemType aggregateLiteralElementType
filteredOverloadIDs map[string]struct{}
+ jsonFieldNames bool
}
// NewEnv returns a new *Env with the given parameters.
@@ -104,6 +105,7 @@ func NewEnv(container *containers.Container, provider types.Provider, opts ...Op
declarations: declarations,
aggLitElemType: aggLitElemType,
filteredOverloadIDs: filteredOverloadIDs,
+ jsonFieldNames: envOptions.jsonFieldNames,
}, nil
}
diff --git a/checker/options.go b/checker/options.go
index 0560c3813..af714323b 100644
--- a/checker/options.go
+++ b/checker/options.go
@@ -18,6 +18,7 @@ type options struct {
crossTypeNumericComparisons bool
homogeneousAggregateLiterals bool
validatedDeclarations *Scopes
+ jsonFieldNames bool
}
// Option is a functional option for configuring the type-checker
@@ -40,3 +41,11 @@ func ValidatedDeclarations(env *Env) Option {
return nil
}
}
+
+// JSONFieldNames enables the use of json names instead of the standard protobuf snake_case field names
+func JSONFieldNames(enabled bool) Option {
+ return func(opts *options) error {
+ opts.jsonFieldNames = enabled
+ return nil
+ }
+}
diff --git a/common/ast/ast.go b/common/ast/ast.go
index 3c5ee0c80..aae2a83e9 100644
--- a/common/ast/ast.go
+++ b/common/ast/ast.go
@@ -231,6 +231,11 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo {
for id, call := range info.macroCalls {
callsCopy[id] = defaultFactory.CopyExpr(call)
}
+ var extCopy []Extension
+ if len(info.extensions) > 0 {
+ extCopy = make([]Extension, len(info.extensions))
+ copy(extCopy, info.extensions)
+ }
return &SourceInfo{
syntax: info.syntax,
desc: info.desc,
@@ -239,6 +244,7 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo {
baseCol: info.baseCol,
offsetRanges: rangesCopy,
macroCalls: callsCopy,
+ extensions: extCopy,
}
}
@@ -252,6 +258,9 @@ type SourceInfo struct {
baseCol int32
offsetRanges map[int64]OffsetRange
macroCalls map[int64]Expr
+
+ // extensions indicate versioned optional features which affect the execution of one or more CEL component.
+ extensions []Extension
}
// RenumberIDs performs an in-place update of the expression IDs within the SourceInfo.
@@ -420,6 +429,23 @@ func (s *SourceInfo) ComputeOffsetAbsolute(line, col int32) int32 {
return offset + col
}
+// Extensions returns the set of extensions present in the source.
+func (s *SourceInfo) Extensions() []Extension {
+ var extensions []Extension
+ if s == nil {
+ return extensions
+ }
+ return s.extensions
+}
+
+// AddExtension adds an extension record into the SourceInfo.
+func (s *SourceInfo) AddExtension(ext Extension) {
+ if s == nil {
+ return
+ }
+ s.extensions = append(s.extensions, ext)
+}
+
// OffsetRange captures the start and stop positions of a section of text in the input expression.
type OffsetRange struct {
Start int32
@@ -489,6 +515,53 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
return true
}
+// NewExtension creates an Extension to be recorded on the SourceInfo.
+func NewExtension(id string, version ExtensionVersion, components ...ExtensionComponent) Extension {
+ return Extension{
+ ID: id,
+ Version: version,
+ Components: components,
+ }
+}
+
+// Extension represents a versioned, optional feature present in the AST that affects CEL component behavior.
+type Extension struct {
+ // ID indicates the unique name of the extension.
+ ID string
+ // Version indicates the major / minor version.
+ Version ExtensionVersion
+ // Components enumerates the CEL components affected by the feature.
+ Components []ExtensionComponent
+}
+
+// NewExtensionVersion creates a new extension version with a major, minor version.
+func NewExtensionVersion(major, minor int64) ExtensionVersion {
+ return ExtensionVersion{Major: major, Minor: minor}
+}
+
+// ExtensionVersion represents a semantic version with a major and minor number.
+type ExtensionVersion struct {
+ // Major version of the extension.
+ // All versions with the same major number are expected to be compatible with all minor version changes.
+ Major int64
+
+ // Minor version of the extension which indicates that some small non-semantic change has been made to
+ // the extension.
+ Minor int64
+}
+
+// ExtensionComponent indicates which CEL component is affected.
+type ExtensionComponent int
+
+const (
+ // ComponentParser means the feature affects expression parsing.
+ ComponentParser ExtensionComponent = iota + 1
+ // ComponentTypeChecker means the feature affects type-checking.
+ ComponentTypeChecker
+ // ComponentRuntime alters program planning or evaluation of the AST.
+ ComponentRuntime
+)
+
type maxIDVisitor struct {
maxID int64
*baseVisitor
diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go
index 0a86c3936..17c4c82c2 100644
--- a/common/ast/ast_test.go
+++ b/common/ast/ast_test.go
@@ -20,6 +20,7 @@ import (
"reflect"
"testing"
+ "github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
@@ -83,6 +84,55 @@ func TestASTCopy(t *testing.T) {
}
}
+func TestASTJsonNames(t *testing.T) {
+ tests := []string{
+ `google.expr.proto3.test.TestAllTypes{}`,
+ `google.expr.proto3.test.TestAllTypes{repeatedInt32: [1, 2]}`,
+ `google.expr.proto3.test.TestAllTypes{singleInt32: 2}.singleInt32 == 2`,
+ }
+
+ for _, tst := range tests {
+ checked := mustTypeCheck(t, tst, checker.JSONFieldNames(true), types.JSONFieldNames(true))
+ copyChecked := ast.Copy(checked)
+ if !reflect.DeepEqual(copyChecked.Expr(), checked.Expr()) {
+ t.Errorf("Copy() got expr %v, wanted %v", copyChecked.Expr(), checked.Expr())
+ }
+ if !reflect.DeepEqual(copyChecked.SourceInfo(), checked.SourceInfo()) {
+ t.Errorf("Copy() got source info %v, wanted %v", copyChecked.SourceInfo(), checked.SourceInfo())
+ }
+ copyParsed := ast.Copy(ast.NewAST(checked.Expr(), checked.SourceInfo()))
+ if !reflect.DeepEqual(copyParsed.Expr(), checked.Expr()) {
+ t.Errorf("Copy() got expr %v, wanted %v", copyParsed.Expr(), checked.Expr())
+ }
+ if !reflect.DeepEqual(copyParsed.SourceInfo(), checked.SourceInfo()) {
+ t.Errorf("Copy() got source info %v, wanted %v", copyParsed.SourceInfo(), checked.SourceInfo())
+ }
+ checkedPB, err := ast.ToProto(checked)
+ if err != nil {
+ t.Errorf("ast.ToProto() failed: %v", err)
+ }
+ copyCheckedPB, err := ast.ToProto(copyChecked)
+ if err != nil {
+ t.Errorf("ast.ToProto() failed: %v", err)
+ }
+ if !proto.Equal(checkedPB, copyCheckedPB) {
+ t.Errorf("Copy() produced different proto results, got %v, wanted %v",
+ prototext.Format(checkedPB), prototext.Format(copyCheckedPB))
+ }
+ checkedRoundtrip, err := ast.ToAST(checkedPB)
+ if err != nil {
+ t.Errorf("ast.ToAST() failed: %v", err)
+ }
+ same := reflect.DeepEqual(checked.Expr(), checkedRoundtrip.Expr()) &&
+ reflect.DeepEqual(checked.ReferenceMap(), checkedRoundtrip.ReferenceMap()) &&
+ reflect.DeepEqual(checked.TypeMap(), checkedRoundtrip.TypeMap()) &&
+ reflect.DeepEqual(checked.SourceInfo().MacroCalls(), checkedRoundtrip.SourceInfo().MacroCalls())
+ if !same {
+ t.Errorf("Roundtrip got %v, wanted %v", checkedRoundtrip, checked)
+ }
+ }
+}
+
func TestASTNilSafety(t *testing.T) {
ex, err := ast.ProtoToExpr(nil)
if err != nil {
@@ -184,6 +234,9 @@ func TestSourceInfoNilSafety(t *testing.T) {
if len(testInfo.MacroCalls()) != 0 {
t.Errorf("MacroCalls() got %v, wanted empty map", testInfo.MacroCalls())
}
+ if len(testInfo.Extensions()) != 0 {
+ t.Errorf("Extensions() got %v, wanted empty list", testInfo.Extensions())
+ }
if call, found := testInfo.GetMacroCall(0); found {
t.Errorf("GetMacroCall(0) got %v, wanted not found", call)
}
diff --git a/common/ast/conversion.go b/common/ast/conversion.go
index 435d8f654..380f8c118 100644
--- a/common/ast/conversion.go
+++ b/common/ast/conversion.go
@@ -27,6 +27,19 @@ import (
structpb "google.golang.org/protobuf/types/known/structpb"
)
+var (
+ pbComponentMap = map[exprpb.SourceInfo_Extension_Component]ExtensionComponent{
+ exprpb.SourceInfo_Extension_COMPONENT_PARSER: ComponentParser,
+ exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER: ComponentTypeChecker,
+ exprpb.SourceInfo_Extension_COMPONENT_RUNTIME: ComponentRuntime,
+ }
+ componentPBMap = map[ExtensionComponent]exprpb.SourceInfo_Extension_Component{
+ ComponentParser: exprpb.SourceInfo_Extension_COMPONENT_PARSER,
+ ComponentTypeChecker: exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER,
+ ComponentRuntime: exprpb.SourceInfo_Extension_COMPONENT_RUNTIME,
+ }
+)
+
// ToProto converts an AST to a CheckedExpr protobouf.
func ToProto(ast *AST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap()))
@@ -534,6 +547,25 @@ func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) {
}
sourceInfo.MacroCalls[id] = call
}
+ for _, ext := range info.Extensions() {
+ var components []exprpb.SourceInfo_Extension_Component
+ for _, c := range ext.Components {
+ comp, found := componentPBMap[c]
+ if found {
+ components = append(components, comp)
+ }
+ }
+ ver := &exprpb.SourceInfo_Extension_Version{
+ Major: ext.Version.Major,
+ Minor: ext.Version.Minor,
+ }
+ pbExt := &exprpb.SourceInfo_Extension{
+ Id: ext.ID,
+ Version: ver,
+ AffectedComponents: components,
+ }
+ sourceInfo.Extensions = append(sourceInfo.Extensions, pbExt)
+ }
return sourceInfo, nil
}
@@ -556,6 +588,23 @@ func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) {
}
sourceInfo.SetMacroCall(id, call)
}
+ for _, pbExt := range info.GetExtensions() {
+ var components []ExtensionComponent
+ for _, c := range pbExt.GetAffectedComponents() {
+ comp, found := pbComponentMap[*c.Enum()]
+ if found {
+ components = append(components, comp)
+ }
+ }
+ sourceInfo.AddExtension(NewExtension(
+ pbExt.GetId(),
+ NewExtensionVersion(
+ pbExt.GetVersion().GetMajor(),
+ pbExt.GetVersion().GetMinor(),
+ ),
+ components...,
+ ))
+ }
return sourceInfo, nil
}
diff --git a/common/ast/conversion_test.go b/common/ast/conversion_test.go
index 86b00e3e3..ecd50830b 100644
--- a/common/ast/conversion_test.go
+++ b/common/ast/conversion_test.go
@@ -38,10 +38,12 @@ import (
func TestConvertAST(t *testing.T) {
fac := ast.NewExprFactory()
tests := []struct {
+ name string
goAST *ast.AST
pbAST *exprpb.CheckedExpr
}{
{
+ name: "simple ast",
goAST: ast.NewCheckedAST(ast.NewAST(nil, nil),
map[int64]*types.Type{
1: types.BoolType,
@@ -71,6 +73,7 @@ func TestConvertAST(t *testing.T) {
},
},
{
+ name: "comprehension ast",
goAST: ast.NewAST(
fac.NewComprehensionTwoVar(1,
fac.NewIdent(2, "data"),
@@ -179,11 +182,74 @@ func TestConvertAST(t *testing.T) {
ReferenceMap: map[int64]*exprpb.Reference{},
},
},
+ {
+ name: "json names ast",
+ goAST: ast.NewCheckedAST(
+ ast.NewAST(
+ fac.NewCall(2, overloads.LogicalNot, fac.NewIdent(1, "value")),
+ sourceWithExtension(
+ ast.NewSourceInfo(common.NewTextSource("!value")),
+ ast.NewExtension("json_name", ast.NewExtensionVersion(1, 1), ast.ComponentRuntime),
+ ),
+ ),
+ map[int64]*types.Type{
+ 1: types.BoolType,
+ 2: types.BoolType,
+ },
+ map[int64]*ast.ReferenceInfo{
+ 1: ast.NewFunctionReference(overloads.LogicalNot),
+ 2: ast.NewIdentReference("value", nil),
+ },
+ ),
+ pbAST: &exprpb.CheckedExpr{
+ Expr: &exprpb.Expr{
+ Id: 2,
+ ExprKind: &exprpb.Expr_CallExpr{
+ CallExpr: &exprpb.Expr_Call{
+ Function: "!_",
+ Args: []*exprpb.Expr{
+ {
+ Id: 1,
+ ExprKind: &exprpb.Expr_IdentExpr{
+ IdentExpr: &exprpb.Expr_Ident{
+ Name: "value",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ SourceInfo: &exprpb.SourceInfo{
+ Location: "",
+ Extensions: []*exprpb.SourceInfo_Extension{
+ {
+ Id: "json_name",
+ Version: &exprpb.SourceInfo_Extension_Version{
+ Major: 1,
+ Minor: 1,
+ },
+ AffectedComponents: []exprpb.SourceInfo_Extension_Component{
+ exprpb.SourceInfo_Extension_COMPONENT_RUNTIME,
+ },
+ },
+ },
+ },
+ TypeMap: map[int64]*exprpb.Type{
+ 1: chkdecls.Bool,
+ 2: chkdecls.Bool,
+ },
+ ReferenceMap: map[int64]*exprpb.Reference{
+ 1: {OverloadId: []string{overloads.LogicalNot}},
+ 2: {Name: "value"},
+ },
+ },
+ },
}
- for i, tst := range tests {
+ for _, tst := range tests {
tc := tst
- t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
+ t.Run(tc.name, func(t *testing.T) {
goAST := tc.goAST
pbAST := tc.pbAST
checkedAST, err := ast.ToAST(pbAST)
@@ -194,6 +260,10 @@ func TestConvertAST(t *testing.T) {
!reflect.DeepEqual(checkedAST.TypeMap(), goAST.TypeMap()) {
t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST)
}
+ if !reflect.DeepEqual(checkedAST.SourceInfo().Extensions(), goAST.SourceInfo().Extensions()) {
+ t.Errorf("conversion to AST did not preserve SourceInfo extensions. got %v, wanted %v",
+ checkedAST.SourceInfo().Extensions(), goAST.SourceInfo().Extensions())
+ }
if len(checkedAST.ReferenceMap()) > 2 {
if !checkedAST.ReferenceMap()[1].Equals(goAST.ReferenceMap()[1]) ||
!checkedAST.ReferenceMap()[2].Equals(goAST.ReferenceMap()[2]) {
@@ -656,3 +726,8 @@ func TestConstantToValError(t *testing.T) {
t.Errorf("ConstantToVal() got %v, wanted error", out)
}
}
+
+func sourceWithExtension(info *ast.SourceInfo, ext ast.Extension) *ast.SourceInfo {
+ info.AddExtension(ext)
+ return info
+}
diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go
index 5afb23ab7..bd8fa8ac1 100644
--- a/common/ast/navigable_test.go
+++ b/common/ast/navigable_test.go
@@ -18,8 +18,6 @@ import (
"reflect"
"testing"
- "google.golang.org/protobuf/proto"
-
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
@@ -556,7 +554,7 @@ func TestNavigableSelectExpr_TestOnly(t *testing.T) {
}
}
-func mustTypeCheck(t testing.TB, expr string) *ast.AST {
+func mustTypeCheck(t testing.TB, expr string, opts ...any) *ast.AST {
t.Helper()
p, err := parser.NewParser(
parser.Macros(parser.AllMacros...),
@@ -570,8 +568,21 @@ func mustTypeCheck(t testing.TB, expr string) *ast.AST {
if len(iss.GetErrors()) != 0 {
t.Fatalf("Parse(%s) failed: %s", expr, iss.ToDisplayString())
}
- reg := newTestRegistry(t, &proto3pb.TestAllTypes{})
- env := newTestEnv(t, containers.DefaultContainer, reg)
+ regOpts := []types.RegistryOption{}
+ chkOpts := []checker.Option{}
+ for _, opt := range opts {
+ switch v := opt.(type) {
+ case types.RegistryOption:
+ regOpts = append(regOpts, v)
+ case checker.Option:
+ chkOpts = append(chkOpts, v)
+ default:
+ t.Fatalf("mustTypeCheck() failed with invalid option type: %T", v)
+ }
+ }
+ regOpts = append(regOpts, types.ProtoTypes(&proto3pb.TestAllTypes{}))
+ reg := newTestRegistry(t, regOpts...)
+ env := newTestEnv(t, containers.DefaultContainer, reg, chkOpts...)
checked, iss := checker.Check(parsed, exprSrc, env)
if len(iss.GetErrors()) != 0 {
t.Fatalf("Check(%s) failed: %s", expr, iss.ToDisplayString())
@@ -579,18 +590,20 @@ func mustTypeCheck(t testing.TB, expr string) *ast.AST {
return checked
}
-func newTestRegistry(t testing.TB, msgs ...proto.Message) *types.Registry {
+func newTestRegistry(t testing.TB, opts ...types.RegistryOption) *types.Registry {
t.Helper()
- reg, err := types.NewRegistry(msgs...)
+ reg, err := types.NewProtoRegistry(opts...)
if err != nil {
- t.Fatalf("types.NewRegistry(%v) failed: %v", msgs, err)
+ t.Fatalf("types.NewProtoRegistry() failed: %v", err)
}
return reg
}
-func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) *checker.Env {
+func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry, opts ...checker.Option) *checker.Env {
t.Helper()
- env, err := checker.NewEnv(cont, reg, checker.CrossTypeNumericComparisons(true))
+ chkOpts := []checker.Option{checker.CrossTypeNumericComparisons(true)}
+ chkOpts = append(chkOpts, opts...)
+ env, err := checker.NewEnv(cont, reg, chkOpts...)
if err != nil {
t.Fatalf("checker.NewEnv(%v, %v) failed: %v", cont, reg, err)
}
diff --git a/common/types/map_test.go b/common/types/map_test.go
index a45a56c85..5b9898c27 100644
--- a/common/types/map_test.go
+++ b/common/types/map_test.go
@@ -44,7 +44,7 @@ type testStruct struct {
}
func TestMapContains(t *testing.T) {
- reg := newTestRegistry(t, &proto3pb.TestAllTypes{})
+ reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{}))
reflectMap := reg.NativeToValue(map[any]any{
int64(1): "hello",
uint64(2): "world",
@@ -582,7 +582,7 @@ func TestMapIsZeroValue(t *testing.T) {
"hello": "world",
},
}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
obj := reg.NativeToValue(msg).(traits.Indexer)
tests := []struct {
@@ -749,7 +749,7 @@ func TestProtoMap(t *testing.T) {
"welcome": "back",
}
msg := &proto3pb.TestAllTypes{MapStringString: strMap}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
obj := reg.NativeToValue(msg).(traits.Indexer)
// Test a simple proto map of string string.
@@ -850,7 +850,7 @@ func TestProtoMapGet(t *testing.T) {
"welcome": "back",
}
msg := &proto3pb.TestAllTypes{MapStringString: strMap}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
obj := reg.NativeToValue(msg).(traits.Indexer)
field := obj.Get(String("map_string_string"))
mapVal, ok := field.(traits.Mapper)
@@ -890,7 +890,7 @@ func TestProtoMapConvertToNative(t *testing.T) {
"welcome": "back",
}
msg := &proto3pb.TestAllTypes{MapStringString: strMap}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
obj := reg.NativeToValue(msg).(traits.Indexer)
// Test a simple proto map of string string.
field := obj.Get(String("map_string_string"))
@@ -974,7 +974,7 @@ func TestProtoMapConvertToNative_NestedProto(t *testing.T) {
},
}
msg := &proto3pb.TestAllTypes{MapInt64NestedType: nestedTypeMap}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
obj := reg.NativeToValue(msg).(traits.Indexer)
// Test a simple proto map of string string.
field := obj.Get(String("map_int64_nested_type"))
diff --git a/common/types/object_test.go b/common/types/object_test.go
index 10d05333e..1f89cd165 100644
--- a/common/types/object_test.go
+++ b/common/types/object_test.go
@@ -49,75 +49,134 @@ func TestNewProtoObject(t *testing.T) {
}
func TestProtoObjectConvertToNative(t *testing.T) {
- reg := newTestRegistry(t, &exprpb.Expr{})
msg := &exprpb.ParsedExpr{
+ Expr: &exprpb.Expr{
+ Id: 1,
+ ExprKind: &exprpb.Expr_ConstExpr{
+ ConstExpr: &exprpb.Constant{
+ ConstantKind: &exprpb.Constant_BoolValue{
+ BoolValue: true,
+ },
+ },
+ },
+ },
SourceInfo: &exprpb.SourceInfo{
LineOffsets: []int32{1, 2, 3}}}
- objVal := reg.NativeToValue(msg)
- // Proto Message
- val, err := objVal.ConvertToNative(reflect.TypeOf(&exprpb.ParsedExpr{}))
- if err != nil {
- t.Error(err)
- }
- if !proto.Equal(val.(proto.Message), msg) {
- t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg)
+ tests := []struct {
+ opts []RegistryOption
+ fieldMap func(reg *Registry) map[string]ref.Val
+ outValue map[string]any
+ }{
+ {
+ opts: []RegistryOption{ProtoTypes(&exprpb.Expr{}), JSONFieldNames(true)},
+ fieldMap: func(reg *Registry) map[string]ref.Val {
+ return map[string]ref.Val{
+ "expr": reg.NativeToValue(msg.GetExpr()),
+ "sourceInfo": reg.NativeToValue(msg.GetSourceInfo()),
+ }
+ },
+ outValue: map[string]any{
+ "expr": map[string]any{
+ // The id is encoded to a string because it is int64 type, though
+ // this is not exactly what's documented in protojson serialization
+ // and so could signal a bug in the json-handling for int values
+ // as the serialization to int or string is based on the value precision.
+ "id": "1",
+ "constExpr": map[string]any{
+ "boolValue": true,
+ },
+ },
+ "sourceInfo": map[string]any{
+ "lineOffsets": []any{1.0, 2.0, 3.0},
+ },
+ },
+ },
+ {
+ opts: []RegistryOption{ProtoTypes(&exprpb.Expr{}), JSONFieldNames(false)},
+ fieldMap: func(reg *Registry) map[string]ref.Val {
+ return map[string]ref.Val{
+ "expr": reg.NativeToValue(msg.GetExpr()),
+ "source_info": reg.NativeToValue(msg.GetSourceInfo()),
+ }
+ },
+ outValue: map[string]any{
+ "expr": map[string]any{
+ "id": "1",
+ "constExpr": map[string]any{
+ "boolValue": true,
+ },
+ },
+ "sourceInfo": map[string]any{
+ "lineOffsets": []any{1.0, 2.0, 3.0},
+ },
+ },
+ },
}
- // Dynamic protobuf
- dynPB := reg.NewValue(
- string(msg.ProtoReflect().Descriptor().FullName()),
- map[string]ref.Val{
- "source_info": reg.NativeToValue(msg.GetSourceInfo()),
- })
- if IsError(dynPB) {
- t.Fatalf("reg.NewValue() failed: %v", dynPB)
- }
- dynVal := reg.NativeToValue(dynPB)
- val, err = dynVal.ConvertToNative(reflect.TypeOf(msg))
- if err != nil {
- t.Fatalf("dynVal.ConvertToNative() failed: %v", err)
- }
- if !proto.Equal(val.(proto.Message), msg) {
- t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg)
- }
+ for _, tst := range tests {
+ reg := newTestRegistry(t, tst.opts...)
+ objVal := reg.NativeToValue(msg)
- // google.protobuf.Any
- anyVal, err := objVal.ConvertToNative(anyValueType)
- if err != nil {
- t.Fatalf("objVal.ConvertToNative() failed: %v", err)
- }
- anyMsg := anyVal.(*anypb.Any)
- unpackedAny, err := anyMsg.UnmarshalNew()
- if err != nil {
- t.Fatalf("UnmarshalNew() failed: %v", err)
- }
- if !proto.Equal(unpackedAny, objVal.Value().(proto.Message)) {
- t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), unpackedAny)
- }
+ // Proto Message
+ val, err := objVal.ConvertToNative(reflect.TypeOf(&exprpb.ParsedExpr{}))
+ if err != nil {
+ t.Error(err)
+ }
+ if !proto.Equal(val.(proto.Message), msg) {
+ t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg)
+ }
- // JSON
- jsonVal, err := objVal.ConvertToNative(JSONValueType)
- if err != nil {
- t.Fatalf("objVal.ConvertToNative(%v) failed: %v", JSONValueType, err)
- }
- jsonBytes, err := protojson.Marshal(jsonVal.(proto.Message))
- jsonTxt := string(jsonBytes)
- if err != nil {
- t.Fatalf("protojson.Marshal(%v) failed: %v", jsonVal, err)
- }
- outMap := map[string]any{}
- err = json.Unmarshal(jsonBytes, &outMap)
- if err != nil {
- t.Fatalf("json.Unmarshal(%q) failed: %v", jsonTxt, err)
- }
- want := map[string]any{
- "sourceInfo": map[string]any{
- "lineOffsets": []any{1.0, 2.0, 3.0},
- },
- }
- if !reflect.DeepEqual(outMap, want) {
- t.Errorf("got json '%v', expected %v", outMap, want)
+ // Dynamic protobuf
+ dynPB := reg.NewValue(
+ string(msg.ProtoReflect().Descriptor().FullName()),
+ tst.fieldMap(reg),
+ )
+ if IsError(dynPB) {
+ t.Fatalf("reg.NewValue() failed: %v", dynPB)
+ }
+ dynVal := reg.NativeToValue(dynPB)
+ val, err = dynVal.ConvertToNative(reflect.TypeOf(msg))
+ if err != nil {
+ t.Fatalf("dynVal.ConvertToNative() failed: %v", err)
+ }
+ if !proto.Equal(val.(proto.Message), msg) {
+ t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg)
+ }
+
+ // google.protobuf.Any
+ anyVal, err := objVal.ConvertToNative(anyValueType)
+ if err != nil {
+ t.Fatalf("objVal.ConvertToNative() failed: %v", err)
+ }
+ anyMsg := anyVal.(*anypb.Any)
+ unpackedAny, err := anyMsg.UnmarshalNew()
+ if err != nil {
+ t.Fatalf("UnmarshalNew() failed: %v", err)
+ }
+ if !proto.Equal(unpackedAny, objVal.Value().(proto.Message)) {
+ t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), unpackedAny)
+ }
+
+ // JSON
+ jsonVal, err := objVal.ConvertToNative(JSONValueType)
+ if err != nil {
+ t.Fatalf("objVal.ConvertToNative(%v) failed: %v", JSONValueType, err)
+ }
+ jsonBytes, err := protojson.Marshal(jsonVal.(proto.Message))
+ jsonTxt := string(jsonBytes)
+ if err != nil {
+ t.Fatalf("protojson.Marshal(%v) failed: %v", jsonVal, err)
+ }
+ outMap := map[string]any{}
+ err = json.Unmarshal(jsonBytes, &outMap)
+ if err != nil {
+ t.Fatalf("json.Unmarshal(%q) failed: %v", jsonTxt, err)
+ }
+ want := tst.outValue
+ if !reflect.DeepEqual(outMap, want) {
+ t.Errorf("got json '%v', expected %v", outMap, want)
+ }
}
}
@@ -127,7 +186,7 @@ func TestProtoObjectIsSet(t *testing.T) {
LineOffsets: []int32{1, 2, 3},
},
}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
objVal := reg.NativeToValue(msg).(*protoObj)
if objVal.IsSet(String("source_info")) != True {
t.Error("got 'source_info' not set, wanted set")
@@ -144,7 +203,7 @@ func TestProtoObjectIsSet(t *testing.T) {
}
func TestProtoObjectIsZeroValue(t *testing.T) {
- reg := newTestRegistry(t, &exprpb.ParsedExpr{})
+ reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{}))
emptyObj := reg.NativeToValue(&exprpb.ParsedExpr{})
pb, ok := emptyObj.(traits.Zeroer)
if !ok {
@@ -166,7 +225,7 @@ func TestProtoObjectGet(t *testing.T) {
LineOffsets: []int32{1, 2, 3},
},
}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
objVal := reg.NativeToValue(msg).(*protoObj)
if objVal.Get(String("source_info")).Equal(reg.NativeToValue(msg.GetSourceInfo())) != True {
t.Error("could not get 'source_info'")
@@ -188,7 +247,7 @@ func TestProtoObjectConvertToType(t *testing.T) {
LineOffsets: []int32{1, 2, 3},
},
}
- reg := newTestRegistry(t, msg)
+ reg := newTestRegistry(t, ProtoTypes(msg))
objVal := reg.NativeToValue(msg)
tv := objVal.Type().(ref.Val)
if objVal.ConvertToType(TypeType).Equal(tv) != True {
diff --git a/common/types/pb/file.go b/common/types/pb/file.go
index e323afb1d..3a8bdf0b2 100644
--- a/common/types/pb/file.go
+++ b/common/types/pb/file.go
@@ -32,7 +32,7 @@ func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDe
}
types := make(map[string]*TypeDescription)
for name, msgType := range metadata.msgTypes {
- types[name] = newTypeDescription(name, msgType, pbdb.extensions)
+ types[name] = newTypeDescription(name, msgType, pbdb)
}
fileExtMap := make(extensionMap)
for typeName, extensions := range metadata.msgExtensionMap {
@@ -42,12 +42,13 @@ func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDe
}
for _, ext := range extensions {
extDesc := dynamicpb.NewExtensionType(ext).TypeDescriptor()
- messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc)
+ messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc, pbdb.jsonFieldNames)
}
fileExtMap[typeName] = messageExtMap
}
return &FileDescription{
name: fileDesc.Path(),
+ desc: fileDesc,
types: types,
enums: enums,
}, fileExtMap
@@ -56,6 +57,7 @@ func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDe
// FileDescription holds a map of all types and enum values declared within a proto file.
type FileDescription struct {
name string
+ desc protoreflect.FileDescriptor
types map[string]*TypeDescription
enums map[string]*EnumValueDescription
}
@@ -68,6 +70,7 @@ func (fd *FileDescription) Copy(pbdb *Db) *FileDescription {
}
return &FileDescription{
name: fd.name,
+ desc: fd.desc,
types: typesCopy,
enums: fd.enums,
}
@@ -78,6 +81,11 @@ func (fd *FileDescription) GetName() string {
return fd.name
}
+// FileDescriptor returns the proto file descriptor associated with the file representation.
+func (fd *FileDescription) FileDescriptor() protoreflect.FileDescriptor {
+ return fd.desc
+}
+
// GetEnumDescription returns an EnumDescription for a qualified enum value
// name declared within the .proto file.
func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, bool) {
diff --git a/common/types/pb/file_test.go b/common/types/pb/file_test.go
index 6a6a5a0f9..a27d097fb 100644
--- a/common/types/pb/file_test.go
+++ b/common/types/pb/file_test.go
@@ -76,6 +76,19 @@ func TestFileDescriptionGetExtensions(t *testing.T) {
}
}
+func TestFileDescriptionJSONFieldNames(t *testing.T) {
+ pbdb := NewDb(JSONFieldNames(true))
+ msg := &proto2pb.TestAllTypes{}
+ fd, err := pbdb.RegisterMessage(msg)
+ if err != nil {
+ t.Fatalf("pbdb.RegisterMessage() failed: %v", err)
+ }
+ fileDesc := msg.ProtoReflect().Descriptor().ParentFile()
+ if fd.FileDescriptor() != fileDesc {
+ t.Errorf("got %v, wanted %v file descriptor", fd.FileDescriptor(), fileDesc)
+ }
+}
+
func TestFileDescriptionGetTypes(t *testing.T) {
pbdb := NewDb()
fd, err := pbdb.RegisterMessage(&proto3pb.TestAllTypes{})
diff --git a/common/types/pb/pb.go b/common/types/pb/pb.go
index eadebcb04..c6fdfc695 100644
--- a/common/types/pb/pb.go
+++ b/common/types/pb/pb.go
@@ -42,6 +42,9 @@ type Db struct {
files []*FileDescription
// extensions contains the mapping between a given type name, extension name and its FieldDescription
extensions map[string]map[string]*FieldDescription
+
+ // jsonFieldNames indicates whether json-style names are supported as proto field names.
+ jsonFieldNames bool
}
// extensionsMap is a type alias to a map[typeName]map[extensionName]*FieldDescription
@@ -81,13 +84,27 @@ func Merge(dstPB, srcPB proto.Message) error {
return nil
}
+// DbOption modifies feature flags enabled on the proto database.
+type DbOption func(*Db) *Db
+
+// JSONFieldNames configures the Db to support proto field accesses by their JSON names.
+func JSONFieldNames(enabled bool) DbOption {
+ return func(db *Db) *Db {
+ db.jsonFieldNames = enabled
+ return db
+ }
+}
+
// NewDb creates a new `pb.Db` with an empty type name to file description map.
-func NewDb() *Db {
+func NewDb(opts ...DbOption) *Db {
pbdb := &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{},
extensions: make(extensionMap),
}
+ for _, o := range opts {
+ pbdb = o(pbdb)
+ }
// The FileDescription objects in the default db contain lazily initialized TypeDescription
// values which may point to the state contained in the DefaultDb irrespective of this shallow
// copy; however, the type graph for a field is idempotently computed, and is guaranteed to
@@ -100,9 +117,15 @@ func NewDb() *Db {
return pbdb
}
+// JSONFieldNames indicates whether the database is configured for proto field accesses by JSON names.
+func (pbdb *Db) JSONFieldNames() bool {
+ return pbdb.jsonFieldNames
+}
+
// Copy creates a copy of the current database with its own internal descriptor mapping.
func (pbdb *Db) Copy() *Db {
copy := NewDb()
+ copy.jsonFieldNames = pbdb.jsonFieldNames
for _, fd := range pbdb.files {
hasFile := false
for _, fd2 := range copy.files {
diff --git a/common/types/pb/pb_test.go b/common/types/pb/pb_test.go
index 54d39477b..aef1c02b4 100644
--- a/common/types/pb/pb_test.go
+++ b/common/types/pb/pb_test.go
@@ -29,6 +29,69 @@ import (
tpb "google.golang.org/protobuf/types/known/timestamppb"
)
+func TestDbJSONFieldNames(t *testing.T) {
+ pbdb := NewDb(JSONFieldNames(true))
+ if pbdb.JSONFieldNames() != true {
+ t.Errorf("pbdb.JSONFieldNames() got %v, wanted true", pbdb.JSONFieldNames())
+ }
+ fd, err := pbdb.RegisterMessage(&proto2pb.TestAllTypes{})
+ if err != nil {
+ t.Fatalf("pbdb.RegisterMessage() failed: %v", err)
+ }
+ td, found := fd.GetTypeDescription("google.expr.proto2.test.TestAllTypes")
+ if !found {
+ t.Fatal("fd.GetTypeDescription() not found")
+ }
+ var fieldNames []string
+ for fieldName, f := range td.FieldMap() {
+ fieldNames = append(fieldNames, fieldName)
+ if f.jsonFieldName != true {
+ t.Error("f.jsonFieldName did not propagate")
+ }
+ }
+ // Note that 'group' type names don't have camelCase representations.
+ wantFields := []string{"singleInt32", "repeatedInt64", "nestedgroup"}
+ for _, want := range wantFields {
+ found := false
+ for _, field := range fieldNames {
+ if field == want {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("%v field name not found", want)
+ }
+ }
+ copied := pbdb.Copy()
+ if copied.JSONFieldNames() != true {
+ t.Errorf("copied.JSONFieldNames() got %v, wanted true", copied.JSONFieldNames())
+ }
+ td, found = copied.DescribeType("google.expr.proto2.test.TestAllTypes")
+ if !found {
+ t.Fatal("copied.DescribeType() not found")
+ }
+ fieldNames = []string{}
+ for fieldName, f := range td.FieldMap() {
+ fieldNames = append(fieldNames, fieldName)
+ if f.jsonFieldName != true {
+ t.Error("f.jsonFieldName did not propagate")
+ }
+ }
+ for _, want := range wantFields {
+ found := false
+ for _, field := range fieldNames {
+ if field == want {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("%v field name not found", want)
+ }
+ }
+}
+
func TestDbCopy(t *testing.T) {
clone := DefaultDb.Copy()
if !reflect.DeepEqual(clone, DefaultDb) {
diff --git a/common/types/pb/type.go b/common/types/pb/type.go
index 171494f07..18564642c 100644
--- a/common/types/pb/type.go
+++ b/common/types/pb/type.go
@@ -40,53 +40,68 @@ type description interface {
// newTypeDescription produces a TypeDescription value for the fully-qualified proto type name
// with a given descriptor.
-func newTypeDescription(typeName string, desc protoreflect.MessageDescriptor, extensions extensionMap) *TypeDescription {
+func newTypeDescription(typeName string, desc protoreflect.MessageDescriptor, pbdb *Db) *TypeDescription {
msgType := dynamicpb.NewMessageType(desc)
msgZero := dynamicpb.NewMessage(desc)
fieldMap := map[string]*FieldDescription{}
+ jsonFieldMap := map[string]*FieldDescription{}
fields := desc.Fields()
for i := 0; i < fields.Len(); i++ {
f := fields.Get(i)
- fieldMap[string(f.Name())] = newFieldDescription(f)
+ fd := newFieldDescription(f, pbdb.jsonFieldNames)
+ fieldMap[fd.Name()] = fd
+ if pbdb.jsonFieldNames {
+ jsonFieldMap[fd.JSONName()] = fd
+ }
}
return &TypeDescription{
- typeName: typeName,
- desc: desc,
- msgType: msgType,
- fieldMap: fieldMap,
- extensions: extensions,
- reflectType: reflectTypeOf(msgZero),
- zeroMsg: zeroValueOf(msgZero),
+ typeName: typeName,
+ desc: desc,
+ msgType: msgType,
+ fieldMap: fieldMap,
+ jsonFieldMap: jsonFieldMap,
+ extensions: pbdb.extensions,
+ reflectType: reflectTypeOf(msgZero),
+ zeroMsg: zeroValueOf(msgZero),
+ jsonFieldNames: pbdb.jsonFieldNames,
}
}
// TypeDescription is a collection of type metadata relevant to expression
// checking and evaluation.
type TypeDescription struct {
- typeName string
- desc protoreflect.MessageDescriptor
- msgType protoreflect.MessageType
- fieldMap map[string]*FieldDescription
- extensions extensionMap
- reflectType reflect.Type
- zeroMsg proto.Message
+ typeName string
+ desc protoreflect.MessageDescriptor
+ msgType protoreflect.MessageType
+ fieldMap map[string]*FieldDescription
+ jsonFieldMap map[string]*FieldDescription
+ extensions extensionMap
+ reflectType reflect.Type
+ zeroMsg proto.Message
+ // jsonFieldNames indicates if the type's fields are accessible via their JSON names.
+ jsonFieldNames bool
}
// Copy copies the type description with updated references to the Db.
func (td *TypeDescription) Copy(pbdb *Db) *TypeDescription {
return &TypeDescription{
- typeName: td.typeName,
- desc: td.desc,
- msgType: td.msgType,
- fieldMap: td.fieldMap,
- extensions: pbdb.extensions,
- reflectType: td.reflectType,
- zeroMsg: td.zeroMsg,
+ typeName: td.typeName,
+ desc: td.desc,
+ msgType: td.msgType,
+ fieldMap: td.fieldMap,
+ jsonFieldMap: td.jsonFieldMap,
+ extensions: pbdb.extensions,
+ reflectType: td.reflectType,
+ zeroMsg: td.zeroMsg,
+ jsonFieldNames: td.jsonFieldNames,
}
}
// FieldMap returns a string field name to FieldDescription map.
func (td *TypeDescription) FieldMap() map[string]*FieldDescription {
+ if td.jsonFieldNames {
+ return td.jsonFieldMap
+ }
return td.fieldMap
}
@@ -97,11 +112,15 @@ func (td *TypeDescription) FieldByName(name string) (*FieldDescription, bool) {
return fd, true
}
extFieldMap, found := td.extensions[td.typeName]
- if !found {
- return nil, false
+ if found {
+ fd, found = extFieldMap[name]
+ return fd, found
+ }
+ if td.jsonFieldNames {
+ fd, found = td.jsonFieldMap[name]
+ return fd, found
}
- fd, found = extFieldMap[name]
- return fd, found
+ return nil, false
}
// MaybeUnwrap accepts a proto message as input and unwraps it to a primitive CEL type if possible.
@@ -132,7 +151,7 @@ func (td *TypeDescription) Zero() proto.Message {
}
// newFieldDescription creates a new field description from a protoreflect.FieldDescriptor.
-func newFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescription {
+func newFieldDescription(fieldDesc protoreflect.FieldDescriptor, jsonFieldNames bool) *FieldDescription {
var reflectType reflect.Type
var zeroMsg proto.Message
switch fieldDesc.Kind() {
@@ -168,15 +187,16 @@ func newFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescripti
}
var keyType, valType *FieldDescription
if fieldDesc.IsMap() {
- keyType = newFieldDescription(fieldDesc.MapKey())
- valType = newFieldDescription(fieldDesc.MapValue())
+ keyType = newFieldDescription(fieldDesc.MapKey(), jsonFieldNames)
+ valType = newFieldDescription(fieldDesc.MapValue(), jsonFieldNames)
}
return &FieldDescription{
- desc: fieldDesc,
- KeyType: keyType,
- ValueType: valType,
- reflectType: reflectType,
- zeroMsg: zeroValueOf(zeroMsg),
+ desc: fieldDesc,
+ KeyType: keyType,
+ ValueType: valType,
+ reflectType: reflectType,
+ zeroMsg: zeroValueOf(zeroMsg),
+ jsonFieldName: jsonFieldNames,
}
}
@@ -187,9 +207,10 @@ type FieldDescription struct {
// ValueType holds the value FieldDescription for map fields.
ValueType *FieldDescription
- desc protoreflect.FieldDescriptor
- reflectType reflect.Type
- zeroMsg proto.Message
+ desc protoreflect.FieldDescriptor
+ reflectType reflect.Type
+ zeroMsg proto.Message
+ jsonFieldName bool
}
// CheckedType returns the type-definition used at type-check time.
@@ -321,11 +342,20 @@ func (fd *FieldDescription) MaybeUnwrapDynamic(msg protoreflect.Message) (any, b
return unwrapDynamic(fd, msg)
}
-// Name returns the CamelCase name of the field within the proto-based struct.
+// Name returns the snake_case name of the field within the proto-based struct.
func (fd *FieldDescription) Name() string {
return string(fd.desc.Name())
}
+// JSONName returns the JSON name of the field, if present.
+func (fd *FieldDescription) JSONName() string {
+ jsonName := fd.desc.JSONName()
+ if len(jsonName) != 0 {
+ return jsonName
+ }
+ return string(fd.desc.Name())
+}
+
// ProtoKind returns the protobuf reflected kind of the field.
func (fd *FieldDescription) ProtoKind() protoreflect.Kind {
return fd.desc.Kind()
diff --git a/common/types/pb/type_test.go b/common/types/pb/type_test.go
index bf2ad58d6..fed1a3d0d 100644
--- a/common/types/pb/type_test.go
+++ b/common/types/pb/type_test.go
@@ -54,6 +54,38 @@ func TestTypeDescription(t *testing.T) {
}
}
+func TestTypeDescriptionJSONFieldNames(t *testing.T) {
+ pbdb := NewDb(JSONFieldNames(true))
+ msg := &proto2pb.TestAllTypes{}
+ msgType := string(msg.ProtoReflect().Descriptor().FullName())
+ _, err := pbdb.RegisterMessage(msg)
+ if err != nil {
+ t.Fatalf("pbdb.RegisterMessage() failed: %v", err)
+ }
+ td, found := pbdb.DescribeType(msgType)
+ if !found {
+ t.Fatalf("pbdb.DescribeType(%s) not found", msgType)
+ }
+ fd, found := td.FieldByName("singleBoolWrapper")
+ if !found {
+ t.Fatal("td.FieldByName(singleBoolWrapper) failed")
+ }
+ if fd.JSONName() != "singleBoolWrapper" {
+ t.Fatalf("fd.JSONName() does not return the correct json name: %s", fd.JSONName())
+ }
+ if fd.Name() != "single_bool_wrapper" {
+ t.Fatalf("fd.Name() does not return correct proto name: %s", fd.Name())
+ }
+ enumName := "google.expr.proto2.test.TestAllTypes.NestedEnum.BAR"
+ en, found := pbdb.DescribeEnum(enumName)
+ if !found {
+ t.Fatalf("pbdb.DescribeEnum(%s) not found", enumName)
+ }
+ if en.Value() != 1 && en.Name() != enumName {
+ t.Errorf("got %v, wanted %s: %d", en, enumName, 1)
+ }
+}
+
func TestTypeDescriptionGroupFields(t *testing.T) {
pbdb := NewDb()
msg := &proto2pb.TestAllTypes{}
@@ -98,6 +130,26 @@ func TestTypeDescriptionFieldMap(t *testing.T) {
}
}
+func TestTypeDescriptionJSONFieldMap(t *testing.T) {
+ pbdb := NewDb(JSONFieldNames(true))
+ msg := &proto3pb.TestAllTypes{}
+ pbdb.RegisterMessage(msg)
+ td, found := pbdb.DescribeType("google.expr.proto3.test.TestAllTypes")
+ if !found {
+ t.Fatalf("pbdb.DescribeType(%v) not found", msg)
+ }
+ fd, found := td.FieldMap()["singleNestedMessage"]
+ if !found {
+ t.Fatal("singleNestedMessage not found")
+ }
+ if fd.Name() != "single_nested_message" {
+ t.Fatalf("fd.Name() got %s, wanted 'single_nested_message'", fd.Name())
+ }
+ if fd.JSONName() != "singleNestedMessage" {
+ t.Fatalf("fd.JSONName() got %s, wanted 'singleNestedMessage'", fd.JSONName())
+ }
+}
+
func TestFieldDescription(t *testing.T) {
pbdb := NewDb()
msg := proto3pb.NestedTestAllTypes{}
diff --git a/common/types/provider.go b/common/types/provider.go
index 936a4e28b..ebfd66dcb 100644
--- a/common/types/provider.go
+++ b/common/types/provider.go
@@ -81,6 +81,9 @@ type FieldType struct {
// GetFrom retrieves the field value on the input object, if set.
GetFrom ref.FieldGetter
+
+ // IsJSONField
+ IsJSONField bool
}
// Registry provides type information for a set of registered types.
@@ -93,11 +96,50 @@ type Registry struct {
// provider which can create new instances of the provided message or any
// message that proto depends upon in its FileDescriptor.
func NewRegistry(types ...proto.Message) (*Registry, error) {
- p := &Registry{
+ return NewProtoRegistry(ProtoTypes(types...))
+}
+
+// RegistryOption configures the behavior of the registry.
+type RegistryOption func(r *Registry) (*Registry, error)
+
+// JSONFieldNames configures JSON field name support within the protobuf types in the registry.
+func JSONFieldNames(enabled bool) RegistryOption {
+ return func(r *Registry) (*Registry, error) {
+ if enabled != r.pbdb.JSONFieldNames() {
+ newDB := pb.NewDb(pb.JSONFieldNames(enabled))
+ files := r.pbdb.FileDescriptions()
+ for _, fd := range files {
+ _, err := newDB.RegisterDescriptor(fd.FileDescriptor())
+ if err != nil {
+ return nil, err
+ }
+ }
+ r.pbdb = newDB
+ }
+ return r, nil
+ }
+}
+
+// ProtoTypes creates a RegistryOption which registers the individual proto messages with the registry.
+func ProtoTypes(types ...proto.Message) RegistryOption {
+ return func(r *Registry) (*Registry, error) {
+ for _, msgType := range types {
+ err := r.RegisterMessage(msgType)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return r, nil
+ }
+}
+
+// NewProtoRegistry creates a proto-based registry with a set of configurable options.
+func NewProtoRegistry(opts ...RegistryOption) (*Registry, error) {
+ r := &Registry{
revTypeMap: make(map[string]*Type),
pbdb: pb.NewDb(),
}
- err := p.RegisterType(
+ err := r.RegisterType(
BoolType,
BytesType,
DoubleType,
@@ -114,19 +156,19 @@ func NewRegistry(types ...proto.Message) (*Registry, error) {
return nil, err
}
// This block ensures that the well-known protobuf types are registered by default.
- for _, fd := range p.pbdb.FileDescriptions() {
- err = p.registerAllTypes(fd)
+ for _, fd := range r.pbdb.FileDescriptions() {
+ err = r.registerAllTypes(fd)
if err != nil {
return nil, err
}
}
- for _, msgType := range types {
- err = p.RegisterMessage(msgType)
+ for _, opt := range opts {
+ r, err = opt(r)
if err != nil {
return nil, err
}
}
- return p, nil
+ return r, nil
}
// NewEmptyRegistry returns a registry which is completely unconfigured.
@@ -172,9 +214,11 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType,
return nil, false
}
return &ref.FieldType{
- Type: field.CheckedType(),
- IsSet: field.IsSet,
- GetFrom: field.GetFrom}, true
+ Type: field.CheckedType(),
+ IsSet: field.IsSet,
+ GetFrom: field.GetFrom,
+ IsJSONField: p.pbdb.JSONFieldNames() && fieldName == field.JSONName(),
+ }, true
}
// FindStructFieldNames returns the set of field names for the given struct type,
@@ -206,9 +250,11 @@ func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType
return nil, false
}
return &FieldType{
- Type: fieldDescToCELType(field),
- IsSet: field.IsSet,
- GetFrom: field.GetFrom}, true
+ Type: fieldDescToCELType(field),
+ IsSet: field.IsSet,
+ GetFrom: field.GetFrom,
+ IsJSONField: p.pbdb.JSONFieldNames() && fieldName == field.JSONName(),
+ }, true
}
// FindIdent takes a qualified identifier name and returns a ref.Val if one exists.
@@ -268,9 +314,8 @@ func (p *Registry) NewValue(structType string, fields map[string]ref.Val) ref.Va
return NewErr("unknown type '%s'", structType)
}
msg := td.New()
- fieldMap := td.FieldMap()
for name, value := range fields {
- field, found := fieldMap[name]
+ field, found := td.FieldByName(name)
if !found {
return NewErr("no such field: %s", name)
}
diff --git a/common/types/provider_test.go b/common/types/provider_test.go
index 2576a5740..9189df449 100644
--- a/common/types/provider_test.go
+++ b/common/types/provider_test.go
@@ -134,10 +134,10 @@ func TestRegistryFindStructType(t *testing.T) {
}
func TestRegistryFindStructFieldNames(t *testing.T) {
- reg := newTestRegistry(t, &exprpb.Decl{}, &exprpb.Reference{})
tests := []struct {
- typeName string
- fields []string
+ typeName string
+ fields []string
+ jsonFieldNames bool
}{
{
typeName: "google.api.expr.v1alpha1.Reference",
@@ -151,11 +151,19 @@ func TestRegistryFindStructFieldNames(t *testing.T) {
typeName: "invalid.TypeName",
fields: []string{},
},
+ {
+ typeName: "google.api.expr.v1alpha1.Reference",
+ fields: []string{"name", "overloadId", "value"},
+ jsonFieldNames: true,
+ },
}
for _, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) {
+ reg := newTestRegistry(t,
+ ProtoTypes(&exprpb.Decl{}, &exprpb.Reference{}),
+ JSONFieldNames(tc.jsonFieldNames))
fields, _ := reg.FindStructFieldNames(tc.typeName)
sort.Strings(fields)
sort.Strings(tc.fields)
@@ -167,16 +175,12 @@ func TestRegistryFindStructFieldNames(t *testing.T) {
}
func TestRegistryFindStructFieldType(t *testing.T) {
- reg := newTestRegistry(t)
- err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile())
- if err != nil {
- t.Fatalf("RegisterDescriptor() failed: %v", err)
- }
msgTypeName := ".google.expr.proto3.test.TestAllTypes"
tests := []struct {
- typeName string
- field string
- found bool
+ typeName string
+ field string
+ found bool
+ jsonFieldNames bool
}{
{
typeName: msgTypeName,
@@ -238,11 +242,28 @@ func TestRegistryFindStructFieldType(t *testing.T) {
field: "map_string_string",
found: false,
},
+ {
+ typeName: msgTypeName,
+ field: "mapStringString",
+ found: true,
+ jsonFieldNames: true,
+ },
+ {
+ typeName: msgTypeName,
+ field: "map_string_string",
+ found: true,
+ jsonFieldNames: true,
+ },
}
for _, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s.%s", tc.typeName, tc.field), func(t *testing.T) {
+ reg := newTestRegistry(t, JSONFieldNames(tc.jsonFieldNames))
+ err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile())
+ if err != nil {
+ t.Fatalf("RegisterDescriptor() failed: %v", err)
+ }
// When the field is expected to be found, test parity of the results
if tc.found {
refField, found := reg.FindFieldType(tc.typeName, tc.field)
@@ -278,7 +299,7 @@ func TestRegistryFindStructFieldType(t *testing.T) {
}
func TestRegistryNewValue(t *testing.T) {
- reg := newTestRegistry(t, &proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})
+ reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{}))
tests := []struct {
typeName string
fields map[string]ref.Val
@@ -406,7 +427,7 @@ func TestRegistryNewValue(t *testing.T) {
}
func TestRegistryNewValueErrors(t *testing.T) {
- reg := newTestRegistry(t, &proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})
+ reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{}))
tests := []struct {
typeName string
fields map[string]ref.Val
@@ -483,7 +504,7 @@ func TestRegistryNewValueErrors(t *testing.T) {
}
func TestRegistryGetters(t *testing.T) {
- reg := newTestRegistry(t, &exprpb.ParsedExpr{})
+ reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{}))
if sourceInfo := reg.NewValue(
"google.api.expr.v1alpha1.SourceInfo",
map[string]ref.Val{
@@ -519,7 +540,7 @@ func TestRegistryGetters(t *testing.T) {
}
func TestConvertToNative(t *testing.T) {
- reg := newTestRegistry(t, &exprpb.ParsedExpr{})
+ reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{}))
// Core type conversion tests.
expectValueToNative(t, True, true)
@@ -584,7 +605,7 @@ func TestConvertToNative(t *testing.T) {
}
func TestNativeToValue_Any(t *testing.T) {
- reg := newTestRegistry(t, &exprpb.ParsedExpr{})
+ reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{}))
// NullValue
anyValue, err := NullValue.ConvertToNative(anyValueType)
if err != nil {
@@ -645,7 +666,7 @@ func TestNativeToValue_Any(t *testing.T) {
}
func TestNativeToValue_Json(t *testing.T) {
- reg := newTestRegistry(t, &exprpb.ParsedExpr{})
+ reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{}))
// Json primitive conversion test.
expectNativeToValue(t, structpb.NewBoolValue(false), False)
expectNativeToValue(t, structpb.NewNumberValue(1.1), Double(1.1))
@@ -835,7 +856,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) {
func expectNativeToValue(t *testing.T, in any, out ref.Val) {
t.Helper()
- reg := newTestRegistry(t, &exprpb.ParsedExpr{})
+ reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{}))
if val := reg.NativeToValue(in); IsError(val) {
t.Error(val)
} else {
@@ -918,11 +939,11 @@ type testFloat32 float32
type testFloat64 float64
type testString string
-func newTestRegistry(t *testing.T, types ...proto.Message) *Registry {
+func newTestRegistry(t *testing.T, opts ...RegistryOption) *Registry {
t.Helper()
- reg, err := NewRegistry(types...)
+ reg, err := NewProtoRegistry(opts...)
if err != nil {
- t.Fatalf("NewRegistry(%v) failed: %v", types, err)
+ t.Fatalf("NewProtoRegistry() failed: %v", err)
}
return reg
}
diff --git a/common/types/ref/provider.go b/common/types/ref/provider.go
index b9820023d..ed5ab0662 100644
--- a/common/types/ref/provider.go
+++ b/common/types/ref/provider.go
@@ -93,6 +93,9 @@ type FieldType struct {
// GetFrom retrieves the field value on the input object, if set.
GetFrom FieldGetter
+
+ // IsJSONFIeld indicates that the field was accessed via its JSON name.
+ IsJSONField bool
}
// FieldTester is used to test field presence on an input object.