22
33#include < type_traits>
44
5+ #include " absl/container/btree_set.h"
6+ #include " eval/eval/set_util.h"
57#include " eval/public/cel_function.h"
68#include " eval/public/cel_options.h"
79#include " eval/public/cel_value.h"
@@ -12,147 +14,83 @@ namespace expr {
1214namespace runtime {
1315namespace {
1416
15- // Forward declare.
16- bool CelValueEqual (const CelValue lhs, const CelValue rhs);
17-
18- // Default to operator==
19- template <typename T>
20- bool CelValueEqualImpl (T lhs, T rhs) {
21- return lhs == rhs;
22- }
23-
24- // List equality specialization. Test that the lists are in-order elementwise
25- // equal.
26- template <>
27- bool CelValueEqualImpl<const CelList*>(const CelList* lhs, const CelList* rhs) {
28- if (lhs->size () != rhs->size ()) {
29- return false ;
17+ // Tests that lhs descriptor is less than (name, receiver call style,
18+ // arg types).
19+ // Argument type Any is not treated specially. For example:
20+ // {"f", false, {kAny}} > {"f", false, {kInt64}}
21+ bool DescriptorLessThan (const CelFunctionDescriptor& lhs,
22+ const CelFunctionDescriptor& rhs) {
23+ if (lhs.name () < rhs.name ()) {
24+ return true ;
3025 }
31- for (int i = 0 ; i < rhs->size (); i++) {
32- if (!CelValueEqual (lhs->operator [](i), rhs->operator [](i))) {
33- return false ;
34- }
26+ if (lhs.name () > rhs.name ()) {
27+ return false ;
3528 }
36- return true ;
37- }
3829
39- // Map equality specialization. Compare that two maps have exactly the same
40- // key/value pairs.
41- template <>
42- bool CelValueEqualImpl<const CelMap*>(const CelMap* lhs, const CelMap* rhs) {
43- if (lhs->size () != rhs->size ()) {
30+ if (lhs.receiver_style () < rhs.receiver_style ()) {
31+ return true ;
32+ }
33+ if (lhs.receiver_style () > rhs.receiver_style ()) {
4434 return false ;
4535 }
46- const CelList* key_set = rhs->ListKeys ();
47- for (int i = 0 ; i < key_set->size (); i++) {
48- CelValue key = key_set->operator [](i);
49- CelValue rhs_value = rhs->operator [](key).value ();
50- auto maybe_lhs_value = lhs->operator [](key);
51- if (!maybe_lhs_value.has_value ()) {
52- return false ;
53- }
54- if (!CelValueEqual (maybe_lhs_value.value (), rhs_value)) {
55- return false ;
56- }
36+
37+ if (lhs.types () >= rhs.types ()) {
38+ return false ;
5739 }
40+
5841 return true ;
5942}
6043
61- // Visitor for implementing comparing the underlying value that two CelValues
62- // are wrapping. The visitor unwraps the lhs then tries to get the rhs
63- // underlying value if it is the same type as the lhs.
64- struct LhsCompareVisitor {
65- CelValue rhs;
66-
67- LhsCompareVisitor (CelValue rhs) : rhs(rhs) {}
68-
69- template <typename T>
70- bool operator ()(T lhs_value) {
71- T rhs_value;
72- bool is_same_type = rhs.GetValue (&rhs_value);
73- if (!is_same_type) {
74- return false ;
75- }
76- return CelValueEqualImpl<T>(lhs_value, rhs_value);
44+ bool UnknownFunctionResultLessThan (const UnknownFunctionResult& lhs,
45+ const UnknownFunctionResult& rhs) {
46+ if (DescriptorLessThan (lhs.descriptor (), rhs.descriptor ())) {
47+ return true ;
7748 }
78- };
79-
80- // This is a slightly different implementation than provided for the cel
81- // evaluator. Differences are:
82- //
83- // - this implementation doesn't need to support error forwarding in the same
84- // way -- this should only be used for situations when we can invoke the
85- // function. i.e. the function must specify that it consumes errors and/or
86- // unknown sets for them to appear in the arg list.
87- // - this implementation defines equality between messages based on ptr identity
88- bool CelValueEqual (const CelValue lhs, const CelValue rhs) {
89- if (lhs.type () != rhs.type ()) {
49+ if (DescriptorLessThan (rhs.descriptor (), lhs.descriptor ())) {
9050 return false ;
9151 }
92- return lhs.Visit <bool >(LhsCompareVisitor (rhs));
93- }
9452
95- // Tests that two descriptors are equal (name, receiver call style, arg types).
96- //
97- // Argument type Any is not treated specially. For example:
98- // {"f", false, {kAny}} != {"f", false, {kInt64}}
99- bool DescriptorEqual (const CelFunctionDescriptor& lhs,
100- const CelFunctionDescriptor& rhs) {
101- if (lhs.name () != rhs.name ()) {
102- return false ;
53+ if (lhs.arguments ().size () < rhs.arguments ().size ()) {
54+ return true ;
10355 }
10456
105- if (lhs.receiver_style () != rhs.receiver_style ()) {
57+ if (lhs.arguments (). size () > rhs.arguments (). size ()) {
10658 return false ;
10759 }
10860
109- if (lhs.types () != rhs.types ()) {
110- return false ;
61+ for (size_t i = 0 ; i < lhs.arguments ().size (); i++) {
62+ if (CelValueLessThan (lhs.arguments ()[i], rhs.arguments ()[i])) {
63+ return true ;
64+ }
65+ if (CelValueLessThan (rhs.arguments ()[i], lhs.arguments ()[i])) {
66+ return false ;
67+ }
11168 }
11269
113- return true ;
70+ // equal
71+ return false ;
11472}
11573
11674} // namespace
11775
76+ bool UnknownFunctionComparator::operator ()(
77+ const UnknownFunctionResult* lhs, const UnknownFunctionResult* rhs) const {
78+ return UnknownFunctionResultLessThan (*lhs, *rhs);
79+ }
80+
11881bool UnknownFunctionResult::IsEqualTo (
11982 const UnknownFunctionResult& other) const {
120- if (!DescriptorEqual (descriptor_, other.descriptor ())) {
121- return false ;
122- }
123-
124- if (arguments_.size () != other.arguments ().size ()) {
125- return false ;
126- }
127-
128- for (size_t i = 0 ; i < arguments_.size (); i++) {
129- if (!CelValueEqual (arguments_[i], other.arguments ()[i])) {
130- return false ;
131- }
132- }
133-
134- return true ;
83+ return !(UnknownFunctionResultLessThan (*this , other) ||
84+ UnknownFunctionResultLessThan (other, *this ));
13585}
13686
13787// Implementation for merge constructor.
13888UnknownFunctionResultSet::UnknownFunctionResultSet (
13989 const UnknownFunctionResultSet& lhs, const UnknownFunctionResultSet& rhs)
14090 : unknown_function_results_(lhs.unknown_function_results()) {
141- unknown_function_results_.reserve (lhs.unknown_function_results ().size () +
142- rhs.unknown_function_results ().size ());
14391 for (const UnknownFunctionResult* call : rhs.unknown_function_results ()) {
144- Add (call);
145- }
146- }
147-
148- void UnknownFunctionResultSet::Add (const UnknownFunctionResult* result) {
149- for (const UnknownFunctionResult* existing_result :
150- unknown_function_results ()) {
151- if (result->IsEqualTo (*existing_result)) {
152- return ;
153- }
92+ unknown_function_results_.insert (call);
15493 }
155- unknown_function_results_.push_back (result);
15694}
15795
15896} // namespace runtime
0 commit comments