From f014944d685e301a26198fe5931acb3101788ca4 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 3 Mar 2026 15:34:18 -0800 Subject: [PATCH] Simplify recursive planning limit checks. Instead of planning recursively up to the limit, the behavior now is to check if the AST is small enough to fit in the limit and either fail planning or swap back to the stack machine implementation This simplifies program planning avoids some unpredictable behavior around wrapping recursive sub-programs in stack machine steps (high overhead in deep but unbalanced ASTs). PiperOrigin-RevId: 878154486 --- eval/compiler/BUILD | 2 + eval/compiler/flat_expr_builder.cc | 296 +++++++++--------- eval/compiler/flat_expr_builder_extensions.h | 23 ++ eval/public/cel_options.cc | 1 + eval/public/cel_options.h | 14 +- .../expression_builder_benchmark_test.cc | 68 +++- runtime/runtime_options.h | 7 + 7 files changed, 253 insertions(+), 158 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e82b0ce13..6d462bd12 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -33,6 +33,7 @@ cc_library( "//base:data", "//common:expr", "//common:native_type", + "//common:navigable_ast", "//common:value", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", @@ -104,6 +105,7 @@ cc_library( "//common:ast", "//common:ast_traverse", "//common:ast_visitor", + "//common:ast_visitor_base", "//common:constant", "//common:expr", "//common:kind", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a0fd427bd..5b01fe4da 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -54,6 +54,7 @@ #include "common/ast.h" #include "common/ast_traverse.h" #include "common/ast_visitor.h" +#include "common/ast_visitor_base.h" #include "common/constant.h" #include "common/expr.h" #include "common/kind.h" @@ -109,6 +110,13 @@ constexpr absl::string_view kBlock = "cel.@block"; // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; +// Error code for failed recursive program building. Generally indicates an +// optimization doesn't support recursive programs. +absl::Status FailedRecursivePlanning() { + return absl::InternalError( + "failed to build recursive program. check for unsupported optimizations"); +} + // Helper for bookkeeping variables mapped to indexes. class IndexManager { public: @@ -577,6 +585,10 @@ class FlatExprVisitor : public cel::AstVisitor { } } + void SetPlanRecursiveProgram() { plan_recursive_program_ = true; } + + bool PlanRecursiveProgram() const { return plan_recursive_program_; } + void PreVisitExpr(const cel::Expr& expr) override { ValidateOrError(!absl::holds_alternative(expr.kind()), "Invalid empty expression"); @@ -947,8 +959,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { SetProgressStatusError(absl::InternalError( @@ -1064,21 +1075,13 @@ class FlatExprVisitor : public cel::AstVisitor { } } + // Returns the maximum recursion depth of the current program if it is + // eligible for recursion, or nullopt if it is not. absl::optional RecursionEligible() { - if (program_builder_.current() == nullptr) { + if (!plan_recursive_program_ || program_builder_.current() == nullptr) { return absl::nullopt; } - absl::optional depth = - program_builder_.current()->RecursiveDependencyDepth(); - if (!depth.has_value()) { - // one or more of the dependencies isn't eligible. - return depth; - } - if (options_.max_recursion_depth < 0 || - *depth < options_.max_recursion_depth) { - return depth; - } - return absl::nullopt; + return program_builder_.current()->RecursiveDependencyDepth(); } std::vector> @@ -1089,10 +1092,7 @@ class FlatExprVisitor : public cel::AstVisitor { return program_builder_.current()->ExtractRecursiveDependencies(); } - void MaybeMakeTernaryRecursive(const cel::Expr* expr) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); @@ -1107,26 +1107,16 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { + if (condition_plan == nullptr || !condition_plan->IsRecursive() || + left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, condition_plan->recursive_program().depth); - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, condition_plan->recursive_program().depth, + left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep( CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, @@ -1136,10 +1126,7 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth + 1); } - void MaybeMakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { if (expr->call_expr().args().size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); @@ -1151,21 +1138,14 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); if (is_or) { SetRecursiveStep( @@ -1182,11 +1162,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - void MaybeMakeOptionalShortcircuitRecursive(const cel::Expr* expr, - bool is_or_value) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -1199,21 +1175,13 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep(CreateDirectOptionalOrStep( expr->id(), left_plan->ExtractRecursiveProgram().step, @@ -1225,7 +1193,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeBindRecursive(const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!plan_recursive_program_) { return; } @@ -1233,16 +1201,12 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } int result_depth = result_plan->recursive_program().depth; - if (options_.max_recursion_depth > 0 && - result_depth >= options_.max_recursion_depth) { - return; - } - auto program = result_plan->ExtractRecursiveProgram(); SetRecursiveStep( CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), @@ -1252,42 +1216,26 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeComprehensionRecursive( const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t iter_slot, size_t iter2_slot, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!plan_recursive_program_) { return; } auto* accu_plan = program_builder_.GetSubexpression(&comprehension->accu_init()); - - if (accu_plan == nullptr || !accu_plan->IsRecursive()) { - return; - } - auto* range_plan = program_builder_.GetSubexpression(&comprehension->iter_range()); - - if (range_plan == nullptr || !range_plan->IsRecursive()) { - return; - } - auto* loop_plan = program_builder_.GetSubexpression(&comprehension->loop_step()); - - if (loop_plan == nullptr || !loop_plan->IsRecursive()) { - return; - } - auto* condition_plan = program_builder_.GetSubexpression(&comprehension->loop_condition()); - - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { - return; - } - auto* result_plan = program_builder_.GetSubexpression(&comprehension->result()); - - if (result_plan == nullptr || !result_plan->IsRecursive()) { + if (accu_plan == nullptr || !accu_plan->IsRecursive() || + range_plan == nullptr || !range_plan->IsRecursive() || + loop_plan == nullptr || !loop_plan->IsRecursive() || + condition_plan == nullptr || !condition_plan->IsRecursive() || + result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } @@ -1298,11 +1246,6 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth = std::max(max_depth, condition_plan->recursive_program().depth); max_depth = std::max(max_depth, result_plan->recursive_program().depth); - if (options_.max_recursion_depth > 0 && - max_depth >= options_.max_recursion_depth) { - return; - } - auto step = CreateDirectComprehensionStep( iter_slot, iter2_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, @@ -1566,7 +1509,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_list_append) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (plan_recursive_program_) { SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); return; } @@ -1579,8 +1522,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - absl::optional depth = RecursionEligible(); - if (depth.has_value()) { + if (absl::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( @@ -1614,8 +1556,7 @@ class FlatExprVisitor : public cel::AstVisitor { std::vector fields = std::move(status_or_resolved_fields.value().second); - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { SetProgressStatusError(absl::InternalError( @@ -1646,7 +1587,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_map_insert) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (plan_recursive_program_) { SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); return; } @@ -1656,8 +1597,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { SetProgressStatusError(absl::InternalError( @@ -1696,8 +1636,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto lazy_overloads = resolver_.FindLazyOverloads( function, call_expr->has_target(), num_args, expr->id()); if (!lazy_overloads.empty()) { - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), @@ -1727,8 +1666,9 @@ class FlatExprVisitor : public cel::AstVisitor { return; } } - auto recursion_depth = RecursionEligible(); - if (recursion_depth.has_value()) { + + if (auto recursion_depth = RecursionEligible(); + recursion_depth.has_value()) { // Nonnull while active -- nullptr indicates logic error elsewhere in the // builder. ABSL_DCHECK(program_builder_.current() != nullptr); @@ -1980,17 +1920,17 @@ class FlatExprVisitor : public cel::AstVisitor { IssueCollector& issue_collector_; ProgramBuilder& program_builder_; - PlannerContext extension_context_; + PlannerContext& extension_context_; IndexManager index_manager_; bool enable_optional_types_; absl::optional block_; + bool plan_recursive_program_ = false; }; FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); - auto depth = RecursionEligible(); if (!ValidateOrError( (call_expr.args().size() == 2 && !call_expr.has_target()) || // TODO(uncreated-issue/79): A few clients use the index operator with a @@ -2000,7 +1940,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2027,9 +1967,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2046,15 +1984,13 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( const cel::Expr& expr, const cel::CallExpr& call_expr) { - auto depth = RecursionEligible(); - if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), "unexpected number of args for builtin " "not_strictly_false operator")) { return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError( @@ -2155,9 +2091,8 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( "unexpected number of args for builtin equality operator")) { return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2182,8 +2117,7 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2221,6 +2155,9 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && arg_num == 0 && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, @@ -2248,6 +2185,9 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || cond_ == BinaryCond::kOptionalOrValue)) { // If first branch evaluation result is enough to determine output, @@ -2275,6 +2215,28 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); + break; + case BinaryCond::kOr: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); + break; + case BinaryCond::kOptionalOr: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/false); + break; + case BinaryCond::kOptionalOrValue: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/true); + break; + default: + ABSL_UNREACHABLE(); + } + return; + } + switch (cond_) { case BinaryCond::kAnd: visitor_->AddStep(CreateAndStep(expr->id())); @@ -2298,26 +2260,6 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_step_.set_target(visitor_->GetCurrentIndex())); } - // Handle maybe replacing the subprogram with a recursive version. This needs - // to happen after the jump step is updated (though it may get overwritten). - switch (cond_) { - case BinaryCond::kAnd: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/false); - break; - case BinaryCond::kOr: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/true); - break; - case BinaryCond::kOptionalOr: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/false); - break; - case BinaryCond::kOptionalOrValue: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/true); - break; - default: - ABSL_UNREACHABLE(); - } } void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2327,6 +2269,9 @@ void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } // Ternary operator "_?_:_" requires a special handing. // In contrary to regular function call, its execution affects the control // flow of the overall CEL expression. @@ -2380,6 +2325,10 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } // Determine and set jump offset in jump instruction. if (visitor_->ValidateOrError( error_jump_.exists(), @@ -2393,7 +2342,6 @@ void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } - visitor_->MaybeMakeTernaryRecursive(expr); } void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2403,8 +2351,11 @@ void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } visitor_->AddStep(CreateTernaryStep(expr->id())); - visitor_->MaybeMakeTernaryRecursive(expr); } void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { @@ -2417,6 +2368,9 @@ void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { absl::Status ComprehensionVisitor::PostVisitArgDefault( cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return absl::OkStatus(); + } switch (arg_num) { case cel::ITER_RANGE: { init_step_pos_ = visitor_->GetCurrentIndex(); @@ -2491,6 +2445,9 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } switch (arg_num) { case cel::ITER_RANGE: { break; @@ -2548,6 +2505,24 @@ std::vector FlattenExpressionTable( return subexpression_indexes; } +int EstimateMaxRecursionDepth(const Ast& ast, int max_recursion_depth) { + class Visitor : public cel::AstVisitorBase { + public: + void PreVisitExpr(const cel::Expr& expr) override { current_depth_++; } + void PostVisitExpr(const cel::Expr& expr) override { current_depth_--; } + int current_depth_ = 0; + } visitor; + int max_depth = 0; + auto traversal = cel::AstTraversal::Create(ast.root_expr()); + while (traversal.Step(visitor)) { + max_depth = std::max(max_depth, visitor.current_depth_); + if (max_depth > max_recursion_depth) { + return max_depth; + } + } + return max_depth; +} + } // namespace absl::StatusOr FlatExprBuilder::CreateExpressionImpl( @@ -2590,6 +2565,27 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( issue_collector, program_builder, extension_context, enable_optional_types_); + if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { + if (!ast) { + return absl::InternalError("AST is null after transforms"); + } + + // Avoid scanning the AST if we don't need to with a max_recursion_depth + // of -1. + if (options_.max_recursion_depth < 0) { + visitor.SetPlanRecursiveProgram(); + } else if (int max_depth = EstimateMaxRecursionDepth( + *ast, options_.max_recursion_depth); + max_depth > options_.max_recursion_depth) { + if (!options_.enable_recursive_planning_fail_over) { + return absl::InvalidArgumentError( + "expression exceeds maximum recursion depth."); + } + } else { + visitor.SetPlanRecursiveProgram(); + } + } + cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(ast->root_expr(), visitor, opts); diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index 21e37b2a8..15b005659 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -40,6 +40,7 @@ #include "base/type_provider.h" #include "common/expr.h" #include "common/native_type.h" +#include "common/navigable_ast.h" #include "common/type_reflector.h" #include "eval/compiler/resolver.h" #include "eval/eval/direct_expression_step.h" @@ -419,6 +420,26 @@ class PlannerContext { : environment_->MutableMessageFactory(); } + void set_ast(const cel::Ast* ast) { ast_ = ast; } + + const cel::Ast* ast() const { return ast_; } + + // Returns a navigable AST for the expression. + // + // This is only valid while the AST is being planned and should not be + // persisted or accessed after the expression is built. + // + // During the AST transforms, the AST structure may change and require + // an explicit refresh via clear_navigable_ast(). + const cel::NavigableAst& navigable_ast() { + if (!navigable_ast_ && ast_ != nullptr) { + navigable_ast_ = cel::NavigableAst::Build(ast_->root_expr()); + } + return navigable_ast_; + } + + void clear_navigable_ast() { navigable_ast_ = cel::NavigableAst(); } + private: const std::shared_ptr environment_; const Resolver& resolver_; @@ -429,6 +450,8 @@ class PlannerContext { std::shared_ptr& arena_; const bool explicit_arena_; const std::shared_ptr message_factory_; + const cel::Ast* ast_ = nullptr; + cel::NavigableAst navigable_ast_; }; // Interface for Ast Transforms. diff --git a/eval/public/cel_options.cc b/eval/public/cel_options.cc index 938b5e96f..a0a5d8dab 100644 --- a/eval/public/cel_options.cc +++ b/eval/public/cel_options.cc @@ -40,6 +40,7 @@ cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) { options.enable_empty_wrapper_null_unboxing, options.enable_lazy_bind_initialization, options.max_recursion_depth, + options.enable_recursive_planning_fail_over, options.enable_recursive_tracing, options.enable_fast_builtins}; } diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 779839583..d340458ad 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -179,11 +179,23 @@ struct InterpreterOptions { // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; + // If true, the planner will switch to the heap-based stack machine for any + // program that exceeds the `max_recursion_depth`. + // + // If false, the planner will return an error if the maximum recursion depth + // is exceeded. + bool enable_recursive_planning_fail_over = false; + // Enable tracing support for recursively planned programs. // // Unlike the stack machine implementation, supporting tracing can affect diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index c26a7cd5c..f1b84e287 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -50,8 +50,24 @@ using google::api::expr::parser::Parse; enum BenchmarkParam : int { kDefault = 0, kFoldConstants = 1, + kRecursivePlanning = 2, + kRecursivePlanningWithConstantFolding = 3, }; +absl::string_view LabelForParam(BenchmarkParam param) { + switch (param) { + case BenchmarkParam::kDefault: + return "default"; + case BenchmarkParam::kFoldConstants: + return "fold_constants"; + case BenchmarkParam::kRecursivePlanning: + return "recursive_planning"; + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + return "recursive_planning_with_constant_folding"; + } + return "unknown"; +} + void BM_RegisterBuiltins(benchmark::State& state) { for (auto _ : state) { auto builder = CreateCelExpressionBuilder(); @@ -64,21 +80,34 @@ BENCHMARK(BM_RegisterBuiltins); InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { InterpreterOptions options; - switch (param) { case BenchmarkParam::kFoldConstants: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: options.constant_arena = &arena; options.constant_folding = true; break; case BenchmarkParam::kDefault: + case BenchmarkParam::kRecursivePlanning: options.constant_folding = false; break; } + switch (param) { + case BenchmarkParam::kRecursivePlanning: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + options.enable_recursive_planning_fail_over = true; + options.max_recursion_depth = 48; + break; + case BenchmarkParam::kDefault: + case BenchmarkParam::kFoldConstants: + options.max_recursion_depth = 0; + break; + } return options; } void BM_SymbolicPolicy(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && @@ -105,7 +134,9 @@ void BM_SymbolicPolicy(benchmark::State& state) { BENCHMARK(BM_SymbolicPolicy) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); absl::StatusOr> MakeBuilderForEnums( absl::string_view container, absl::string_view enum_type, @@ -209,6 +240,7 @@ BENCHMARK(BM_EnumResolution256Candidate)->ThreadRange(1, 32); void BM_NestedComprehension(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) @@ -231,10 +263,13 @@ void BM_NestedComprehension(benchmark::State& state) { BENCHMARK(BM_NestedComprehension) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_Comparisons(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( v11 < v12 && v12 < v13 @@ -260,7 +295,9 @@ void BM_Comparisons(benchmark::State& state) { BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_ComparisonsConcurrent(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( @@ -290,6 +327,8 @@ BENCHMARK(BM_ComparisonsConcurrent)->ThreadRange(1, 32); void RegexPrecompilationBench(bool enabled, benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(absl::StrCat(LabelForParam(param), "_", + enabled ? "enabled" : "disabled")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || @@ -325,7 +364,9 @@ void BM_RegexPrecompilationDisabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationDisabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_RegexPrecompilationEnabled(benchmark::State& state) { RegexPrecompilationBench(true, state); @@ -333,10 +374,13 @@ void BM_RegexPrecompilationEnabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationEnabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); auto size = state.range(1); std::string source = "'1234567890' + '1234567890'"; @@ -377,7 +421,17 @@ BENCHMARK(BM_StringConcat) ->Args({BenchmarkParam::kFoldConstants, 4}) ->Args({BenchmarkParam::kFoldConstants, 8}) ->Args({BenchmarkParam::kFoldConstants, 16}) - ->Args({BenchmarkParam::kFoldConstants, 32}); + ->Args({BenchmarkParam::kFoldConstants, 32}) + ->Args({BenchmarkParam::kRecursivePlanning, 2}) + ->Args({BenchmarkParam::kRecursivePlanning, 4}) + ->Args({BenchmarkParam::kRecursivePlanning, 8}) + ->Args({BenchmarkParam::kRecursivePlanning, 16}) + ->Args({BenchmarkParam::kRecursivePlanning, 32}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 2}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 4}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 8}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 16}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 32}); void BM_StringConcat32Concurrent(benchmark::State& state) { std::string source = "'1234567890' + '1234567890'"; diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 1e18fef95..bf489f24e 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -152,6 +152,13 @@ struct RuntimeOptions { // -1 means unbounded. int max_recursion_depth = 0; + // If true, the planner will switch to the heap-based stack machine for any + // program that exceeds the `max_recursion_depth`. + // + // If false, the planner will return an error if the maximum recursion depth + // is exceeded. + bool enable_recursive_planning_fail_over = false; + // Enable tracing support for recursively planned programs. // // Unlike the stack machine implementation, supporting tracing can affect