diff --git a/cel/env.go b/cel/env.go index a7aa6db34..d3295b6a5 100644 --- a/cel/env.go +++ b/cel/env.go @@ -1052,9 +1052,10 @@ func (p *interopCELTypeProvider) FindStructFieldType(structType, fieldName strin return nil, false } return &types.FieldType{ - Type: t, - IsSet: ft.IsSet, - GetFrom: ft.GetFrom, + Type: t, + IsSet: ft.IsSet, + GetFrom: ft.GetFrom, + IsJSONField: ft.IsJSONField, }, true } return nil, false diff --git a/checker/checker.go b/checker/checker.go index d07d8e799..42d27a428 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -71,6 +71,11 @@ func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.E // check() deletes some nodes while rewriting the AST. For example the Select operand is // deleted when a variable reference is replaced with a Ident expression. c.AST.ClearUnusedIDs() + if env.jsonFieldNames { + c.AST.SourceInfo().AddExtension( + ast.NewExtension("json_name", ast.NewExtensionVersion(1, 1), ast.ComponentRuntime), + ) + } return c.AST, errs } @@ -718,6 +723,9 @@ func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (* } if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found { + if c.env.jsonFieldNames && !ft.IsJSONField { + c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName) + } return ft.Type, found } diff --git a/checker/checker_test.go b/checker/checker_test.go index 387c8bd75..623e1f709 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -2413,6 +2413,42 @@ _&&_(_==_(list~type(list(dyn))^list, @result~bool^@result)~bool`, outType: types.BoolType, }, + { + in: `TestAllTypes{?singleInt32: {}.?i}`, + container: "google.expr.proto2.test", + env: testEnv{optionalSyntax: true, jsonFieldNames: true}, + out: `google.expr.proto2.test.TestAllTypes{ + ?singleInt32:_?._( + {}~map(dyn, int), + "i" + )~optional_type(int)^select_optional_field + }~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes`, + outType: types.NewObjectType( + "google.expr.proto2.test.TestAllTypes", + ), + }, + { + in: `TestAllTypes{?singleInt32: {'i': 20}.?i}.singleInt32`, + container: "google.expr.proto2.test", + env: testEnv{optionalSyntax: true, jsonFieldNames: true}, + out: `google.expr.proto2.test.TestAllTypes{ + ?singleInt32:_?._( + { + "i"~string:20~int + }~map(string, int), + "i" + )~optional_type(int)^select_optional_field + }~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes.singleInt32~int`, + outType: types.IntType, + }, + { + in: `TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32`, + container: "google.expr.proto2.test", + env: testEnv{optionalSyntax: true, jsonFieldNames: true}, + err: `ERROR: :1:41: undefined field 'single_bool' + | TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32 + | ........................................^`, + }, } } @@ -2470,6 +2506,7 @@ type testEnv struct { functions []*decls.FunctionDecl variadicASTs bool optionalSyntax bool + jsonFieldNames bool } func TestCheck(t *testing.T) { @@ -2505,9 +2542,12 @@ func TestCheck(t *testing.T) { t.Fatalf("Unexpected parse errors: %v", errors.ToDisplayString()) } - reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}) + reg, err := types.NewProtoRegistry( + types.JSONFieldNames(tc.env.jsonFieldNames), + types.ProtoTypes(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}), + ) if err != nil { - t.Fatalf("types.NewRegistry() failed: %v", err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } if tc.env.optionalSyntax { if err := reg.RegisterType(types.OptionalType); err != nil { @@ -2522,6 +2562,9 @@ func TestCheck(t *testing.T) { if len(tc.opts) != 0 { opts = tc.opts } + if tc.env.jsonFieldNames { + opts = append(opts, JSONFieldNames(true)) + } env, err := NewEnv(cont, reg, opts...) if err != nil { t.Fatalf("NewEnv(cont, reg) failed: %v", err) diff --git a/checker/env.go b/checker/env.go index 16c8ae60f..477918c48 100644 --- a/checker/env.go +++ b/checker/env.go @@ -74,6 +74,7 @@ type Env struct { declarations *Scopes aggLitElemType aggregateLiteralElementType filteredOverloadIDs map[string]struct{} + jsonFieldNames bool } // NewEnv returns a new *Env with the given parameters. @@ -104,6 +105,7 @@ func NewEnv(container *containers.Container, provider types.Provider, opts ...Op declarations: declarations, aggLitElemType: aggLitElemType, filteredOverloadIDs: filteredOverloadIDs, + jsonFieldNames: envOptions.jsonFieldNames, }, nil } diff --git a/checker/options.go b/checker/options.go index 0560c3813..af714323b 100644 --- a/checker/options.go +++ b/checker/options.go @@ -18,6 +18,7 @@ type options struct { crossTypeNumericComparisons bool homogeneousAggregateLiterals bool validatedDeclarations *Scopes + jsonFieldNames bool } // Option is a functional option for configuring the type-checker @@ -40,3 +41,11 @@ func ValidatedDeclarations(env *Env) Option { return nil } } + +// JSONFieldNames enables the use of json names instead of the standard protobuf snake_case field names +func JSONFieldNames(enabled bool) Option { + return func(opts *options) error { + opts.jsonFieldNames = enabled + return nil + } +} diff --git a/common/ast/ast.go b/common/ast/ast.go index 3c5ee0c80..aae2a83e9 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -231,6 +231,11 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo { for id, call := range info.macroCalls { callsCopy[id] = defaultFactory.CopyExpr(call) } + var extCopy []Extension + if len(info.extensions) > 0 { + extCopy = make([]Extension, len(info.extensions)) + copy(extCopy, info.extensions) + } return &SourceInfo{ syntax: info.syntax, desc: info.desc, @@ -239,6 +244,7 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo { baseCol: info.baseCol, offsetRanges: rangesCopy, macroCalls: callsCopy, + extensions: extCopy, } } @@ -252,6 +258,9 @@ type SourceInfo struct { baseCol int32 offsetRanges map[int64]OffsetRange macroCalls map[int64]Expr + + // extensions indicate versioned optional features which affect the execution of one or more CEL component. + extensions []Extension } // RenumberIDs performs an in-place update of the expression IDs within the SourceInfo. @@ -420,6 +429,23 @@ func (s *SourceInfo) ComputeOffsetAbsolute(line, col int32) int32 { return offset + col } +// Extensions returns the set of extensions present in the source. +func (s *SourceInfo) Extensions() []Extension { + var extensions []Extension + if s == nil { + return extensions + } + return s.extensions +} + +// AddExtension adds an extension record into the SourceInfo. +func (s *SourceInfo) AddExtension(ext Extension) { + if s == nil { + return + } + s.extensions = append(s.extensions, ext) +} + // OffsetRange captures the start and stop positions of a section of text in the input expression. type OffsetRange struct { Start int32 @@ -489,6 +515,53 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool { return true } +// NewExtension creates an Extension to be recorded on the SourceInfo. +func NewExtension(id string, version ExtensionVersion, components ...ExtensionComponent) Extension { + return Extension{ + ID: id, + Version: version, + Components: components, + } +} + +// Extension represents a versioned, optional feature present in the AST that affects CEL component behavior. +type Extension struct { + // ID indicates the unique name of the extension. + ID string + // Version indicates the major / minor version. + Version ExtensionVersion + // Components enumerates the CEL components affected by the feature. + Components []ExtensionComponent +} + +// NewExtensionVersion creates a new extension version with a major, minor version. +func NewExtensionVersion(major, minor int64) ExtensionVersion { + return ExtensionVersion{Major: major, Minor: minor} +} + +// ExtensionVersion represents a semantic version with a major and minor number. +type ExtensionVersion struct { + // Major version of the extension. + // All versions with the same major number are expected to be compatible with all minor version changes. + Major int64 + + // Minor version of the extension which indicates that some small non-semantic change has been made to + // the extension. + Minor int64 +} + +// ExtensionComponent indicates which CEL component is affected. +type ExtensionComponent int + +const ( + // ComponentParser means the feature affects expression parsing. + ComponentParser ExtensionComponent = iota + 1 + // ComponentTypeChecker means the feature affects type-checking. + ComponentTypeChecker + // ComponentRuntime alters program planning or evaluation of the AST. + ComponentRuntime +) + type maxIDVisitor struct { maxID int64 *baseVisitor diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index 0a86c3936..17c4c82c2 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -20,6 +20,7 @@ import ( "reflect" "testing" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/overloads" @@ -83,6 +84,55 @@ func TestASTCopy(t *testing.T) { } } +func TestASTJsonNames(t *testing.T) { + tests := []string{ + `google.expr.proto3.test.TestAllTypes{}`, + `google.expr.proto3.test.TestAllTypes{repeatedInt32: [1, 2]}`, + `google.expr.proto3.test.TestAllTypes{singleInt32: 2}.singleInt32 == 2`, + } + + for _, tst := range tests { + checked := mustTypeCheck(t, tst, checker.JSONFieldNames(true), types.JSONFieldNames(true)) + copyChecked := ast.Copy(checked) + if !reflect.DeepEqual(copyChecked.Expr(), checked.Expr()) { + t.Errorf("Copy() got expr %v, wanted %v", copyChecked.Expr(), checked.Expr()) + } + if !reflect.DeepEqual(copyChecked.SourceInfo(), checked.SourceInfo()) { + t.Errorf("Copy() got source info %v, wanted %v", copyChecked.SourceInfo(), checked.SourceInfo()) + } + copyParsed := ast.Copy(ast.NewAST(checked.Expr(), checked.SourceInfo())) + if !reflect.DeepEqual(copyParsed.Expr(), checked.Expr()) { + t.Errorf("Copy() got expr %v, wanted %v", copyParsed.Expr(), checked.Expr()) + } + if !reflect.DeepEqual(copyParsed.SourceInfo(), checked.SourceInfo()) { + t.Errorf("Copy() got source info %v, wanted %v", copyParsed.SourceInfo(), checked.SourceInfo()) + } + checkedPB, err := ast.ToProto(checked) + if err != nil { + t.Errorf("ast.ToProto() failed: %v", err) + } + copyCheckedPB, err := ast.ToProto(copyChecked) + if err != nil { + t.Errorf("ast.ToProto() failed: %v", err) + } + if !proto.Equal(checkedPB, copyCheckedPB) { + t.Errorf("Copy() produced different proto results, got %v, wanted %v", + prototext.Format(checkedPB), prototext.Format(copyCheckedPB)) + } + checkedRoundtrip, err := ast.ToAST(checkedPB) + if err != nil { + t.Errorf("ast.ToAST() failed: %v", err) + } + same := reflect.DeepEqual(checked.Expr(), checkedRoundtrip.Expr()) && + reflect.DeepEqual(checked.ReferenceMap(), checkedRoundtrip.ReferenceMap()) && + reflect.DeepEqual(checked.TypeMap(), checkedRoundtrip.TypeMap()) && + reflect.DeepEqual(checked.SourceInfo().MacroCalls(), checkedRoundtrip.SourceInfo().MacroCalls()) + if !same { + t.Errorf("Roundtrip got %v, wanted %v", checkedRoundtrip, checked) + } + } +} + func TestASTNilSafety(t *testing.T) { ex, err := ast.ProtoToExpr(nil) if err != nil { @@ -184,6 +234,9 @@ func TestSourceInfoNilSafety(t *testing.T) { if len(testInfo.MacroCalls()) != 0 { t.Errorf("MacroCalls() got %v, wanted empty map", testInfo.MacroCalls()) } + if len(testInfo.Extensions()) != 0 { + t.Errorf("Extensions() got %v, wanted empty list", testInfo.Extensions()) + } if call, found := testInfo.GetMacroCall(0); found { t.Errorf("GetMacroCall(0) got %v, wanted not found", call) } diff --git a/common/ast/conversion.go b/common/ast/conversion.go index 435d8f654..380f8c118 100644 --- a/common/ast/conversion.go +++ b/common/ast/conversion.go @@ -27,6 +27,19 @@ import ( structpb "google.golang.org/protobuf/types/known/structpb" ) +var ( + pbComponentMap = map[exprpb.SourceInfo_Extension_Component]ExtensionComponent{ + exprpb.SourceInfo_Extension_COMPONENT_PARSER: ComponentParser, + exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER: ComponentTypeChecker, + exprpb.SourceInfo_Extension_COMPONENT_RUNTIME: ComponentRuntime, + } + componentPBMap = map[ExtensionComponent]exprpb.SourceInfo_Extension_Component{ + ComponentParser: exprpb.SourceInfo_Extension_COMPONENT_PARSER, + ComponentTypeChecker: exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER, + ComponentRuntime: exprpb.SourceInfo_Extension_COMPONENT_RUNTIME, + } +) + // ToProto converts an AST to a CheckedExpr protobouf. func ToProto(ast *AST) (*exprpb.CheckedExpr, error) { refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap())) @@ -534,6 +547,25 @@ func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) { } sourceInfo.MacroCalls[id] = call } + for _, ext := range info.Extensions() { + var components []exprpb.SourceInfo_Extension_Component + for _, c := range ext.Components { + comp, found := componentPBMap[c] + if found { + components = append(components, comp) + } + } + ver := &exprpb.SourceInfo_Extension_Version{ + Major: ext.Version.Major, + Minor: ext.Version.Minor, + } + pbExt := &exprpb.SourceInfo_Extension{ + Id: ext.ID, + Version: ver, + AffectedComponents: components, + } + sourceInfo.Extensions = append(sourceInfo.Extensions, pbExt) + } return sourceInfo, nil } @@ -556,6 +588,23 @@ func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) { } sourceInfo.SetMacroCall(id, call) } + for _, pbExt := range info.GetExtensions() { + var components []ExtensionComponent + for _, c := range pbExt.GetAffectedComponents() { + comp, found := pbComponentMap[*c.Enum()] + if found { + components = append(components, comp) + } + } + sourceInfo.AddExtension(NewExtension( + pbExt.GetId(), + NewExtensionVersion( + pbExt.GetVersion().GetMajor(), + pbExt.GetVersion().GetMinor(), + ), + components..., + )) + } return sourceInfo, nil } diff --git a/common/ast/conversion_test.go b/common/ast/conversion_test.go index 86b00e3e3..ecd50830b 100644 --- a/common/ast/conversion_test.go +++ b/common/ast/conversion_test.go @@ -38,10 +38,12 @@ import ( func TestConvertAST(t *testing.T) { fac := ast.NewExprFactory() tests := []struct { + name string goAST *ast.AST pbAST *exprpb.CheckedExpr }{ { + name: "simple ast", goAST: ast.NewCheckedAST(ast.NewAST(nil, nil), map[int64]*types.Type{ 1: types.BoolType, @@ -71,6 +73,7 @@ func TestConvertAST(t *testing.T) { }, }, { + name: "comprehension ast", goAST: ast.NewAST( fac.NewComprehensionTwoVar(1, fac.NewIdent(2, "data"), @@ -179,11 +182,74 @@ func TestConvertAST(t *testing.T) { ReferenceMap: map[int64]*exprpb.Reference{}, }, }, + { + name: "json names ast", + goAST: ast.NewCheckedAST( + ast.NewAST( + fac.NewCall(2, overloads.LogicalNot, fac.NewIdent(1, "value")), + sourceWithExtension( + ast.NewSourceInfo(common.NewTextSource("!value")), + ast.NewExtension("json_name", ast.NewExtensionVersion(1, 1), ast.ComponentRuntime), + ), + ), + map[int64]*types.Type{ + 1: types.BoolType, + 2: types.BoolType, + }, + map[int64]*ast.ReferenceInfo{ + 1: ast.NewFunctionReference(overloads.LogicalNot), + 2: ast.NewIdentReference("value", nil), + }, + ), + pbAST: &exprpb.CheckedExpr{ + Expr: &exprpb.Expr{ + Id: 2, + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: "!_", + Args: []*exprpb.Expr{ + { + Id: 1, + ExprKind: &exprpb.Expr_IdentExpr{ + IdentExpr: &exprpb.Expr_Ident{ + Name: "value", + }, + }, + }, + }, + }, + }, + }, + SourceInfo: &exprpb.SourceInfo{ + Location: "", + Extensions: []*exprpb.SourceInfo_Extension{ + { + Id: "json_name", + Version: &exprpb.SourceInfo_Extension_Version{ + Major: 1, + Minor: 1, + }, + AffectedComponents: []exprpb.SourceInfo_Extension_Component{ + exprpb.SourceInfo_Extension_COMPONENT_RUNTIME, + }, + }, + }, + }, + TypeMap: map[int64]*exprpb.Type{ + 1: chkdecls.Bool, + 2: chkdecls.Bool, + }, + ReferenceMap: map[int64]*exprpb.Reference{ + 1: {OverloadId: []string{overloads.LogicalNot}}, + 2: {Name: "value"}, + }, + }, + }, } - for i, tst := range tests { + for _, tst := range tests { tc := tst - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { goAST := tc.goAST pbAST := tc.pbAST checkedAST, err := ast.ToAST(pbAST) @@ -194,6 +260,10 @@ func TestConvertAST(t *testing.T) { !reflect.DeepEqual(checkedAST.TypeMap(), goAST.TypeMap()) { t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST) } + if !reflect.DeepEqual(checkedAST.SourceInfo().Extensions(), goAST.SourceInfo().Extensions()) { + t.Errorf("conversion to AST did not preserve SourceInfo extensions. got %v, wanted %v", + checkedAST.SourceInfo().Extensions(), goAST.SourceInfo().Extensions()) + } if len(checkedAST.ReferenceMap()) > 2 { if !checkedAST.ReferenceMap()[1].Equals(goAST.ReferenceMap()[1]) || !checkedAST.ReferenceMap()[2].Equals(goAST.ReferenceMap()[2]) { @@ -656,3 +726,8 @@ func TestConstantToValError(t *testing.T) { t.Errorf("ConstantToVal() got %v, wanted error", out) } } + +func sourceWithExtension(info *ast.SourceInfo, ext ast.Extension) *ast.SourceInfo { + info.AddExtension(ext) + return info +} diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go index 5afb23ab7..bd8fa8ac1 100644 --- a/common/ast/navigable_test.go +++ b/common/ast/navigable_test.go @@ -18,8 +18,6 @@ import ( "reflect" "testing" - "google.golang.org/protobuf/proto" - "github.com/google/cel-go/checker" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" @@ -556,7 +554,7 @@ func TestNavigableSelectExpr_TestOnly(t *testing.T) { } } -func mustTypeCheck(t testing.TB, expr string) *ast.AST { +func mustTypeCheck(t testing.TB, expr string, opts ...any) *ast.AST { t.Helper() p, err := parser.NewParser( parser.Macros(parser.AllMacros...), @@ -570,8 +568,21 @@ func mustTypeCheck(t testing.TB, expr string) *ast.AST { if len(iss.GetErrors()) != 0 { t.Fatalf("Parse(%s) failed: %s", expr, iss.ToDisplayString()) } - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) - env := newTestEnv(t, containers.DefaultContainer, reg) + regOpts := []types.RegistryOption{} + chkOpts := []checker.Option{} + for _, opt := range opts { + switch v := opt.(type) { + case types.RegistryOption: + regOpts = append(regOpts, v) + case checker.Option: + chkOpts = append(chkOpts, v) + default: + t.Fatalf("mustTypeCheck() failed with invalid option type: %T", v) + } + } + regOpts = append(regOpts, types.ProtoTypes(&proto3pb.TestAllTypes{})) + reg := newTestRegistry(t, regOpts...) + env := newTestEnv(t, containers.DefaultContainer, reg, chkOpts...) checked, iss := checker.Check(parsed, exprSrc, env) if len(iss.GetErrors()) != 0 { t.Fatalf("Check(%s) failed: %s", expr, iss.ToDisplayString()) @@ -579,18 +590,20 @@ func mustTypeCheck(t testing.TB, expr string) *ast.AST { return checked } -func newTestRegistry(t testing.TB, msgs ...proto.Message) *types.Registry { +func newTestRegistry(t testing.TB, opts ...types.RegistryOption) *types.Registry { t.Helper() - reg, err := types.NewRegistry(msgs...) + reg, err := types.NewProtoRegistry(opts...) if err != nil { - t.Fatalf("types.NewRegistry(%v) failed: %v", msgs, err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } return reg } -func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) *checker.Env { +func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry, opts ...checker.Option) *checker.Env { t.Helper() - env, err := checker.NewEnv(cont, reg, checker.CrossTypeNumericComparisons(true)) + chkOpts := []checker.Option{checker.CrossTypeNumericComparisons(true)} + chkOpts = append(chkOpts, opts...) + env, err := checker.NewEnv(cont, reg, chkOpts...) if err != nil { t.Fatalf("checker.NewEnv(%v, %v) failed: %v", cont, reg, err) } diff --git a/common/types/map_test.go b/common/types/map_test.go index a45a56c85..5b9898c27 100644 --- a/common/types/map_test.go +++ b/common/types/map_test.go @@ -44,7 +44,7 @@ type testStruct struct { } func TestMapContains(t *testing.T) { - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) + reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{})) reflectMap := reg.NativeToValue(map[any]any{ int64(1): "hello", uint64(2): "world", @@ -582,7 +582,7 @@ func TestMapIsZeroValue(t *testing.T) { "hello": "world", }, } - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) tests := []struct { @@ -749,7 +749,7 @@ func TestProtoMap(t *testing.T) { "welcome": "back", } msg := &proto3pb.TestAllTypes{MapStringString: strMap} - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) // Test a simple proto map of string string. @@ -850,7 +850,7 @@ func TestProtoMapGet(t *testing.T) { "welcome": "back", } msg := &proto3pb.TestAllTypes{MapStringString: strMap} - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) field := obj.Get(String("map_string_string")) mapVal, ok := field.(traits.Mapper) @@ -890,7 +890,7 @@ func TestProtoMapConvertToNative(t *testing.T) { "welcome": "back", } msg := &proto3pb.TestAllTypes{MapStringString: strMap} - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) // Test a simple proto map of string string. field := obj.Get(String("map_string_string")) @@ -974,7 +974,7 @@ func TestProtoMapConvertToNative_NestedProto(t *testing.T) { }, } msg := &proto3pb.TestAllTypes{MapInt64NestedType: nestedTypeMap} - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) // Test a simple proto map of string string. field := obj.Get(String("map_int64_nested_type")) diff --git a/common/types/object_test.go b/common/types/object_test.go index 10d05333e..1f89cd165 100644 --- a/common/types/object_test.go +++ b/common/types/object_test.go @@ -49,75 +49,134 @@ func TestNewProtoObject(t *testing.T) { } func TestProtoObjectConvertToNative(t *testing.T) { - reg := newTestRegistry(t, &exprpb.Expr{}) msg := &exprpb.ParsedExpr{ + Expr: &exprpb.Expr{ + Id: 1, + ExprKind: &exprpb.Expr_ConstExpr{ + ConstExpr: &exprpb.Constant{ + ConstantKind: &exprpb.Constant_BoolValue{ + BoolValue: true, + }, + }, + }, + }, SourceInfo: &exprpb.SourceInfo{ LineOffsets: []int32{1, 2, 3}}} - objVal := reg.NativeToValue(msg) - // Proto Message - val, err := objVal.ConvertToNative(reflect.TypeOf(&exprpb.ParsedExpr{})) - if err != nil { - t.Error(err) - } - if !proto.Equal(val.(proto.Message), msg) { - t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg) + tests := []struct { + opts []RegistryOption + fieldMap func(reg *Registry) map[string]ref.Val + outValue map[string]any + }{ + { + opts: []RegistryOption{ProtoTypes(&exprpb.Expr{}), JSONFieldNames(true)}, + fieldMap: func(reg *Registry) map[string]ref.Val { + return map[string]ref.Val{ + "expr": reg.NativeToValue(msg.GetExpr()), + "sourceInfo": reg.NativeToValue(msg.GetSourceInfo()), + } + }, + outValue: map[string]any{ + "expr": map[string]any{ + // The id is encoded to a string because it is int64 type, though + // this is not exactly what's documented in protojson serialization + // and so could signal a bug in the json-handling for int values + // as the serialization to int or string is based on the value precision. + "id": "1", + "constExpr": map[string]any{ + "boolValue": true, + }, + }, + "sourceInfo": map[string]any{ + "lineOffsets": []any{1.0, 2.0, 3.0}, + }, + }, + }, + { + opts: []RegistryOption{ProtoTypes(&exprpb.Expr{}), JSONFieldNames(false)}, + fieldMap: func(reg *Registry) map[string]ref.Val { + return map[string]ref.Val{ + "expr": reg.NativeToValue(msg.GetExpr()), + "source_info": reg.NativeToValue(msg.GetSourceInfo()), + } + }, + outValue: map[string]any{ + "expr": map[string]any{ + "id": "1", + "constExpr": map[string]any{ + "boolValue": true, + }, + }, + "sourceInfo": map[string]any{ + "lineOffsets": []any{1.0, 2.0, 3.0}, + }, + }, + }, } - // Dynamic protobuf - dynPB := reg.NewValue( - string(msg.ProtoReflect().Descriptor().FullName()), - map[string]ref.Val{ - "source_info": reg.NativeToValue(msg.GetSourceInfo()), - }) - if IsError(dynPB) { - t.Fatalf("reg.NewValue() failed: %v", dynPB) - } - dynVal := reg.NativeToValue(dynPB) - val, err = dynVal.ConvertToNative(reflect.TypeOf(msg)) - if err != nil { - t.Fatalf("dynVal.ConvertToNative() failed: %v", err) - } - if !proto.Equal(val.(proto.Message), msg) { - t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg) - } + for _, tst := range tests { + reg := newTestRegistry(t, tst.opts...) + objVal := reg.NativeToValue(msg) - // google.protobuf.Any - anyVal, err := objVal.ConvertToNative(anyValueType) - if err != nil { - t.Fatalf("objVal.ConvertToNative() failed: %v", err) - } - anyMsg := anyVal.(*anypb.Any) - unpackedAny, err := anyMsg.UnmarshalNew() - if err != nil { - t.Fatalf("UnmarshalNew() failed: %v", err) - } - if !proto.Equal(unpackedAny, objVal.Value().(proto.Message)) { - t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), unpackedAny) - } + // Proto Message + val, err := objVal.ConvertToNative(reflect.TypeOf(&exprpb.ParsedExpr{})) + if err != nil { + t.Error(err) + } + if !proto.Equal(val.(proto.Message), msg) { + t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg) + } - // JSON - jsonVal, err := objVal.ConvertToNative(JSONValueType) - if err != nil { - t.Fatalf("objVal.ConvertToNative(%v) failed: %v", JSONValueType, err) - } - jsonBytes, err := protojson.Marshal(jsonVal.(proto.Message)) - jsonTxt := string(jsonBytes) - if err != nil { - t.Fatalf("protojson.Marshal(%v) failed: %v", jsonVal, err) - } - outMap := map[string]any{} - err = json.Unmarshal(jsonBytes, &outMap) - if err != nil { - t.Fatalf("json.Unmarshal(%q) failed: %v", jsonTxt, err) - } - want := map[string]any{ - "sourceInfo": map[string]any{ - "lineOffsets": []any{1.0, 2.0, 3.0}, - }, - } - if !reflect.DeepEqual(outMap, want) { - t.Errorf("got json '%v', expected %v", outMap, want) + // Dynamic protobuf + dynPB := reg.NewValue( + string(msg.ProtoReflect().Descriptor().FullName()), + tst.fieldMap(reg), + ) + if IsError(dynPB) { + t.Fatalf("reg.NewValue() failed: %v", dynPB) + } + dynVal := reg.NativeToValue(dynPB) + val, err = dynVal.ConvertToNative(reflect.TypeOf(msg)) + if err != nil { + t.Fatalf("dynVal.ConvertToNative() failed: %v", err) + } + if !proto.Equal(val.(proto.Message), msg) { + t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), msg) + } + + // google.protobuf.Any + anyVal, err := objVal.ConvertToNative(anyValueType) + if err != nil { + t.Fatalf("objVal.ConvertToNative() failed: %v", err) + } + anyMsg := anyVal.(*anypb.Any) + unpackedAny, err := anyMsg.UnmarshalNew() + if err != nil { + t.Fatalf("UnmarshalNew() failed: %v", err) + } + if !proto.Equal(unpackedAny, objVal.Value().(proto.Message)) { + t.Errorf("Messages were not equal, expect '%v', got '%v'", objVal.Value(), unpackedAny) + } + + // JSON + jsonVal, err := objVal.ConvertToNative(JSONValueType) + if err != nil { + t.Fatalf("objVal.ConvertToNative(%v) failed: %v", JSONValueType, err) + } + jsonBytes, err := protojson.Marshal(jsonVal.(proto.Message)) + jsonTxt := string(jsonBytes) + if err != nil { + t.Fatalf("protojson.Marshal(%v) failed: %v", jsonVal, err) + } + outMap := map[string]any{} + err = json.Unmarshal(jsonBytes, &outMap) + if err != nil { + t.Fatalf("json.Unmarshal(%q) failed: %v", jsonTxt, err) + } + want := tst.outValue + if !reflect.DeepEqual(outMap, want) { + t.Errorf("got json '%v', expected %v", outMap, want) + } } } @@ -127,7 +186,7 @@ func TestProtoObjectIsSet(t *testing.T) { LineOffsets: []int32{1, 2, 3}, }, } - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) objVal := reg.NativeToValue(msg).(*protoObj) if objVal.IsSet(String("source_info")) != True { t.Error("got 'source_info' not set, wanted set") @@ -144,7 +203,7 @@ func TestProtoObjectIsSet(t *testing.T) { } func TestProtoObjectIsZeroValue(t *testing.T) { - reg := newTestRegistry(t, &exprpb.ParsedExpr{}) + reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) emptyObj := reg.NativeToValue(&exprpb.ParsedExpr{}) pb, ok := emptyObj.(traits.Zeroer) if !ok { @@ -166,7 +225,7 @@ func TestProtoObjectGet(t *testing.T) { LineOffsets: []int32{1, 2, 3}, }, } - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) objVal := reg.NativeToValue(msg).(*protoObj) if objVal.Get(String("source_info")).Equal(reg.NativeToValue(msg.GetSourceInfo())) != True { t.Error("could not get 'source_info'") @@ -188,7 +247,7 @@ func TestProtoObjectConvertToType(t *testing.T) { LineOffsets: []int32{1, 2, 3}, }, } - reg := newTestRegistry(t, msg) + reg := newTestRegistry(t, ProtoTypes(msg)) objVal := reg.NativeToValue(msg) tv := objVal.Type().(ref.Val) if objVal.ConvertToType(TypeType).Equal(tv) != True { diff --git a/common/types/pb/file.go b/common/types/pb/file.go index e323afb1d..3a8bdf0b2 100644 --- a/common/types/pb/file.go +++ b/common/types/pb/file.go @@ -32,7 +32,7 @@ func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDe } types := make(map[string]*TypeDescription) for name, msgType := range metadata.msgTypes { - types[name] = newTypeDescription(name, msgType, pbdb.extensions) + types[name] = newTypeDescription(name, msgType, pbdb) } fileExtMap := make(extensionMap) for typeName, extensions := range metadata.msgExtensionMap { @@ -42,12 +42,13 @@ func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDe } for _, ext := range extensions { extDesc := dynamicpb.NewExtensionType(ext).TypeDescriptor() - messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc) + messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc, pbdb.jsonFieldNames) } fileExtMap[typeName] = messageExtMap } return &FileDescription{ name: fileDesc.Path(), + desc: fileDesc, types: types, enums: enums, }, fileExtMap @@ -56,6 +57,7 @@ func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDe // FileDescription holds a map of all types and enum values declared within a proto file. type FileDescription struct { name string + desc protoreflect.FileDescriptor types map[string]*TypeDescription enums map[string]*EnumValueDescription } @@ -68,6 +70,7 @@ func (fd *FileDescription) Copy(pbdb *Db) *FileDescription { } return &FileDescription{ name: fd.name, + desc: fd.desc, types: typesCopy, enums: fd.enums, } @@ -78,6 +81,11 @@ func (fd *FileDescription) GetName() string { return fd.name } +// FileDescriptor returns the proto file descriptor associated with the file representation. +func (fd *FileDescription) FileDescriptor() protoreflect.FileDescriptor { + return fd.desc +} + // GetEnumDescription returns an EnumDescription for a qualified enum value // name declared within the .proto file. func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, bool) { diff --git a/common/types/pb/file_test.go b/common/types/pb/file_test.go index 6a6a5a0f9..a27d097fb 100644 --- a/common/types/pb/file_test.go +++ b/common/types/pb/file_test.go @@ -76,6 +76,19 @@ func TestFileDescriptionGetExtensions(t *testing.T) { } } +func TestFileDescriptionJSONFieldNames(t *testing.T) { + pbdb := NewDb(JSONFieldNames(true)) + msg := &proto2pb.TestAllTypes{} + fd, err := pbdb.RegisterMessage(msg) + if err != nil { + t.Fatalf("pbdb.RegisterMessage() failed: %v", err) + } + fileDesc := msg.ProtoReflect().Descriptor().ParentFile() + if fd.FileDescriptor() != fileDesc { + t.Errorf("got %v, wanted %v file descriptor", fd.FileDescriptor(), fileDesc) + } +} + func TestFileDescriptionGetTypes(t *testing.T) { pbdb := NewDb() fd, err := pbdb.RegisterMessage(&proto3pb.TestAllTypes{}) diff --git a/common/types/pb/pb.go b/common/types/pb/pb.go index eadebcb04..c6fdfc695 100644 --- a/common/types/pb/pb.go +++ b/common/types/pb/pb.go @@ -42,6 +42,9 @@ type Db struct { files []*FileDescription // extensions contains the mapping between a given type name, extension name and its FieldDescription extensions map[string]map[string]*FieldDescription + + // jsonFieldNames indicates whether json-style names are supported as proto field names. + jsonFieldNames bool } // extensionsMap is a type alias to a map[typeName]map[extensionName]*FieldDescription @@ -81,13 +84,27 @@ func Merge(dstPB, srcPB proto.Message) error { return nil } +// DbOption modifies feature flags enabled on the proto database. +type DbOption func(*Db) *Db + +// JSONFieldNames configures the Db to support proto field accesses by their JSON names. +func JSONFieldNames(enabled bool) DbOption { + return func(db *Db) *Db { + db.jsonFieldNames = enabled + return db + } +} + // NewDb creates a new `pb.Db` with an empty type name to file description map. -func NewDb() *Db { +func NewDb(opts ...DbOption) *Db { pbdb := &Db{ revFileDescriptorMap: make(map[string]*FileDescription), files: []*FileDescription{}, extensions: make(extensionMap), } + for _, o := range opts { + pbdb = o(pbdb) + } // The FileDescription objects in the default db contain lazily initialized TypeDescription // values which may point to the state contained in the DefaultDb irrespective of this shallow // copy; however, the type graph for a field is idempotently computed, and is guaranteed to @@ -100,9 +117,15 @@ func NewDb() *Db { return pbdb } +// JSONFieldNames indicates whether the database is configured for proto field accesses by JSON names. +func (pbdb *Db) JSONFieldNames() bool { + return pbdb.jsonFieldNames +} + // Copy creates a copy of the current database with its own internal descriptor mapping. func (pbdb *Db) Copy() *Db { copy := NewDb() + copy.jsonFieldNames = pbdb.jsonFieldNames for _, fd := range pbdb.files { hasFile := false for _, fd2 := range copy.files { diff --git a/common/types/pb/pb_test.go b/common/types/pb/pb_test.go index 54d39477b..aef1c02b4 100644 --- a/common/types/pb/pb_test.go +++ b/common/types/pb/pb_test.go @@ -29,6 +29,69 @@ import ( tpb "google.golang.org/protobuf/types/known/timestamppb" ) +func TestDbJSONFieldNames(t *testing.T) { + pbdb := NewDb(JSONFieldNames(true)) + if pbdb.JSONFieldNames() != true { + t.Errorf("pbdb.JSONFieldNames() got %v, wanted true", pbdb.JSONFieldNames()) + } + fd, err := pbdb.RegisterMessage(&proto2pb.TestAllTypes{}) + if err != nil { + t.Fatalf("pbdb.RegisterMessage() failed: %v", err) + } + td, found := fd.GetTypeDescription("google.expr.proto2.test.TestAllTypes") + if !found { + t.Fatal("fd.GetTypeDescription() not found") + } + var fieldNames []string + for fieldName, f := range td.FieldMap() { + fieldNames = append(fieldNames, fieldName) + if f.jsonFieldName != true { + t.Error("f.jsonFieldName did not propagate") + } + } + // Note that 'group' type names don't have camelCase representations. + wantFields := []string{"singleInt32", "repeatedInt64", "nestedgroup"} + for _, want := range wantFields { + found := false + for _, field := range fieldNames { + if field == want { + found = true + break + } + } + if !found { + t.Errorf("%v field name not found", want) + } + } + copied := pbdb.Copy() + if copied.JSONFieldNames() != true { + t.Errorf("copied.JSONFieldNames() got %v, wanted true", copied.JSONFieldNames()) + } + td, found = copied.DescribeType("google.expr.proto2.test.TestAllTypes") + if !found { + t.Fatal("copied.DescribeType() not found") + } + fieldNames = []string{} + for fieldName, f := range td.FieldMap() { + fieldNames = append(fieldNames, fieldName) + if f.jsonFieldName != true { + t.Error("f.jsonFieldName did not propagate") + } + } + for _, want := range wantFields { + found := false + for _, field := range fieldNames { + if field == want { + found = true + break + } + } + if !found { + t.Errorf("%v field name not found", want) + } + } +} + func TestDbCopy(t *testing.T) { clone := DefaultDb.Copy() if !reflect.DeepEqual(clone, DefaultDb) { diff --git a/common/types/pb/type.go b/common/types/pb/type.go index 171494f07..18564642c 100644 --- a/common/types/pb/type.go +++ b/common/types/pb/type.go @@ -40,53 +40,68 @@ type description interface { // newTypeDescription produces a TypeDescription value for the fully-qualified proto type name // with a given descriptor. -func newTypeDescription(typeName string, desc protoreflect.MessageDescriptor, extensions extensionMap) *TypeDescription { +func newTypeDescription(typeName string, desc protoreflect.MessageDescriptor, pbdb *Db) *TypeDescription { msgType := dynamicpb.NewMessageType(desc) msgZero := dynamicpb.NewMessage(desc) fieldMap := map[string]*FieldDescription{} + jsonFieldMap := map[string]*FieldDescription{} fields := desc.Fields() for i := 0; i < fields.Len(); i++ { f := fields.Get(i) - fieldMap[string(f.Name())] = newFieldDescription(f) + fd := newFieldDescription(f, pbdb.jsonFieldNames) + fieldMap[fd.Name()] = fd + if pbdb.jsonFieldNames { + jsonFieldMap[fd.JSONName()] = fd + } } return &TypeDescription{ - typeName: typeName, - desc: desc, - msgType: msgType, - fieldMap: fieldMap, - extensions: extensions, - reflectType: reflectTypeOf(msgZero), - zeroMsg: zeroValueOf(msgZero), + typeName: typeName, + desc: desc, + msgType: msgType, + fieldMap: fieldMap, + jsonFieldMap: jsonFieldMap, + extensions: pbdb.extensions, + reflectType: reflectTypeOf(msgZero), + zeroMsg: zeroValueOf(msgZero), + jsonFieldNames: pbdb.jsonFieldNames, } } // TypeDescription is a collection of type metadata relevant to expression // checking and evaluation. type TypeDescription struct { - typeName string - desc protoreflect.MessageDescriptor - msgType protoreflect.MessageType - fieldMap map[string]*FieldDescription - extensions extensionMap - reflectType reflect.Type - zeroMsg proto.Message + typeName string + desc protoreflect.MessageDescriptor + msgType protoreflect.MessageType + fieldMap map[string]*FieldDescription + jsonFieldMap map[string]*FieldDescription + extensions extensionMap + reflectType reflect.Type + zeroMsg proto.Message + // jsonFieldNames indicates if the type's fields are accessible via their JSON names. + jsonFieldNames bool } // Copy copies the type description with updated references to the Db. func (td *TypeDescription) Copy(pbdb *Db) *TypeDescription { return &TypeDescription{ - typeName: td.typeName, - desc: td.desc, - msgType: td.msgType, - fieldMap: td.fieldMap, - extensions: pbdb.extensions, - reflectType: td.reflectType, - zeroMsg: td.zeroMsg, + typeName: td.typeName, + desc: td.desc, + msgType: td.msgType, + fieldMap: td.fieldMap, + jsonFieldMap: td.jsonFieldMap, + extensions: pbdb.extensions, + reflectType: td.reflectType, + zeroMsg: td.zeroMsg, + jsonFieldNames: td.jsonFieldNames, } } // FieldMap returns a string field name to FieldDescription map. func (td *TypeDescription) FieldMap() map[string]*FieldDescription { + if td.jsonFieldNames { + return td.jsonFieldMap + } return td.fieldMap } @@ -97,11 +112,15 @@ func (td *TypeDescription) FieldByName(name string) (*FieldDescription, bool) { return fd, true } extFieldMap, found := td.extensions[td.typeName] - if !found { - return nil, false + if found { + fd, found = extFieldMap[name] + return fd, found + } + if td.jsonFieldNames { + fd, found = td.jsonFieldMap[name] + return fd, found } - fd, found = extFieldMap[name] - return fd, found + return nil, false } // MaybeUnwrap accepts a proto message as input and unwraps it to a primitive CEL type if possible. @@ -132,7 +151,7 @@ func (td *TypeDescription) Zero() proto.Message { } // newFieldDescription creates a new field description from a protoreflect.FieldDescriptor. -func newFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescription { +func newFieldDescription(fieldDesc protoreflect.FieldDescriptor, jsonFieldNames bool) *FieldDescription { var reflectType reflect.Type var zeroMsg proto.Message switch fieldDesc.Kind() { @@ -168,15 +187,16 @@ func newFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescripti } var keyType, valType *FieldDescription if fieldDesc.IsMap() { - keyType = newFieldDescription(fieldDesc.MapKey()) - valType = newFieldDescription(fieldDesc.MapValue()) + keyType = newFieldDescription(fieldDesc.MapKey(), jsonFieldNames) + valType = newFieldDescription(fieldDesc.MapValue(), jsonFieldNames) } return &FieldDescription{ - desc: fieldDesc, - KeyType: keyType, - ValueType: valType, - reflectType: reflectType, - zeroMsg: zeroValueOf(zeroMsg), + desc: fieldDesc, + KeyType: keyType, + ValueType: valType, + reflectType: reflectType, + zeroMsg: zeroValueOf(zeroMsg), + jsonFieldName: jsonFieldNames, } } @@ -187,9 +207,10 @@ type FieldDescription struct { // ValueType holds the value FieldDescription for map fields. ValueType *FieldDescription - desc protoreflect.FieldDescriptor - reflectType reflect.Type - zeroMsg proto.Message + desc protoreflect.FieldDescriptor + reflectType reflect.Type + zeroMsg proto.Message + jsonFieldName bool } // CheckedType returns the type-definition used at type-check time. @@ -321,11 +342,20 @@ func (fd *FieldDescription) MaybeUnwrapDynamic(msg protoreflect.Message) (any, b return unwrapDynamic(fd, msg) } -// Name returns the CamelCase name of the field within the proto-based struct. +// Name returns the snake_case name of the field within the proto-based struct. func (fd *FieldDescription) Name() string { return string(fd.desc.Name()) } +// JSONName returns the JSON name of the field, if present. +func (fd *FieldDescription) JSONName() string { + jsonName := fd.desc.JSONName() + if len(jsonName) != 0 { + return jsonName + } + return string(fd.desc.Name()) +} + // ProtoKind returns the protobuf reflected kind of the field. func (fd *FieldDescription) ProtoKind() protoreflect.Kind { return fd.desc.Kind() diff --git a/common/types/pb/type_test.go b/common/types/pb/type_test.go index bf2ad58d6..fed1a3d0d 100644 --- a/common/types/pb/type_test.go +++ b/common/types/pb/type_test.go @@ -54,6 +54,38 @@ func TestTypeDescription(t *testing.T) { } } +func TestTypeDescriptionJSONFieldNames(t *testing.T) { + pbdb := NewDb(JSONFieldNames(true)) + msg := &proto2pb.TestAllTypes{} + msgType := string(msg.ProtoReflect().Descriptor().FullName()) + _, err := pbdb.RegisterMessage(msg) + if err != nil { + t.Fatalf("pbdb.RegisterMessage() failed: %v", err) + } + td, found := pbdb.DescribeType(msgType) + if !found { + t.Fatalf("pbdb.DescribeType(%s) not found", msgType) + } + fd, found := td.FieldByName("singleBoolWrapper") + if !found { + t.Fatal("td.FieldByName(singleBoolWrapper) failed") + } + if fd.JSONName() != "singleBoolWrapper" { + t.Fatalf("fd.JSONName() does not return the correct json name: %s", fd.JSONName()) + } + if fd.Name() != "single_bool_wrapper" { + t.Fatalf("fd.Name() does not return correct proto name: %s", fd.Name()) + } + enumName := "google.expr.proto2.test.TestAllTypes.NestedEnum.BAR" + en, found := pbdb.DescribeEnum(enumName) + if !found { + t.Fatalf("pbdb.DescribeEnum(%s) not found", enumName) + } + if en.Value() != 1 && en.Name() != enumName { + t.Errorf("got %v, wanted %s: %d", en, enumName, 1) + } +} + func TestTypeDescriptionGroupFields(t *testing.T) { pbdb := NewDb() msg := &proto2pb.TestAllTypes{} @@ -98,6 +130,26 @@ func TestTypeDescriptionFieldMap(t *testing.T) { } } +func TestTypeDescriptionJSONFieldMap(t *testing.T) { + pbdb := NewDb(JSONFieldNames(true)) + msg := &proto3pb.TestAllTypes{} + pbdb.RegisterMessage(msg) + td, found := pbdb.DescribeType("google.expr.proto3.test.TestAllTypes") + if !found { + t.Fatalf("pbdb.DescribeType(%v) not found", msg) + } + fd, found := td.FieldMap()["singleNestedMessage"] + if !found { + t.Fatal("singleNestedMessage not found") + } + if fd.Name() != "single_nested_message" { + t.Fatalf("fd.Name() got %s, wanted 'single_nested_message'", fd.Name()) + } + if fd.JSONName() != "singleNestedMessage" { + t.Fatalf("fd.JSONName() got %s, wanted 'singleNestedMessage'", fd.JSONName()) + } +} + func TestFieldDescription(t *testing.T) { pbdb := NewDb() msg := proto3pb.NestedTestAllTypes{} diff --git a/common/types/provider.go b/common/types/provider.go index 936a4e28b..ebfd66dcb 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -81,6 +81,9 @@ type FieldType struct { // GetFrom retrieves the field value on the input object, if set. GetFrom ref.FieldGetter + + // IsJSONField + IsJSONField bool } // Registry provides type information for a set of registered types. @@ -93,11 +96,50 @@ type Registry struct { // provider which can create new instances of the provided message or any // message that proto depends upon in its FileDescriptor. func NewRegistry(types ...proto.Message) (*Registry, error) { - p := &Registry{ + return NewProtoRegistry(ProtoTypes(types...)) +} + +// RegistryOption configures the behavior of the registry. +type RegistryOption func(r *Registry) (*Registry, error) + +// JSONFieldNames configures JSON field name support within the protobuf types in the registry. +func JSONFieldNames(enabled bool) RegistryOption { + return func(r *Registry) (*Registry, error) { + if enabled != r.pbdb.JSONFieldNames() { + newDB := pb.NewDb(pb.JSONFieldNames(enabled)) + files := r.pbdb.FileDescriptions() + for _, fd := range files { + _, err := newDB.RegisterDescriptor(fd.FileDescriptor()) + if err != nil { + return nil, err + } + } + r.pbdb = newDB + } + return r, nil + } +} + +// ProtoTypes creates a RegistryOption which registers the individual proto messages with the registry. +func ProtoTypes(types ...proto.Message) RegistryOption { + return func(r *Registry) (*Registry, error) { + for _, msgType := range types { + err := r.RegisterMessage(msgType) + if err != nil { + return nil, err + } + } + return r, nil + } +} + +// NewProtoRegistry creates a proto-based registry with a set of configurable options. +func NewProtoRegistry(opts ...RegistryOption) (*Registry, error) { + r := &Registry{ revTypeMap: make(map[string]*Type), pbdb: pb.NewDb(), } - err := p.RegisterType( + err := r.RegisterType( BoolType, BytesType, DoubleType, @@ -114,19 +156,19 @@ func NewRegistry(types ...proto.Message) (*Registry, error) { return nil, err } // This block ensures that the well-known protobuf types are registered by default. - for _, fd := range p.pbdb.FileDescriptions() { - err = p.registerAllTypes(fd) + for _, fd := range r.pbdb.FileDescriptions() { + err = r.registerAllTypes(fd) if err != nil { return nil, err } } - for _, msgType := range types { - err = p.RegisterMessage(msgType) + for _, opt := range opts { + r, err = opt(r) if err != nil { return nil, err } } - return p, nil + return r, nil } // NewEmptyRegistry returns a registry which is completely unconfigured. @@ -172,9 +214,11 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType, return nil, false } return &ref.FieldType{ - Type: field.CheckedType(), - IsSet: field.IsSet, - GetFrom: field.GetFrom}, true + Type: field.CheckedType(), + IsSet: field.IsSet, + GetFrom: field.GetFrom, + IsJSONField: p.pbdb.JSONFieldNames() && fieldName == field.JSONName(), + }, true } // FindStructFieldNames returns the set of field names for the given struct type, @@ -206,9 +250,11 @@ func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType return nil, false } return &FieldType{ - Type: fieldDescToCELType(field), - IsSet: field.IsSet, - GetFrom: field.GetFrom}, true + Type: fieldDescToCELType(field), + IsSet: field.IsSet, + GetFrom: field.GetFrom, + IsJSONField: p.pbdb.JSONFieldNames() && fieldName == field.JSONName(), + }, true } // FindIdent takes a qualified identifier name and returns a ref.Val if one exists. @@ -268,9 +314,8 @@ func (p *Registry) NewValue(structType string, fields map[string]ref.Val) ref.Va return NewErr("unknown type '%s'", structType) } msg := td.New() - fieldMap := td.FieldMap() for name, value := range fields { - field, found := fieldMap[name] + field, found := td.FieldByName(name) if !found { return NewErr("no such field: %s", name) } diff --git a/common/types/provider_test.go b/common/types/provider_test.go index 2576a5740..9189df449 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -134,10 +134,10 @@ func TestRegistryFindStructType(t *testing.T) { } func TestRegistryFindStructFieldNames(t *testing.T) { - reg := newTestRegistry(t, &exprpb.Decl{}, &exprpb.Reference{}) tests := []struct { - typeName string - fields []string + typeName string + fields []string + jsonFieldNames bool }{ { typeName: "google.api.expr.v1alpha1.Reference", @@ -151,11 +151,19 @@ func TestRegistryFindStructFieldNames(t *testing.T) { typeName: "invalid.TypeName", fields: []string{}, }, + { + typeName: "google.api.expr.v1alpha1.Reference", + fields: []string{"name", "overloadId", "value"}, + jsonFieldNames: true, + }, } for _, tst := range tests { tc := tst t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) { + reg := newTestRegistry(t, + ProtoTypes(&exprpb.Decl{}, &exprpb.Reference{}), + JSONFieldNames(tc.jsonFieldNames)) fields, _ := reg.FindStructFieldNames(tc.typeName) sort.Strings(fields) sort.Strings(tc.fields) @@ -167,16 +175,12 @@ func TestRegistryFindStructFieldNames(t *testing.T) { } func TestRegistryFindStructFieldType(t *testing.T) { - reg := newTestRegistry(t) - err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile()) - if err != nil { - t.Fatalf("RegisterDescriptor() failed: %v", err) - } msgTypeName := ".google.expr.proto3.test.TestAllTypes" tests := []struct { - typeName string - field string - found bool + typeName string + field string + found bool + jsonFieldNames bool }{ { typeName: msgTypeName, @@ -238,11 +242,28 @@ func TestRegistryFindStructFieldType(t *testing.T) { field: "map_string_string", found: false, }, + { + typeName: msgTypeName, + field: "mapStringString", + found: true, + jsonFieldNames: true, + }, + { + typeName: msgTypeName, + field: "map_string_string", + found: true, + jsonFieldNames: true, + }, } for _, tst := range tests { tc := tst t.Run(fmt.Sprintf("%s.%s", tc.typeName, tc.field), func(t *testing.T) { + reg := newTestRegistry(t, JSONFieldNames(tc.jsonFieldNames)) + err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile()) + if err != nil { + t.Fatalf("RegisterDescriptor() failed: %v", err) + } // When the field is expected to be found, test parity of the results if tc.found { refField, found := reg.FindFieldType(tc.typeName, tc.field) @@ -278,7 +299,7 @@ func TestRegistryFindStructFieldType(t *testing.T) { } func TestRegistryNewValue(t *testing.T) { - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}, &exprpb.SourceInfo{}) + reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})) tests := []struct { typeName string fields map[string]ref.Val @@ -406,7 +427,7 @@ func TestRegistryNewValue(t *testing.T) { } func TestRegistryNewValueErrors(t *testing.T) { - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}, &exprpb.SourceInfo{}) + reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})) tests := []struct { typeName string fields map[string]ref.Val @@ -483,7 +504,7 @@ func TestRegistryNewValueErrors(t *testing.T) { } func TestRegistryGetters(t *testing.T) { - reg := newTestRegistry(t, &exprpb.ParsedExpr{}) + reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) if sourceInfo := reg.NewValue( "google.api.expr.v1alpha1.SourceInfo", map[string]ref.Val{ @@ -519,7 +540,7 @@ func TestRegistryGetters(t *testing.T) { } func TestConvertToNative(t *testing.T) { - reg := newTestRegistry(t, &exprpb.ParsedExpr{}) + reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) // Core type conversion tests. expectValueToNative(t, True, true) @@ -584,7 +605,7 @@ func TestConvertToNative(t *testing.T) { } func TestNativeToValue_Any(t *testing.T) { - reg := newTestRegistry(t, &exprpb.ParsedExpr{}) + reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) // NullValue anyValue, err := NullValue.ConvertToNative(anyValueType) if err != nil { @@ -645,7 +666,7 @@ func TestNativeToValue_Any(t *testing.T) { } func TestNativeToValue_Json(t *testing.T) { - reg := newTestRegistry(t, &exprpb.ParsedExpr{}) + reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) // Json primitive conversion test. expectNativeToValue(t, structpb.NewBoolValue(false), False) expectNativeToValue(t, structpb.NewNumberValue(1.1), Double(1.1)) @@ -835,7 +856,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) { func expectNativeToValue(t *testing.T, in any, out ref.Val) { t.Helper() - reg := newTestRegistry(t, &exprpb.ParsedExpr{}) + reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) if val := reg.NativeToValue(in); IsError(val) { t.Error(val) } else { @@ -918,11 +939,11 @@ type testFloat32 float32 type testFloat64 float64 type testString string -func newTestRegistry(t *testing.T, types ...proto.Message) *Registry { +func newTestRegistry(t *testing.T, opts ...RegistryOption) *Registry { t.Helper() - reg, err := NewRegistry(types...) + reg, err := NewProtoRegistry(opts...) if err != nil { - t.Fatalf("NewRegistry(%v) failed: %v", types, err) + t.Fatalf("NewProtoRegistry() failed: %v", err) } return reg } diff --git a/common/types/ref/provider.go b/common/types/ref/provider.go index b9820023d..ed5ab0662 100644 --- a/common/types/ref/provider.go +++ b/common/types/ref/provider.go @@ -93,6 +93,9 @@ type FieldType struct { // GetFrom retrieves the field value on the input object, if set. GetFrom FieldGetter + + // IsJSONFIeld indicates that the field was accessed via its JSON name. + IsJSONField bool } // FieldTester is used to test field presence on an input object.