Skip to content

Commit d88a482

Browse files
kyessenovCEL Dev Team
andauthored
Export of internal changes (#55)
-- 306225043 by kuat <kuat@google.com>: BEGIN_PUBLIC Export cel/cpp/tools. END_PUBLIC I accidentally forgot to include tools/ in the last OSS export. Updating copybara to include tools/flatbuffers_backed_impl targets and excluding decompiler and conversion tools. OSS bazel is more strict than blaze so patched the build definitions to make them work. -- 305962063 by CEL Dev Team <cel-dev@google.com>: BEGIN_PUBLIC Move CelValueToValue() function from Explainer to utility class END_PUBLIC -- 305731098 by CEL Dev Team <cel-dev@google.com>: BEGIN_PUBLIC Add btree set implementation for unknown function result sets. Improves performance for large numbers of unknown function results. END_PUBLIC PiperOrigin-RevId: 306225043 Co-authored-by: CEL Dev Team <cel-dev@google.com>
1 parent 53928a8 commit d88a482

9 files changed

Lines changed: 1090 additions & 112 deletions

eval/public/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,8 @@ cc_library(
573573
":cel_function",
574574
":cel_options",
575575
":cel_value",
576+
"//eval/eval:set_util",
577+
"@com_google_absl//absl/container:btree",
576578
"@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto",
577579
],
578580
)

eval/public/unknown_function_result_set.cc

Lines changed: 45 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <type_traits>
44

5+
#include "absl/container/btree_set.h"
6+
#include "eval/eval/set_util.h"
57
#include "eval/public/cel_function.h"
68
#include "eval/public/cel_options.h"
79
#include "eval/public/cel_value.h"
@@ -12,147 +14,83 @@ namespace expr {
1214
namespace runtime {
1315
namespace {
1416

15-
// Forward declare.
16-
bool CelValueEqual(const CelValue lhs, const CelValue rhs);
17-
18-
// Default to operator==
19-
template <typename T>
20-
bool CelValueEqualImpl(T lhs, T rhs) {
21-
return lhs == rhs;
22-
}
23-
24-
// List equality specialization. Test that the lists are in-order elementwise
25-
// equal.
26-
template <>
27-
bool CelValueEqualImpl<const CelList*>(const CelList* lhs, const CelList* rhs) {
28-
if (lhs->size() != rhs->size()) {
29-
return false;
17+
// Tests that lhs descriptor is less than (name, receiver call style,
18+
// arg types).
19+
// Argument type Any is not treated specially. For example:
20+
// {"f", false, {kAny}} > {"f", false, {kInt64}}
21+
bool DescriptorLessThan(const CelFunctionDescriptor& lhs,
22+
const CelFunctionDescriptor& rhs) {
23+
if (lhs.name() < rhs.name()) {
24+
return true;
3025
}
31-
for (int i = 0; i < rhs->size(); i++) {
32-
if (!CelValueEqual(lhs->operator[](i), rhs->operator[](i))) {
33-
return false;
34-
}
26+
if (lhs.name() > rhs.name()) {
27+
return false;
3528
}
36-
return true;
37-
}
3829

39-
// Map equality specialization. Compare that two maps have exactly the same
40-
// key/value pairs.
41-
template <>
42-
bool CelValueEqualImpl<const CelMap*>(const CelMap* lhs, const CelMap* rhs) {
43-
if (lhs->size() != rhs->size()) {
30+
if (lhs.receiver_style() < rhs.receiver_style()) {
31+
return true;
32+
}
33+
if (lhs.receiver_style() > rhs.receiver_style()) {
4434
return false;
4535
}
46-
const CelList* key_set = rhs->ListKeys();
47-
for (int i = 0; i < key_set->size(); i++) {
48-
CelValue key = key_set->operator[](i);
49-
CelValue rhs_value = rhs->operator[](key).value();
50-
auto maybe_lhs_value = lhs->operator[](key);
51-
if (!maybe_lhs_value.has_value()) {
52-
return false;
53-
}
54-
if (!CelValueEqual(maybe_lhs_value.value(), rhs_value)) {
55-
return false;
56-
}
36+
37+
if (lhs.types() >= rhs.types()) {
38+
return false;
5739
}
40+
5841
return true;
5942
}
6043

61-
// Visitor for implementing comparing the underlying value that two CelValues
62-
// are wrapping. The visitor unwraps the lhs then tries to get the rhs
63-
// underlying value if it is the same type as the lhs.
64-
struct LhsCompareVisitor {
65-
CelValue rhs;
66-
67-
LhsCompareVisitor(CelValue rhs) : rhs(rhs) {}
68-
69-
template <typename T>
70-
bool operator()(T lhs_value) {
71-
T rhs_value;
72-
bool is_same_type = rhs.GetValue(&rhs_value);
73-
if (!is_same_type) {
74-
return false;
75-
}
76-
return CelValueEqualImpl<T>(lhs_value, rhs_value);
44+
bool UnknownFunctionResultLessThan(const UnknownFunctionResult& lhs,
45+
const UnknownFunctionResult& rhs) {
46+
if (DescriptorLessThan(lhs.descriptor(), rhs.descriptor())) {
47+
return true;
7748
}
78-
};
79-
80-
// This is a slightly different implementation than provided for the cel
81-
// evaluator. Differences are:
82-
//
83-
// - this implementation doesn't need to support error forwarding in the same
84-
// way -- this should only be used for situations when we can invoke the
85-
// function. i.e. the function must specify that it consumes errors and/or
86-
// unknown sets for them to appear in the arg list.
87-
// - this implementation defines equality between messages based on ptr identity
88-
bool CelValueEqual(const CelValue lhs, const CelValue rhs) {
89-
if (lhs.type() != rhs.type()) {
49+
if (DescriptorLessThan(rhs.descriptor(), lhs.descriptor())) {
9050
return false;
9151
}
92-
return lhs.Visit<bool>(LhsCompareVisitor(rhs));
93-
}
9452

95-
// Tests that two descriptors are equal (name, receiver call style, arg types).
96-
//
97-
// Argument type Any is not treated specially. For example:
98-
// {"f", false, {kAny}} != {"f", false, {kInt64}}
99-
bool DescriptorEqual(const CelFunctionDescriptor& lhs,
100-
const CelFunctionDescriptor& rhs) {
101-
if (lhs.name() != rhs.name()) {
102-
return false;
53+
if (lhs.arguments().size() < rhs.arguments().size()) {
54+
return true;
10355
}
10456

105-
if (lhs.receiver_style() != rhs.receiver_style()) {
57+
if (lhs.arguments().size() > rhs.arguments().size()) {
10658
return false;
10759
}
10860

109-
if (lhs.types() != rhs.types()) {
110-
return false;
61+
for (size_t i = 0; i < lhs.arguments().size(); i++) {
62+
if (CelValueLessThan(lhs.arguments()[i], rhs.arguments()[i])) {
63+
return true;
64+
}
65+
if (CelValueLessThan(rhs.arguments()[i], lhs.arguments()[i])) {
66+
return false;
67+
}
11168
}
11269

113-
return true;
70+
// equal
71+
return false;
11472
}
11573

11674
} // namespace
11775

76+
bool UnknownFunctionComparator::operator()(
77+
const UnknownFunctionResult* lhs, const UnknownFunctionResult* rhs) const {
78+
return UnknownFunctionResultLessThan(*lhs, *rhs);
79+
}
80+
11881
bool UnknownFunctionResult::IsEqualTo(
11982
const UnknownFunctionResult& other) const {
120-
if (!DescriptorEqual(descriptor_, other.descriptor())) {
121-
return false;
122-
}
123-
124-
if (arguments_.size() != other.arguments().size()) {
125-
return false;
126-
}
127-
128-
for (size_t i = 0; i < arguments_.size(); i++) {
129-
if (!CelValueEqual(arguments_[i], other.arguments()[i])) {
130-
return false;
131-
}
132-
}
133-
134-
return true;
83+
return !(UnknownFunctionResultLessThan(*this, other) ||
84+
UnknownFunctionResultLessThan(other, *this));
13585
}
13686

13787
// Implementation for merge constructor.
13888
UnknownFunctionResultSet::UnknownFunctionResultSet(
13989
const UnknownFunctionResultSet& lhs, const UnknownFunctionResultSet& rhs)
14090
: unknown_function_results_(lhs.unknown_function_results()) {
141-
unknown_function_results_.reserve(lhs.unknown_function_results().size() +
142-
rhs.unknown_function_results().size());
14391
for (const UnknownFunctionResult* call : rhs.unknown_function_results()) {
144-
Add(call);
145-
}
146-
}
147-
148-
void UnknownFunctionResultSet::Add(const UnknownFunctionResult* result) {
149-
for (const UnknownFunctionResult* existing_result :
150-
unknown_function_results()) {
151-
if (result->IsEqualTo(*existing_result)) {
152-
return;
153-
}
92+
unknown_function_results_.insert(call);
15493
}
155-
unknown_function_results_.push_back(result);
15694
}
15795

15896
} // namespace runtime

eval/public/unknown_function_result_set.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <vector>
55

66
#include "google/api/expr/v1alpha1/syntax.pb.h"
7+
#include "absl/container/btree_map.h"
8+
#include "absl/container/btree_set.h"
79
#include "eval/public/cel_function.h"
810

911
namespace google {
@@ -30,7 +32,8 @@ class UnknownFunctionResult {
3032
// The arguments of the function call that generated the unknown.
3133
const std::vector<CelValue>& arguments() const { return arguments_; }
3234

33-
// Equality operator provided for set semantics.
35+
// Equality operator provided for testing. Compatible with set less-than
36+
// comparator.
3437
// Compares descriptor then arguments elementwise.
3538
bool IsEqualTo(const UnknownFunctionResult& other) const;
3639

@@ -40,6 +43,12 @@ class UnknownFunctionResult {
4043
std::vector<CelValue> arguments_;
4144
};
4245

46+
// Comparator for set semantics.
47+
struct UnknownFunctionComparator {
48+
bool operator()(const UnknownFunctionResult*,
49+
const UnknownFunctionResult*) const;
50+
};
51+
4352
// Represents a collection of unknown function results at a particular point in
4453
// execution. Execution should advance further if this set of unknowns are
4554
// provided. It may not advance if only a subset are provided.
@@ -57,14 +66,15 @@ class UnknownFunctionResultSet {
5766
UnknownFunctionResultSet(const UnknownFunctionResult* initial)
5867
: unknown_function_results_{initial} {}
5968

60-
const std::vector<const UnknownFunctionResult*>& unknown_function_results()
61-
const {
69+
using Container =
70+
absl::btree_set<const UnknownFunctionResult*, UnknownFunctionComparator>;
71+
72+
const Container& unknown_function_results() const {
6273
return unknown_function_results_;
6374
}
6475

6576
private:
66-
std::vector<const UnknownFunctionResult*> unknown_function_results_;
67-
void Add(const UnknownFunctionResult* result);
77+
Container unknown_function_results_;
6878
};
6979

7080
} // namespace runtime

eval/public/unknown_function_result_set_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ CelFunctionDescriptor kTwoInt("TwoInt", false,
3535

3636
CelFunctionDescriptor kOneInt("OneInt", false, {CelValue::Type::kInt64});
3737

38+
// Helper to confirm the set comparator works.
39+
bool IsLessThan(const UnknownFunctionResult& lhs,
40+
const UnknownFunctionResult& rhs) {
41+
return UnknownFunctionComparator()(&lhs, &rhs);
42+
}
43+
3844
TEST(UnknownFunctionResult, ArgumentCapture) {
3945
UnknownFunctionResult call1(
4046
kTwoInt, /*expr_id=*/0,
@@ -54,6 +60,8 @@ TEST(UnknownFunctionResult, Equals) {
5460
{CelValue::CreateInt64(1), CelValue::CreateInt64(2)});
5561

5662
EXPECT_TRUE(call1.IsEqualTo(call2));
63+
EXPECT_FALSE(IsLessThan(call1, call2));
64+
EXPECT_FALSE(IsLessThan(call2, call1));
5765

5866
UnknownFunctionResult call3(kOneInt, /*expr_id=*/0,
5967
{CelValue::CreateInt64(1)});
@@ -73,6 +81,7 @@ TEST(UnknownFunctionResult, InequalDescriptor) {
7381
{CelValue::CreateInt64(1)});
7482

7583
EXPECT_FALSE(call1.IsEqualTo(call2));
84+
EXPECT_TRUE(IsLessThan(call2, call1));
7685

7786
CelFunctionDescriptor one_uint("OneInt", false, {CelValue::Type::kUint64});
7887

@@ -83,6 +92,7 @@ TEST(UnknownFunctionResult, InequalDescriptor) {
8392
{CelValue::CreateUint64(1)});
8493

8594
EXPECT_FALSE(call3.IsEqualTo(call4));
95+
EXPECT_TRUE(IsLessThan(call3, call4));
8696
}
8797

8898
TEST(UnknownFunctionResult, InequalArgs) {
@@ -95,6 +105,7 @@ TEST(UnknownFunctionResult, InequalArgs) {
95105
{CelValue::CreateInt64(1), CelValue::CreateInt64(3)});
96106

97107
EXPECT_FALSE(call1.IsEqualTo(call2));
108+
EXPECT_TRUE(IsLessThan(call1, call2));
98109

99110
UnknownFunctionResult call3(
100111
kTwoInt, /*expr_id=*/0,
@@ -104,6 +115,7 @@ TEST(UnknownFunctionResult, InequalArgs) {
104115
{CelValue::CreateInt64(1)});
105116

106117
EXPECT_FALSE(call3.IsEqualTo(call4));
118+
EXPECT_TRUE(IsLessThan(call4, call3));
107119
}
108120

109121
TEST(UnknownFunctionResult, ListsEqual) {
@@ -143,6 +155,7 @@ TEST(UnknownFunctionResult, ListsDifferentSizes) {
143155

144156
// [1, 2] == [1, 2, 3]
145157
EXPECT_FALSE(call1.IsEqualTo(call2));
158+
EXPECT_TRUE(IsLessThan(call1, call2));
146159
}
147160

148161
TEST(UnknownFunctionResult, ListsDifferentMembers) {
@@ -161,6 +174,7 @@ TEST(UnknownFunctionResult, ListsDifferentMembers) {
161174

162175
// [1, 2] == [2, 2]
163176
EXPECT_FALSE(call1.IsEqualTo(call2));
177+
EXPECT_TRUE(IsLessThan(call1, call2));
164178
}
165179

166180
TEST(UnknownFunctionResult, MapsEqual) {
@@ -205,6 +219,7 @@ TEST(UnknownFunctionResult, MapsDifferentSizes) {
205219

206220
// {1: 2, 2: 4} == {1: 2, 2: 4, 3: 6}
207221
EXPECT_FALSE(call1.IsEqualTo(call2));
222+
EXPECT_TRUE(IsLessThan(call1, call2));
208223
}
209224

210225
TEST(UnknownFunctionResult, MapsDifferentElements) {
@@ -240,8 +255,10 @@ TEST(UnknownFunctionResult, MapsDifferentElements) {
240255

241256
// {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 4: 8}
242257
EXPECT_FALSE(call1.IsEqualTo(call2));
258+
EXPECT_TRUE(IsLessThan(call1, call2));
243259
// {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 3: 5}
244260
EXPECT_FALSE(call1.IsEqualTo(call3));
261+
EXPECT_TRUE(IsLessThan(call3, call1));
245262
}
246263

247264
TEST(UnknownFunctionResult, Messages) {
@@ -279,7 +296,9 @@ TEST(UnknownFunctionResult, AnyDescriptor) {
279296
{CelValue::CreateUint64(2)});
280297

281298
EXPECT_FALSE(callAnyInt1.IsEqualTo(callInt));
299+
EXPECT_TRUE(IsLessThan(callAnyInt1, callInt));
282300
EXPECT_FALSE(callAnyInt1.IsEqualTo(callAnyUint));
301+
EXPECT_TRUE(IsLessThan(callAnyInt1, callAnyUint));
283302
EXPECT_TRUE(callAnyInt1.IsEqualTo(callAnyInt2));
284303
}
285304

0 commit comments

Comments
 (0)