diff --git a/cel/library.go b/cel/library.go index 10e720cc..834291bd 100644 --- a/cel/library.go +++ b/cel/library.go @@ -78,6 +78,24 @@ func (l library) CompileOptions() []cel.EnvOption { //nolint:funlen,gocyclo l.uniqueMemberOverload(cel.StringType, l.uniqueScalar), l.uniqueMemberOverload(cel.BytesType, l.uniqueBytes), ), + cel.Function("getField", + cel.Overload( + "get_field_any_string", + []*cel.Type{cel.AnyType, cel.StringType}, + cel.AnyType, + cel.FunctionBinding(func(values ...ref.Val) ref.Val { + message, ok := values[0].(traits.Indexer) + if !ok { + return types.UnsupportedRefValConversionErr(values[0]) + } + fieldName, ok := values[1].Value().(string) + if !ok { + return types.UnsupportedRefValConversionErr(values[1]) + } + return message.Get(types.String(fieldName)) + }), + ), + ), cel.Function("isNan", cel.MemberOverload( "double_is_nan_bool", diff --git a/cel/library_test.go b/cel/library_test.go index 286dc5e8..e5ec2e4f 100644 --- a/cel/library_test.go +++ b/cel/library_test.go @@ -17,7 +17,10 @@ package cel import ( "testing" + "github.com/bufbuild/protovalidate-go/internal/gen/buf/validate/conformance/cases" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -26,7 +29,23 @@ import ( func TestCELLib(t *testing.T) { t.Parallel() - env, err := cel.NewEnv(cel.Lib(NewLibrary())) + testValue := cases.StringConst_builder{Val: "test_string"}.Build() + + activation, err := interpreter.NewActivation(map[string]any{ + "test": testValue, + }) + require.NoError(t, err) + + env, err := cel.NewEnv( + cel.Lib(NewLibrary()), + cel.Variable( + "test", + cel.ObjectType( + string(testValue.ProtoReflect().Descriptor().FullName()), + ), + ), + ) + require.NoError(t, err) t.Run("ext", func(t *testing.T) { @@ -34,7 +53,7 @@ func TestCELLib(t *testing.T) { tests := []struct { expr string - ex bool + ex any }{ {"0.0.isInf()", false}, {"0.0.isNan()", false}, @@ -197,6 +216,18 @@ func TestCELLib(t *testing.T) { "'foo@example.com '.isEmail()", false, }, + { + "getField(test, 'val')", + "test_string", + }, + { + "getField(test, 'lav')", + types.NewErrFromString("no such field"), + }, + { + "getField(0, 'val')", + types.NewErrFromString("unsupported conversion"), + }, } for _, tc := range tests { @@ -204,11 +235,15 @@ func TestCELLib(t *testing.T) { t.Run(test.expr, func(t *testing.T) { t.Parallel() prog := buildTestProgram(t, env, test.expr) - val, _, err := prog.Eval(interpreter.EmptyActivation()) - require.NoError(t, err) - isUnique, ok := val.Value().(bool) - require.True(t, ok) - assert.Equal(t, test.ex, isUnique) + val, _, err := prog.Eval(activation) + if refEx, ok := test.ex.(ref.Val); ok && types.IsError(refEx) { + refErr, ok := refEx.Value().(error) + require.True(t, ok) + assert.ErrorContains(t, err, refErr.Error()) + } else { + require.NoError(t, err) + assert.Equal(t, test.ex, val.Value()) + } }) } })