Skip to content

Commit 2f06d90

Browse files
jnthntatumcopybara-github
authored andcommitted
Add checker support for block.
This is needed for re-checking expressions that were produced as a part of policy compilation. PiperOrigin-RevId: 910179322
1 parent 1cf21ee commit 2f06d90

11 files changed

Lines changed: 446 additions & 123 deletions

File tree

checker/internal/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ cc_library(
155155
"@com_google_absl//absl/cleanup",
156156
"@com_google_absl//absl/container:flat_hash_map",
157157
"@com_google_absl//absl/container:flat_hash_set",
158+
"@com_google_absl//absl/log:absl_check",
158159
"@com_google_absl//absl/log:absl_log",
159160
"@com_google_absl//absl/status",
160161
"@com_google_absl//absl/status:statusor",
@@ -179,6 +180,7 @@ cc_test(
179180
"//checker:type_checker_builder",
180181
"//checker:validation_result",
181182
"//common:ast",
183+
"//common:ast_proto",
182184
"//common:container",
183185
"//common:decl",
184186
"//common:expr",
@@ -187,13 +189,17 @@ cc_test(
187189
"//internal:status_macros",
188190
"//internal:testing",
189191
"//internal:testing_descriptor_pool",
192+
"//parser",
193+
"//parser:macro_registry",
190194
"//testutil:baseline_tests",
195+
"//testutil:test_macros",
191196
"@com_google_absl//absl/base:no_destructor",
192197
"@com_google_absl//absl/base:nullability",
193198
"@com_google_absl//absl/container:flat_hash_set",
194199
"@com_google_absl//absl/log:absl_check",
195200
"@com_google_absl//absl/status",
196201
"@com_google_absl//absl/status:status_matchers",
202+
"@com_google_absl//absl/status:statusor",
197203
"@com_google_absl//absl/strings",
198204
"@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto",
199205
"@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto",

checker/internal/type_checker_impl.cc

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "absl/base/nullability.h"
2626
#include "absl/container/flat_hash_map.h"
2727
#include "absl/container/flat_hash_set.h"
28+
#include "absl/log/absl_check.h"
2829
#include "absl/status/status.h"
2930
#include "absl/status/statusor.h"
3031
#include "absl/strings/match.h"
@@ -59,6 +60,15 @@
5960
namespace cel::checker_internal {
6061
namespace {
6162

63+
bool MatchesBlock(const Expr& expr) {
64+
if (!expr.has_call_expr()) {
65+
return false;
66+
}
67+
const auto& call = expr.call_expr();
68+
return call.function() == "cel.@block" && call.args().size() == 2 &&
69+
call.args()[0].has_list_expr();
70+
}
71+
6272
using AstType = cel::TypeSpec;
6373
using Severity = TypeCheckIssue::Severity;
6474

@@ -204,13 +214,23 @@ class ResolveVisitor : public AstVisitorBase {
204214
arena_(arena),
205215
current_scope_(&root_scope_) {}
206216

207-
void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); }
217+
void PreVisitExpr(const Expr& expr) override {
218+
expr_stack_.push_back(&expr);
219+
if (expr_stack_.size() == 1 && MatchesBlock(expr)) {
220+
ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2);
221+
ABSL_DCHECK(block_init_list_ == nullptr);
222+
block_init_list_ = &expr.call_expr().args()[0];
223+
}
224+
}
208225

209226
void PostVisitExpr(const Expr& expr) override {
210227
if (expr_stack_.empty()) {
211228
return;
212229
}
213230
expr_stack_.pop_back();
231+
if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) {
232+
HandleBlockIndex(&expr);
233+
}
214234
}
215235

216236
void PostVisitConst(const Expr& expr, const Constant& constant) override;
@@ -389,6 +409,7 @@ class ResolveVisitor : public AstVisitorBase {
389409
absl::string_view field_name);
390410

391411
void HandleOptSelect(const Expr& expr);
412+
void HandleBlockIndex(const Expr* expr);
392413

393414
// Get the assigned type of the given subexpression. Should only be called if
394415
// the given subexpression is expected to have already been checked.
@@ -421,6 +442,7 @@ class ResolveVisitor : public AstVisitorBase {
421442
std::vector<const Expr*> expr_stack_;
422443
absl::flat_hash_map<const Expr*, std::vector<std::string>>
423444
maybe_namespaced_functions_;
445+
const Expr* block_init_list_ = nullptr;
424446
// Select operations that need to be resolved outside of the traversal.
425447
// These are handled separately to disambiguate between namespaces and field
426448
// accesses
@@ -609,8 +631,15 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) {
609631
}
610632

611633
void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) {
612-
// Follows list type inferencing behavior in Go (see map comments above).
634+
if (&expr == block_init_list_) {
635+
// Don't try to coalesce list type here because it can influence the
636+
// resolved type of the list elements. cel.@block is always list<dyn> and
637+
// the elements are treated independently at runtime.
638+
types_[&expr] = ListType();
639+
return;
640+
}
613641

642+
// Follows list type inferencing behavior in Go (see map comments above).
614643
Type overall_elem_type =
615644
inference_context_->InstantiateTypeParams(TypeParamType("E"));
616645
auto assignability_context = inference_context_->CreateAssignabilityContext();
@@ -1172,6 +1201,44 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) {
11721201
}
11731202
}
11741203

1204+
void ResolveVisitor::HandleBlockIndex(const Expr* expr) {
1205+
ABSL_DCHECK(block_init_list_ != nullptr);
1206+
ABSL_DCHECK(block_init_list_->has_list_expr());
1207+
const auto& elements = block_init_list_->list_expr().elements();
1208+
int index = -1;
1209+
for (size_t i = 0; i < elements.size(); ++i) {
1210+
if (&elements[i].expr() == expr) {
1211+
index = i;
1212+
break;
1213+
}
1214+
}
1215+
if (index < 0) {
1216+
status_.Update(absl::InternalError(
1217+
"could not resolve expression as a cel.@block subexpression"));
1218+
return;
1219+
}
1220+
std::string var_name = absl::StrCat("@index", index);
1221+
1222+
// Block is typically manually assembled from logically separate
1223+
// expressions so fix the type instead of inferring any remaining free type
1224+
// params as for normal subexpressions.
1225+
auto type = inference_context_->FinalizeType(GetDeducedType(expr));
1226+
1227+
VariableDecl decl = MakeVariableDecl(var_name, std::move(type));
1228+
1229+
// The C++ runtime requires that the indexes are topologically ordered.
1230+
// They just come into scope in order as we walk the AST so we don't need
1231+
// to do any additional work to check references to other initializers in
1232+
// an init expr.
1233+
//
1234+
// TODO(uncreated-issue/90): This is slightly inconsistent with the java
1235+
// runtime implementation which just requires the references to be acyclic.
1236+
auto* scope =
1237+
comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get();
1238+
scope->InsertVariableIfAbsent(std::move(decl));
1239+
current_scope_ = scope;
1240+
}
1241+
11751242
class ResolveRewriter : public AstRewriterBase {
11761243
public:
11771244
explicit ResolveRewriter(const ResolveVisitor& visitor,
@@ -1230,15 +1297,15 @@ class ResolveRewriter : public AstRewriterBase {
12301297

12311298
if (auto iter = visitor_.types().find(&expr);
12321299
iter != visitor_.types().end()) {
1233-
auto flattened_type =
1234-
FlattenType(inference_context_.FinalizeType(iter->second));
1300+
cel::Type finalized_type = inference_context_.FinalizeType(iter->second);
1301+
auto flattened_type = FlattenType(finalized_type);
12351302

12361303
if (!flattened_type.ok()) {
12371304
status_.Update(flattened_type.status());
12381305
return rewritten;
12391306
}
12401307
type_map_[expr.id()] = *std::move(flattened_type);
1241-
resolved_types_[expr.id()] = iter->second;
1308+
resolved_types_[expr.id()] = finalized_type;
12421309
rewritten = true;
12431310
}
12441311

checker/internal/type_checker_impl_test.cc

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "absl/log/absl_check.h"
2727
#include "absl/status/status.h"
2828
#include "absl/status/status_matchers.h"
29+
#include "absl/status/statusor.h"
2930
#include "absl/strings/match.h"
3031
#include "absl/strings/str_cat.h"
3132
#include "absl/strings/str_join.h"
@@ -36,6 +37,7 @@
3637
#include "checker/type_checker_builder.h"
3738
#include "checker/validation_result.h"
3839
#include "common/ast.h"
40+
#include "common/ast_proto.h"
3941
#include "common/container.h"
4042
#include "common/decl.h"
4143
#include "common/expr.h"
@@ -45,7 +47,10 @@
4547
#include "internal/status_macros.h"
4648
#include "internal/testing.h"
4749
#include "internal/testing_descriptor_pool.h"
50+
#include "parser/macro_registry.h"
51+
#include "parser/parser.h"
4852
#include "testutil/baseline_tests.h"
53+
#include "testutil/test_macros.h"
4954
#include "cel/expr/conformance/proto2/test_all_types.pb.h"
5055
#include "cel/expr/conformance/proto3/test_all_types.pb.h"
5156
#include "google/protobuf/arena.h"
@@ -108,6 +113,17 @@ google::protobuf::Arena* absl_nonnull TestTypeArena() {
108113
return &(*kArena);
109114
}
110115

116+
absl::StatusOr<std::unique_ptr<Ast>> MakeTestParsedAstWithMacros(
117+
absl::string_view expression, const cel::MacroRegistry& registry) {
118+
CEL_ASSIGN_OR_RETURN(
119+
auto source,
120+
cel::NewSource(expression, /*description=*/std::string(expression)));
121+
CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse(
122+
*source, registry,
123+
{.enable_optional_syntax = true}));
124+
return cel::CreateAstFromParsedExpr(parsed_expr);
125+
}
126+
111127
FunctionDecl MakeIdentFunction() {
112128
auto decl = MakeFunctionDecl(
113129
"identity",
@@ -272,6 +288,12 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena
272288
/*return_type=*/TypeType(arena, TypeParamType("A")),
273289
TypeParamType("A"))));
274290

291+
Type kParam(TypeParamType("T"));
292+
CEL_ASSIGN_OR_RETURN(
293+
auto block_decl,
294+
MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam,
295+
ListType(), kParam)));
296+
275297
env.InsertFunctionIfAbsent(std::move(not_op));
276298
env.InsertFunctionIfAbsent(std::move(not_strictly_false));
277299
env.InsertFunctionIfAbsent(std::move(add_op));
@@ -289,6 +311,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena
289311
env.InsertFunctionIfAbsent(std::move(to_type));
290312
env.InsertFunctionIfAbsent(std::move(to_duration));
291313
env.InsertFunctionIfAbsent(std::move(to_timestamp));
314+
env.InsertFunctionIfAbsent(std::move(block_decl));
292315

293316
return absl::OkStatus();
294317
}
@@ -308,6 +331,78 @@ TEST(TypeCheckerImplTest, SmokeTest) {
308331
EXPECT_THAT(result.GetIssues(), IsEmpty());
309332
}
310333

334+
TEST(TypeCheckerImplTest, BlockMacroSupport) {
335+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
336+
337+
google::protobuf::Arena arena;
338+
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());
339+
340+
MacroRegistry registry;
341+
ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk());
342+
343+
TypeCheckerImpl impl(std::move(env));
344+
ASSERT_OK_AND_ASSIGN(
345+
auto ast,
346+
MakeTestParsedAstWithMacros(
347+
"cel.block([1, 2], cel.index(0) + cel.index(1))", registry));
348+
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));
349+
350+
EXPECT_TRUE(result.IsValid());
351+
EXPECT_THAT(result.GetIssues(), IsEmpty());
352+
353+
// Overall type should be int.
354+
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
355+
auto root_id = checked_ast->root_expr().id();
356+
EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(),
357+
PrimitiveType::kInt64);
358+
}
359+
360+
TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) {
361+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
362+
363+
google::protobuf::Arena arena;
364+
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());
365+
366+
MacroRegistry registry;
367+
ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk());
368+
369+
TypeCheckerImpl impl(std::move(env));
370+
ASSERT_OK_AND_ASSIGN(
371+
auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))",
372+
registry));
373+
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));
374+
375+
EXPECT_TRUE(result.IsValid());
376+
EXPECT_THAT(result.GetIssues(), IsEmpty());
377+
378+
// cel.index(1) refers to 'a' which is string.
379+
// So overall type should be string.
380+
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
381+
auto root_id = checked_ast->root_expr().id();
382+
EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(),
383+
PrimitiveType::kString);
384+
}
385+
386+
TEST(TypeCheckerImplTest, BadIndex) {
387+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
388+
389+
google::protobuf::Arena arena;
390+
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());
391+
392+
MacroRegistry registry;
393+
ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk());
394+
395+
TypeCheckerImpl impl(std::move(env));
396+
ASSERT_OK_AND_ASSIGN(
397+
auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))",
398+
registry));
399+
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));
400+
401+
EXPECT_FALSE(result.IsValid());
402+
EXPECT_THAT(result.FormatError(),
403+
HasSubstr("undeclared reference to '@index2' (in container"));
404+
}
405+
311406
TEST(TypeCheckerImplTest, SimpleIdentsResolved) {
312407
TypeCheckEnv env(GetSharedTestingDescriptorPool());
313408

conformance/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ cc_library(
6969
"//runtime:reference_resolver",
7070
"//runtime:runtime_options",
7171
"//runtime:standard_runtime_builder_factory",
72+
"//testutil:test_macros",
7273
"@com_google_absl//absl/log:absl_check",
7374
"@com_google_absl//absl/memory",
7475
"@com_google_absl//absl/status",
@@ -221,7 +222,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [
221222
"proto3/set_null/list_value",
222223
"proto3/set_null/single_struct",
223224

224-
# cel.@block
225+
# no optional support for legacy types
225226
"block_ext/basic/optional_list",
226227
"block_ext/basic/optional_map",
227228
"block_ext/basic/optional_map_chained",
@@ -231,7 +232,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [
231232
_TESTS_TO_SKIP_CHECKED = [
232233
# block is a post-check optimization that inserts internal variables. The C++ type checker
233234
# needs support for a proper optimizer for this to work.
234-
"block_ext",
235+
# "block_ext",
235236
]
236237

237238
_TESTS_TO_SKIP_LEGACY_DASHBOARD = [

0 commit comments

Comments
 (0)