diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e82b0ce13..6d462bd12 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -33,6 +33,7 @@ cc_library( "//base:data", "//common:expr", "//common:native_type", + "//common:navigable_ast", "//common:value", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", @@ -104,6 +105,7 @@ cc_library( "//common:ast", "//common:ast_traverse", "//common:ast_visitor", + "//common:ast_visitor_base", "//common:constant", "//common:expr", "//common:kind", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a0fd427bd..5b01fe4da 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -54,6 +54,7 @@ #include "common/ast.h" #include "common/ast_traverse.h" #include "common/ast_visitor.h" +#include "common/ast_visitor_base.h" #include "common/constant.h" #include "common/expr.h" #include "common/kind.h" @@ -109,6 +110,13 @@ constexpr absl::string_view kBlock = "cel.@block"; // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; +// Error code for failed recursive program building. Generally indicates an +// optimization doesn't support recursive programs. +absl::Status FailedRecursivePlanning() { + return absl::InternalError( + "failed to build recursive program. check for unsupported optimizations"); +} + // Helper for bookkeeping variables mapped to indexes. class IndexManager { public: @@ -577,6 +585,10 @@ class FlatExprVisitor : public cel::AstVisitor { } } + void SetPlanRecursiveProgram() { plan_recursive_program_ = true; } + + bool PlanRecursiveProgram() const { return plan_recursive_program_; } + void PreVisitExpr(const cel::Expr& expr) override { ValidateOrError(!absl::holds_alternative(expr.kind()), "Invalid empty expression"); @@ -947,8 +959,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { SetProgressStatusError(absl::InternalError( @@ -1064,21 +1075,13 @@ class FlatExprVisitor : public cel::AstVisitor { } } + // Returns the maximum recursion depth of the current program if it is + // eligible for recursion, or nullopt if it is not. absl::optional RecursionEligible() { - if (program_builder_.current() == nullptr) { + if (!plan_recursive_program_ || program_builder_.current() == nullptr) { return absl::nullopt; } - absl::optional depth = - program_builder_.current()->RecursiveDependencyDepth(); - if (!depth.has_value()) { - // one or more of the dependencies isn't eligible. - return depth; - } - if (options_.max_recursion_depth < 0 || - *depth < options_.max_recursion_depth) { - return depth; - } - return absl::nullopt; + return program_builder_.current()->RecursiveDependencyDepth(); } std::vector> @@ -1089,10 +1092,7 @@ class FlatExprVisitor : public cel::AstVisitor { return program_builder_.current()->ExtractRecursiveDependencies(); } - void MaybeMakeTernaryRecursive(const cel::Expr* expr) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); @@ -1107,26 +1107,16 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { + if (condition_plan == nullptr || !condition_plan->IsRecursive() || + left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, condition_plan->recursive_program().depth); - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, condition_plan->recursive_program().depth, + left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep( CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, @@ -1136,10 +1126,7 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth + 1); } - void MaybeMakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { if (expr->call_expr().args().size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); @@ -1151,21 +1138,14 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); if (is_or) { SetRecursiveStep( @@ -1182,11 +1162,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - void MaybeMakeOptionalShortcircuitRecursive(const cel::Expr* expr, - bool is_or_value) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -1199,21 +1175,13 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep(CreateDirectOptionalOrStep( expr->id(), left_plan->ExtractRecursiveProgram().step, @@ -1225,7 +1193,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeBindRecursive(const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!plan_recursive_program_) { return; } @@ -1233,16 +1201,12 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } int result_depth = result_plan->recursive_program().depth; - if (options_.max_recursion_depth > 0 && - result_depth >= options_.max_recursion_depth) { - return; - } - auto program = result_plan->ExtractRecursiveProgram(); SetRecursiveStep( CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), @@ -1252,42 +1216,26 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeComprehensionRecursive( const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t iter_slot, size_t iter2_slot, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!plan_recursive_program_) { return; } auto* accu_plan = program_builder_.GetSubexpression(&comprehension->accu_init()); - - if (accu_plan == nullptr || !accu_plan->IsRecursive()) { - return; - } - auto* range_plan = program_builder_.GetSubexpression(&comprehension->iter_range()); - - if (range_plan == nullptr || !range_plan->IsRecursive()) { - return; - } - auto* loop_plan = program_builder_.GetSubexpression(&comprehension->loop_step()); - - if (loop_plan == nullptr || !loop_plan->IsRecursive()) { - return; - } - auto* condition_plan = program_builder_.GetSubexpression(&comprehension->loop_condition()); - - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { - return; - } - auto* result_plan = program_builder_.GetSubexpression(&comprehension->result()); - - if (result_plan == nullptr || !result_plan->IsRecursive()) { + if (accu_plan == nullptr || !accu_plan->IsRecursive() || + range_plan == nullptr || !range_plan->IsRecursive() || + loop_plan == nullptr || !loop_plan->IsRecursive() || + condition_plan == nullptr || !condition_plan->IsRecursive() || + result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } @@ -1298,11 +1246,6 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth = std::max(max_depth, condition_plan->recursive_program().depth); max_depth = std::max(max_depth, result_plan->recursive_program().depth); - if (options_.max_recursion_depth > 0 && - max_depth >= options_.max_recursion_depth) { - return; - } - auto step = CreateDirectComprehensionStep( iter_slot, iter2_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, @@ -1566,7 +1509,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_list_append) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (plan_recursive_program_) { SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); return; } @@ -1579,8 +1522,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - absl::optional depth = RecursionEligible(); - if (depth.has_value()) { + if (absl::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( @@ -1614,8 +1556,7 @@ class FlatExprVisitor : public cel::AstVisitor { std::vector fields = std::move(status_or_resolved_fields.value().second); - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { SetProgressStatusError(absl::InternalError( @@ -1646,7 +1587,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_map_insert) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (plan_recursive_program_) { SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); return; } @@ -1656,8 +1597,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { SetProgressStatusError(absl::InternalError( @@ -1696,8 +1636,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto lazy_overloads = resolver_.FindLazyOverloads( function, call_expr->has_target(), num_args, expr->id()); if (!lazy_overloads.empty()) { - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), @@ -1727,8 +1666,9 @@ class FlatExprVisitor : public cel::AstVisitor { return; } } - auto recursion_depth = RecursionEligible(); - if (recursion_depth.has_value()) { + + if (auto recursion_depth = RecursionEligible(); + recursion_depth.has_value()) { // Nonnull while active -- nullptr indicates logic error elsewhere in the // builder. ABSL_DCHECK(program_builder_.current() != nullptr); @@ -1980,17 +1920,17 @@ class FlatExprVisitor : public cel::AstVisitor { IssueCollector& issue_collector_; ProgramBuilder& program_builder_; - PlannerContext extension_context_; + PlannerContext& extension_context_; IndexManager index_manager_; bool enable_optional_types_; absl::optional block_; + bool plan_recursive_program_ = false; }; FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); - auto depth = RecursionEligible(); if (!ValidateOrError( (call_expr.args().size() == 2 && !call_expr.has_target()) || // TODO(uncreated-issue/79): A few clients use the index operator with a @@ -2000,7 +1940,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2027,9 +1967,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2046,15 +1984,13 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( const cel::Expr& expr, const cel::CallExpr& call_expr) { - auto depth = RecursionEligible(); - if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), "unexpected number of args for builtin " "not_strictly_false operator")) { return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError( @@ -2155,9 +2091,8 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( "unexpected number of args for builtin equality operator")) { return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2182,8 +2117,7 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2221,6 +2155,9 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && arg_num == 0 && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, @@ -2248,6 +2185,9 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || cond_ == BinaryCond::kOptionalOrValue)) { // If first branch evaluation result is enough to determine output, @@ -2275,6 +2215,28 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); + break; + case BinaryCond::kOr: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); + break; + case BinaryCond::kOptionalOr: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/false); + break; + case BinaryCond::kOptionalOrValue: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/true); + break; + default: + ABSL_UNREACHABLE(); + } + return; + } + switch (cond_) { case BinaryCond::kAnd: visitor_->AddStep(CreateAndStep(expr->id())); @@ -2298,26 +2260,6 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_step_.set_target(visitor_->GetCurrentIndex())); } - // Handle maybe replacing the subprogram with a recursive version. This needs - // to happen after the jump step is updated (though it may get overwritten). - switch (cond_) { - case BinaryCond::kAnd: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/false); - break; - case BinaryCond::kOr: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/true); - break; - case BinaryCond::kOptionalOr: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/false); - break; - case BinaryCond::kOptionalOrValue: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/true); - break; - default: - ABSL_UNREACHABLE(); - } } void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2327,6 +2269,9 @@ void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } // Ternary operator "_?_:_" requires a special handing. // In contrary to regular function call, its execution affects the control // flow of the overall CEL expression. @@ -2380,6 +2325,10 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } // Determine and set jump offset in jump instruction. if (visitor_->ValidateOrError( error_jump_.exists(), @@ -2393,7 +2342,6 @@ void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } - visitor_->MaybeMakeTernaryRecursive(expr); } void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2403,8 +2351,11 @@ void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } visitor_->AddStep(CreateTernaryStep(expr->id())); - visitor_->MaybeMakeTernaryRecursive(expr); } void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { @@ -2417,6 +2368,9 @@ void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { absl::Status ComprehensionVisitor::PostVisitArgDefault( cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return absl::OkStatus(); + } switch (arg_num) { case cel::ITER_RANGE: { init_step_pos_ = visitor_->GetCurrentIndex(); @@ -2491,6 +2445,9 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } switch (arg_num) { case cel::ITER_RANGE: { break; @@ -2548,6 +2505,24 @@ std::vector FlattenExpressionTable( return subexpression_indexes; } +int EstimateMaxRecursionDepth(const Ast& ast, int max_recursion_depth) { + class Visitor : public cel::AstVisitorBase { + public: + void PreVisitExpr(const cel::Expr& expr) override { current_depth_++; } + void PostVisitExpr(const cel::Expr& expr) override { current_depth_--; } + int current_depth_ = 0; + } visitor; + int max_depth = 0; + auto traversal = cel::AstTraversal::Create(ast.root_expr()); + while (traversal.Step(visitor)) { + max_depth = std::max(max_depth, visitor.current_depth_); + if (max_depth > max_recursion_depth) { + return max_depth; + } + } + return max_depth; +} + } // namespace absl::StatusOr FlatExprBuilder::CreateExpressionImpl( @@ -2590,6 +2565,27 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( issue_collector, program_builder, extension_context, enable_optional_types_); + if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { + if (!ast) { + return absl::InternalError("AST is null after transforms"); + } + + // Avoid scanning the AST if we don't need to with a max_recursion_depth + // of -1. + if (options_.max_recursion_depth < 0) { + visitor.SetPlanRecursiveProgram(); + } else if (int max_depth = EstimateMaxRecursionDepth( + *ast, options_.max_recursion_depth); + max_depth > options_.max_recursion_depth) { + if (!options_.enable_recursive_planning_fail_over) { + return absl::InvalidArgumentError( + "expression exceeds maximum recursion depth."); + } + } else { + visitor.SetPlanRecursiveProgram(); + } + } + cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(ast->root_expr(), visitor, opts); diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index 21e37b2a8..15b005659 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -40,6 +40,7 @@ #include "base/type_provider.h" #include "common/expr.h" #include "common/native_type.h" +#include "common/navigable_ast.h" #include "common/type_reflector.h" #include "eval/compiler/resolver.h" #include "eval/eval/direct_expression_step.h" @@ -419,6 +420,26 @@ class PlannerContext { : environment_->MutableMessageFactory(); } + void set_ast(const cel::Ast* ast) { ast_ = ast; } + + const cel::Ast* ast() const { return ast_; } + + // Returns a navigable AST for the expression. + // + // This is only valid while the AST is being planned and should not be + // persisted or accessed after the expression is built. + // + // During the AST transforms, the AST structure may change and require + // an explicit refresh via clear_navigable_ast(). + const cel::NavigableAst& navigable_ast() { + if (!navigable_ast_ && ast_ != nullptr) { + navigable_ast_ = cel::NavigableAst::Build(ast_->root_expr()); + } + return navigable_ast_; + } + + void clear_navigable_ast() { navigable_ast_ = cel::NavigableAst(); } + private: const std::shared_ptr environment_; const Resolver& resolver_; @@ -429,6 +450,8 @@ class PlannerContext { std::shared_ptr& arena_; const bool explicit_arena_; const std::shared_ptr message_factory_; + const cel::Ast* ast_ = nullptr; + cel::NavigableAst navigable_ast_; }; // Interface for Ast Transforms. diff --git a/eval/public/cel_options.cc b/eval/public/cel_options.cc index 938b5e96f..a0a5d8dab 100644 --- a/eval/public/cel_options.cc +++ b/eval/public/cel_options.cc @@ -40,6 +40,7 @@ cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) { options.enable_empty_wrapper_null_unboxing, options.enable_lazy_bind_initialization, options.max_recursion_depth, + options.enable_recursive_planning_fail_over, options.enable_recursive_tracing, options.enable_fast_builtins}; } diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 779839583..d340458ad 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -179,11 +179,23 @@ struct InterpreterOptions { // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; + // If true, the planner will switch to the heap-based stack machine for any + // program that exceeds the `max_recursion_depth`. + // + // If false, the planner will return an error if the maximum recursion depth + // is exceeded. + bool enable_recursive_planning_fail_over = false; + // Enable tracing support for recursively planned programs. // // Unlike the stack machine implementation, supporting tracing can affect diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index c26a7cd5c..f1b84e287 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -50,8 +50,24 @@ using google::api::expr::parser::Parse; enum BenchmarkParam : int { kDefault = 0, kFoldConstants = 1, + kRecursivePlanning = 2, + kRecursivePlanningWithConstantFolding = 3, }; +absl::string_view LabelForParam(BenchmarkParam param) { + switch (param) { + case BenchmarkParam::kDefault: + return "default"; + case BenchmarkParam::kFoldConstants: + return "fold_constants"; + case BenchmarkParam::kRecursivePlanning: + return "recursive_planning"; + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + return "recursive_planning_with_constant_folding"; + } + return "unknown"; +} + void BM_RegisterBuiltins(benchmark::State& state) { for (auto _ : state) { auto builder = CreateCelExpressionBuilder(); @@ -64,21 +80,34 @@ BENCHMARK(BM_RegisterBuiltins); InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { InterpreterOptions options; - switch (param) { case BenchmarkParam::kFoldConstants: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: options.constant_arena = &arena; options.constant_folding = true; break; case BenchmarkParam::kDefault: + case BenchmarkParam::kRecursivePlanning: options.constant_folding = false; break; } + switch (param) { + case BenchmarkParam::kRecursivePlanning: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + options.enable_recursive_planning_fail_over = true; + options.max_recursion_depth = 48; + break; + case BenchmarkParam::kDefault: + case BenchmarkParam::kFoldConstants: + options.max_recursion_depth = 0; + break; + } return options; } void BM_SymbolicPolicy(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && @@ -105,7 +134,9 @@ void BM_SymbolicPolicy(benchmark::State& state) { BENCHMARK(BM_SymbolicPolicy) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); absl::StatusOr> MakeBuilderForEnums( absl::string_view container, absl::string_view enum_type, @@ -209,6 +240,7 @@ BENCHMARK(BM_EnumResolution256Candidate)->ThreadRange(1, 32); void BM_NestedComprehension(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) @@ -231,10 +263,13 @@ void BM_NestedComprehension(benchmark::State& state) { BENCHMARK(BM_NestedComprehension) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_Comparisons(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( v11 < v12 && v12 < v13 @@ -260,7 +295,9 @@ void BM_Comparisons(benchmark::State& state) { BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_ComparisonsConcurrent(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( @@ -290,6 +327,8 @@ BENCHMARK(BM_ComparisonsConcurrent)->ThreadRange(1, 32); void RegexPrecompilationBench(bool enabled, benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(absl::StrCat(LabelForParam(param), "_", + enabled ? "enabled" : "disabled")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || @@ -325,7 +364,9 @@ void BM_RegexPrecompilationDisabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationDisabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_RegexPrecompilationEnabled(benchmark::State& state) { RegexPrecompilationBench(true, state); @@ -333,10 +374,13 @@ void BM_RegexPrecompilationEnabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationEnabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); auto size = state.range(1); std::string source = "'1234567890' + '1234567890'"; @@ -377,7 +421,17 @@ BENCHMARK(BM_StringConcat) ->Args({BenchmarkParam::kFoldConstants, 4}) ->Args({BenchmarkParam::kFoldConstants, 8}) ->Args({BenchmarkParam::kFoldConstants, 16}) - ->Args({BenchmarkParam::kFoldConstants, 32}); + ->Args({BenchmarkParam::kFoldConstants, 32}) + ->Args({BenchmarkParam::kRecursivePlanning, 2}) + ->Args({BenchmarkParam::kRecursivePlanning, 4}) + ->Args({BenchmarkParam::kRecursivePlanning, 8}) + ->Args({BenchmarkParam::kRecursivePlanning, 16}) + ->Args({BenchmarkParam::kRecursivePlanning, 32}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 2}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 4}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 8}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 16}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 32}); void BM_StringConcat32Concurrent(benchmark::State& state) { std::string source = "'1234567890' + '1234567890'"; diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 1e18fef95..bf489f24e 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -152,6 +152,13 @@ struct RuntimeOptions { // -1 means unbounded. int max_recursion_depth = 0; + // If true, the planner will switch to the heap-based stack machine for any + // program that exceeds the `max_recursion_depth`. + // + // If false, the planner will return an error if the maximum recursion depth + // is exceeded. + bool enable_recursive_planning_fail_over = false; + // Enable tracing support for recursively planned programs. // // Unlike the stack machine implementation, supporting tracing can affect