From c4a9bc058509650d7faeae50e333addf9cbf949f Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Mon, 22 Jun 2026 12:11:13 -0700 Subject: [PATCH] Add cost tracking to ext/math.go --- ext/README.md | 4 ++ ext/math.go | 61 +++++++++++++++++++----- ext/math_test.go | 118 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 11 deletions(-) diff --git a/ext/README.md b/ext/README.md index 6a7163de0..377c85ec5 100644 --- a/ext/README.md +++ b/ext/README.md @@ -66,6 +66,8 @@ intended; however, there is some chance for collision. ### Math.Greatest +**Introduced in version 0 (cost support in version 3)** + Returns the greatest valued number present in the arguments to the macro. Greatest is a variable argument count macro which must take at least one @@ -93,6 +95,8 @@ Examples: ### Math.Least +**Introduced in version 0 (cost support in version 3)** + Returns the least valued number present in the arguments to the macro. Least is a variable argument count macro which must take at least one diff --git a/ext/math.go b/ext/math.go index 6df8e3773..e67b205de 100644 --- a/ext/math.go +++ b/ext/math.go @@ -20,10 +20,12 @@ import ( "strings" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" ) // Math returns a cel.EnvOption to configure namespaced math helper macros and @@ -339,9 +341,9 @@ import ( // // Examples: // -// math.sqrt(81) // returns 9.0 -// math.sqrt(985.25) // returns 31.388692231439016 -// math.sqrt(-15) // returns NaN +// math.sqrt(81) // returns 9.0 +// math.sqrt(985.25) // returns 31.388692231439016 +// math.sqrt(-15) // returns NaN func Math(options ...MathOption) cel.EnvOption { m := &mathLib{version: math.MaxUint32} for _, o := range options { @@ -580,12 +582,35 @@ func (lib *mathLib) CompileOptions() []cel.EnvOption { ), ) } + if lib.version >= 3 { + estimators := []checker.CostOption{ + checker.OverloadCostEstimate("math_@min_list_double", estimateMathListCost), + checker.OverloadCostEstimate("math_@min_list_int", estimateMathListCost), + checker.OverloadCostEstimate("math_@min_list_uint", estimateMathListCost), + checker.OverloadCostEstimate("math_@max_list_double", estimateMathListCost), + checker.OverloadCostEstimate("math_@max_list_int", estimateMathListCost), + checker.OverloadCostEstimate("math_@max_list_uint", estimateMathListCost), + } + opts = append(opts, cel.CostEstimatorOptions(estimators...)) + } return opts } // ProgramOptions implements the Library interface method. -func (*mathLib) ProgramOptions() []cel.ProgramOption { - return []cel.ProgramOption{} +func (lib *mathLib) ProgramOptions() []cel.ProgramOption { + var opts []cel.ProgramOption + if lib.version >= 3 { + trackers := []interpreter.CostTrackerOption{ + interpreter.OverloadCostTracker("math_@min_list_double", trackMathListCost), + interpreter.OverloadCostTracker("math_@min_list_int", trackMathListCost), + interpreter.OverloadCostTracker("math_@min_list_uint", trackMathListCost), + interpreter.OverloadCostTracker("math_@max_list_double", trackMathListCost), + interpreter.OverloadCostTracker("math_@max_list_int", trackMathListCost), + interpreter.OverloadCostTracker("math_@max_list_uint", trackMathListCost), + } + opts = append(opts, cel.CostTrackerOptions(trackers...)) + } + return opts } func mathLeast(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { @@ -723,21 +748,19 @@ func sign(val ref.Val) ref.Val { } } - func sqrt(val ref.Val) ref.Val { switch v := val.(type) { case types.Double: - return types.Double(math.Sqrt(float64(v))) + return types.Double(math.Sqrt(float64(v))) case types.Int: - return types.Double(math.Sqrt(float64(v))) + return types.Double(math.Sqrt(float64(v))) case types.Uint: - return types.Double(math.Sqrt(float64(v))) + return types.Double(math.Sqrt(float64(v))) default: - return types.NewErr("no such overload: sqrt") + return types.NewErr("no such overload: sqrt") } } - func bitAndPairInt(first, second ref.Val) ref.Val { l := first.(types.Int) r := second.(types.Int) @@ -946,3 +969,19 @@ func maybeSuffixError(val ref.Val, suffix string) ref.Val { } return val } + +func estimateMathListCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if len(args) != 1 { + return nil + } + sz := estimateSize(estimator, args[0]) + cost := sz.MultiplyByCostFactor(1.0).Add(callCostEstimate) + resultSize := checker.FixedSizeEstimate(1) + return &checker.CallEstimate{CostEstimate: cost, ResultSize: &resultSize} +} + +func trackMathListCost(args []ref.Val, _ ref.Val) *uint64 { + sz := actualSize(args[0]) + cost := safeAdd(sz, callCost) + return &cost +} diff --git a/ext/math_test.go b/ext/math_test.go index 878954c75..8b47cae64 100644 --- a/ext/math_test.go +++ b/ext/math_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/types" ) @@ -667,3 +668,120 @@ func testMathEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { } return env } + +func testMathCostsEnv(t *testing.T, version int, opts ...cel.EnvOption) *cel.Env { + t.Helper() + var mathOpt cel.EnvOption + if version > 0 { + mathOpt = Math(MathVersion(uint32(version))) + } else { + mathOpt = Math() + } + baseOpts := []cel.EnvOption{ + mathOpt, + cel.EnableMacroCallTracking(), + } + env, err := cel.NewEnv(append(baseOpts, opts...)...) + if err != nil { + t.Fatalf("cel.NewEnv(Math()) failed: %v", err) + } + return env +} + +func TestMathCosts(t *testing.T) { + tests := []struct { + name string + expr string + vars []cel.EnvOption + in map[string]any + hints map[string]uint64 + estimatedCost checker.CostEstimate + actualCost uint64 + version int + }{ + { + name: "math_greatest_list_v2", + expr: "math.greatest(x) == 5", + vars: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.IntType)), + }, + in: map[string]any{ + "x": []int64{1, 2, 3, 4, 5}, + }, + hints: map[string]uint64{ + "x": 10, + }, + estimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + actualCost: 3, + version: 2, + }, + { + name: "math_greatest_list_v3", + expr: "math.greatest(x) == 5", + vars: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.IntType)), + }, + in: map[string]any{ + "x": []int64{1, 2, 3, 4, 5}, + }, + hints: map[string]uint64{ + "x": 10, + }, + estimatedCost: checker.CostEstimate{Min: 3, Max: 13}, + actualCost: 8, + version: 3, + }, + { + name: "math_least_list_v2", + expr: "math.least(x) == -3.0", + vars: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.DoubleType)), + }, + in: map[string]any{ + "x": []float64{-1.0, -2.0, -3.0}, + }, + hints: map[string]uint64{ + "x": 100, + }, + estimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + actualCost: 3, + version: 2, + }, + { + name: "math_least_list_v3", + expr: "math.least(x) == -3.0", + vars: []cel.EnvOption{ + cel.Variable("x", cel.ListType(cel.DoubleType)), + }, + in: map[string]any{ + "x": []float64{-1.0, -2.0, -3.0}, + }, + hints: map[string]uint64{ + "x": 100, + }, + estimatedCost: checker.CostEstimate{Min: 3, Max: 103}, + actualCost: 6, + version: 3, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + env := testMathCostsEnv(t, tc.version, tc.vars...) + var asts []*cel.Ast + pAst, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, pAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) + } + testCheckCost(t, env, cAst, tc.hints, tc.estimatedCost) + asts = append(asts, cAst) + for _, ast := range asts { + testEvalWithCost(t, env, ast, tc.in, tc.actualCost) + } + }) + } +}