Skip to content

Commit 2f87494

Browse files
committed
Fix const fold
1 parent cf7db21 commit 2f87494

3 files changed

Lines changed: 77 additions & 20 deletions

File tree

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import static com.google.common.base.Preconditions.checkNotNull;
1717
import static com.google.common.collect.ImmutableList.toImmutableList;
18-
import static com.google.common.collect.MoreCollectors.onlyElement;
1918
import static dev.cel.checker.CelStandardDeclarations.StandardFunction.DURATION;
2019
import static dev.cel.checker.CelStandardDeclarations.StandardFunction.TIMESTAMP;
2120

@@ -183,9 +182,9 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
183182
if (functionName.equals(Operator.EQUALS.getFunction())
184183
|| functionName.equals(Operator.NOT_EQUALS.getFunction())) {
185184
if (mutableCall.args().stream()
186-
.anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE))
185+
.anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE))
187186
|| mutableCall.args().stream()
188-
.allMatch(node -> node.getKind().equals(Kind.CONSTANT))) {
187+
.allMatch(node -> node.getKind().equals(Kind.CONSTANT))) {
189188
return true;
190189
}
191190
}
@@ -196,12 +195,27 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
196195

197196
// Default case: all call arguments must be constants. If the argument is a container (ex:
198197
// list, map), then its arguments must be a constant.
199-
return areChildrenArgConstant(navigableExpr);
198+
199+
CelMutableExpr target = mutableCall.target().orElse(null);
200+
if (target != null && !isConstantExpr(target)) {
201+
return false;
202+
}
203+
return mutableCall.args().stream().allMatch(this::isConstantExpr);
200204
case SELECT:
201-
CelNavigableMutableExpr operand = navigableExpr.children().collect(onlyElement());
202-
return areChildrenArgConstant(operand);
205+
return isConstantExpr(navigableExpr.expr().select().operand());
203206
case COMPREHENSION:
204-
return !isNestedComprehension(navigableExpr);
207+
if (isNestedComprehension(navigableExpr)) {
208+
return false;
209+
}
210+
CelMutableComprehension comprehension = navigableExpr.expr().comprehension();
211+
212+
if (!isConstantExpr(comprehension.iterRange())
213+
|| !isConstantExpr(comprehension.accuInit())) {
214+
return false;
215+
}
216+
217+
return isFoldableComprehension(comprehension.loopStep())
218+
&& isFoldableComprehension(comprehension.loopCondition());
205219
default:
206220
return false;
207221
}
@@ -248,22 +262,50 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr)
248262
return true;
249263
}
250264

251-
private static boolean areChildrenArgConstant(CelNavigableMutableExpr expr) {
252-
if (expr.getKind().equals(Kind.CONSTANT)) {
253-
return true;
254-
}
265+
private boolean isConstantExpr(CelMutableExpr expr) {
266+
switch (expr.getKind()) {
267+
case CONSTANT:
268+
return true;
269+
case CALL:
270+
CelMutableCall call = expr.call();
271+
if (foldableFunctions.contains(call.function())) {
272+
CelMutableExpr target = call.target().orElse(null);
273+
if (target != null && !isConstantExpr(target)) {
274+
return false;
275+
}
255276

256-
if (expr.getKind().equals(Kind.CALL)
257-
|| expr.getKind().equals(Kind.LIST)
258-
|| expr.getKind().equals(Kind.MAP)
259-
|| expr.getKind().equals(Kind.SELECT)
260-
|| expr.getKind().equals(Kind.STRUCT)) {
261-
return expr.children().allMatch(ConstantFoldingOptimizer::areChildrenArgConstant);
277+
return call.args().stream().allMatch(this::isConstantExpr);
278+
}
279+
return false;
280+
case LIST:
281+
return expr.list().elements().stream().allMatch(this::isConstantExpr);
282+
case MAP:
283+
return expr.map().entries().stream()
284+
.allMatch(
285+
e ->
286+
isConstantExpr(e.key())
287+
&& isConstantExpr(e.value()));
288+
case STRUCT:
289+
return expr.struct().entries().stream()
290+
.allMatch(e -> isConstantExpr(e.value()));
291+
case SELECT:
292+
return isConstantExpr(expr.select().operand());
293+
default:
294+
return false;
262295
}
296+
}
263297

264-
return false;
298+
private boolean isFoldableComprehension(CelMutableExpr expr) {
299+
return CelNavigableMutableExpr.fromExpr(expr)
300+
.allNodes()
301+
.filter(node -> node.getKind().equals(Kind.CALL))
302+
.map(node -> node.expr().call())
303+
.allMatch(
304+
call ->
305+
foldableFunctions.contains(call.function()));
265306
}
266307

308+
267309
private static boolean isNestedComprehension(CelNavigableMutableExpr expr) {
268310
Optional<CelNavigableMutableExpr> maybeParent = expr.parent();
269311
while (maybeParent.isPresent()) {

optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
public class ConstantFoldingOptimizerTest {
4949
private static final CelOptions CEL_OPTIONS =
5050
CelOptions.current()
51+
.populateMacroCalls(true)
5152
.enableTimestampEpoch(true)
5253
.build();
5354
private static final Cel CEL =
@@ -56,12 +57,23 @@ public class ConstantFoldingOptimizerTest {
5657
.addVar("y", SimpleType.DYN)
5758
.addVar("list_var", ListType.create(SimpleType.STRING))
5859
.addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING))
60+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
5961
.addFunctionDeclarations(
6062
CelFunctionDecl.newFunctionDeclaration(
6163
"get_true",
62-
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
64+
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)),
65+
CelFunctionDecl.newFunctionDeclaration(
66+
"get_list",
67+
CelOverloadDecl.newGlobalOverload(
68+
"get_list_overload",
69+
ListType.create(SimpleType.INT),
70+
ListType.create(SimpleType.INT)))
71+
)
6372
.addFunctionBindings(
64-
CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true))
73+
CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true),
74+
CelFunctionBinding.from(
75+
"get_list_overload", ImmutableList.class, arg -> arg)
76+
)
6577
.addMessageTypes(TestAllTypes.getDescriptor())
6678
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
6779
.setOptions(CEL_OPTIONS)
@@ -371,6 +383,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
371383
@TestParameters("{source: 'x == 42'}")
372384
@TestParameters("{source: 'timestamp(100)'}")
373385
@TestParameters("{source: 'duration(\"1h\")'}")
386+
@TestParameters("{source: '[true].exists(x, x == get_true())'}")
387+
@TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}")
374388
public void constantFold_noOp(String source) throws Exception {
375389
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
376390

tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ private void runTests(Cel cel, CelAbstractSyntaxTree ast, PolicyTestSuite testSu
283283
? (List<AgentMessage>) inputMap.get("_test_history")
284284
: ImmutableList.of();
285285

286+
@SuppressWarnings("Immutable")
286287
CelLateFunctionBindings bindings = CelLateFunctionBindings.from(
287288
CelFunctionBinding.from(
288289
"agent_history",

0 commit comments

Comments
 (0)