Skip to content

Commit ac1ef3e

Browse files
committed
Merge branch 'expression-oversimplification' into misleading-expansions
2 parents 0d5b026 + 77a3fad commit ac1ef3e

File tree

2 files changed

+193
-22
lines changed

2 files changed

+193
-22
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import liquidjava.rj_language.ast.BinaryExpression;
44
import liquidjava.rj_language.ast.Expression;
55
import liquidjava.rj_language.ast.LiteralBoolean;
6+
import liquidjava.rj_language.ast.UnaryExpression;
67
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
78
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
9+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
810
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
911

1012
public class ExpressionSimplifier {
@@ -15,7 +17,8 @@ public class ExpressionSimplifier {
1517
*/
1618
public static ValDerivationNode simplify(Expression exp) {
1719
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp);
18-
return simplifyValDerivationNode(fixedPoint);
20+
ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint);
21+
return unwrapDerivedBooleans(simplified);
1922
}
2023

2124
/**
@@ -123,4 +126,61 @@ private static boolean isRedundant(Expression exp) {
123126
}
124127
return false;
125128
}
129+
130+
/**
131+
* Recursively traverses the derivation tree and replaces boolean literals with the expressions that produced them,
132+
* but only when at least one operand in the derivation is non-boolean. e.g. "x == true" where true came from "1 >
133+
* 0" becomes "x == 1 > 0"
134+
*/
135+
static ValDerivationNode unwrapDerivedBooleans(ValDerivationNode node) {
136+
Expression value = node.getValue();
137+
DerivationNode origin = node.getOrigin();
138+
139+
if (origin == null)
140+
return node;
141+
142+
// unwrap binary expressions
143+
if (value instanceof BinaryExpression binExp && origin instanceof BinaryDerivationNode binOrigin) {
144+
ValDerivationNode left = unwrapDerivedBooleans(binOrigin.getLeft());
145+
ValDerivationNode right = unwrapDerivedBooleans(binOrigin.getRight());
146+
if (left != binOrigin.getLeft() || right != binOrigin.getRight()) {
147+
Expression newValue = new BinaryExpression(left.getValue(), binExp.getOperator(), right.getValue());
148+
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
149+
}
150+
return node;
151+
}
152+
153+
// unwrap unary expressions
154+
if (value instanceof UnaryExpression unaryExp && origin instanceof UnaryDerivationNode unaryOrigin) {
155+
ValDerivationNode operand = unwrapDerivedBooleans(unaryOrigin.getOperand());
156+
if (operand != unaryOrigin.getOperand()) {
157+
Expression newValue = new UnaryExpression(unaryExp.getOp(), operand.getValue());
158+
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
159+
}
160+
return node;
161+
}
162+
163+
// boolean literal with binary origin: unwrap if at least one child is non-boolean
164+
if (value instanceof LiteralBoolean && origin instanceof BinaryDerivationNode binOrigin) {
165+
ValDerivationNode left = unwrapDerivedBooleans(binOrigin.getLeft());
166+
ValDerivationNode right = unwrapDerivedBooleans(binOrigin.getRight());
167+
if (!(left.getValue() instanceof LiteralBoolean) || !(right.getValue() instanceof LiteralBoolean)) {
168+
Expression newValue = new BinaryExpression(left.getValue(), binOrigin.getOp(), right.getValue());
169+
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
170+
}
171+
return node;
172+
}
173+
174+
// boolean literal with unary origin: unwrap if operand is non-boolean
175+
if (value instanceof LiteralBoolean && origin instanceof UnaryDerivationNode unaryOrigin) {
176+
ValDerivationNode operand = unwrapDerivedBooleans(unaryOrigin.getOperand());
177+
if (!(operand.getValue() instanceof LiteralBoolean)) {
178+
Expression newValue = new UnaryExpression(unaryOrigin.getOp(), operand.getValue());
179+
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
180+
}
181+
return node;
182+
}
183+
184+
return node;
185+
}
126186
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 132 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ void testComplexArithmeticWithMultipleOperations() {
239239
// When
240240
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
241241

242-
// Then
242+
// Then: boolean literals are unwrapped to show the verified conditions
243243
assertNotNull(result, "Result should not be null");
244244
assertNotNull(result.getValue(), "Result value should not be null");
245-
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should be a boolean literal");
246-
assertTrue(result.getValue().isBooleanTrue(), "Expected result to be true");
245+
assertEquals("14 == 14 && 5 == 5 && 7 == 7 && 14 == 14", result.getValue().toString(),
246+
"All verified conditions should be visible instead of collapsed to true");
247247

248-
// 5 * 2 + 7 - 3
248+
// 5 * 2 + 7 - 3 = 14
249249
ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
250250
ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null);
251251
BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*");
@@ -262,39 +262,45 @@ void testComplexArithmeticWithMultipleOperations() {
262262
// 14 from variable c
263263
ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
264264

265-
// 14 == 14
265+
// 14 == 14 (unwrapped from true)
266266
BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "==");
267-
ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14);
267+
Expression expr14Eq14 = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
268+
ValDerivationNode compare14Node = new ValDerivationNode(expr14Eq14, compare14);
268269

269-
// a == 5 => true
270+
// a == 5 => 5 == 5 (unwrapped from true)
270271
ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
271272
ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null);
272273
BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "==");
273-
ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5);
274+
Expression expr5Eq5 = new BinaryExpression(new LiteralInt(5), "==", new LiteralInt(5));
275+
ValDerivationNode compare5Node = new ValDerivationNode(expr5Eq5, compareA5);
274276

275-
// b == 7 => true
277+
// b == 7 => 7 == 7 (unwrapped from true)
276278
ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
277279
ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null);
278280
BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "==");
279-
ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7);
281+
Expression expr7Eq7 = new BinaryExpression(new LiteralInt(7), "==", new LiteralInt(7));
282+
ValDerivationNode compare7Node = new ValDerivationNode(expr7Eq7, compareB7);
280283

281-
// (a == 5) && (b == 7) => true
282-
BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&");
283-
ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB);
284+
// (5 == 5) && (7 == 7) (unwrapped from true)
285+
BinaryDerivationNode andAB = new BinaryDerivationNode(compare5Node, compare7Node, "&&");
286+
Expression expr5And7 = new BinaryExpression(expr5Eq5, "&&", expr7Eq7);
287+
ValDerivationNode and5And7Node = new ValDerivationNode(expr5And7, andAB);
284288

285-
// c == 14 => true
289+
// c == 14 => 14 == 14 (unwrapped from true)
286290
ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
287291
ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null);
288292
BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "==");
289-
ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14);
293+
Expression expr14Eq14b = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
294+
ValDerivationNode compare14bNode = new ValDerivationNode(expr14Eq14b, compareC14);
290295

291-
// ((a == 5) && (b == 7)) && (c == 14) => true
292-
BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&");
293-
ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC);
296+
// ((5 == 5) && (7 == 7)) && (14 == 14) (unwrapped from true)
297+
BinaryDerivationNode andABC = new BinaryDerivationNode(and5And7Node, compare14bNode, "&&");
298+
Expression exprConditions = new BinaryExpression(expr5And7, "&&", expr14Eq14b);
299+
ValDerivationNode conditionsNode = new ValDerivationNode(exprConditions, andABC);
294300

295-
// 14 == 14 => true
296-
BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&");
297-
ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd);
301+
// (14 == 14) && ((5 == 5 && 7 == 7) && 14 == 14)
302+
BinaryDerivationNode finalAnd = new BinaryDerivationNode(compare14Node, conditionsNode, "&&");
303+
ValDerivationNode expected = new ValDerivationNode(result.getValue(), finalAnd);
298304

299305
// Compare the derivation trees
300306
assertDerivationEquals(expected, result, "");
@@ -599,6 +605,111 @@ void testNoSimplificationHasNoOrigin() {
599605
assertNull(result.getOrigin(), "No origin should be present when nothing was simplified");
600606
}
601607

608+
@Test
609+
void testShouldUnwrapBooleanInEquality() {
610+
// Given: x == (1 > 0)
611+
// Without unwrapping: x == true (unhelpful - hides what "true" came from)
612+
// Expected: x == 1 > 0 (unwrapped to show the original comparison)
613+
614+
Expression varX = new Var("x");
615+
Expression one = new LiteralInt(1);
616+
Expression zero = new LiteralInt(0);
617+
Expression oneGreaterZero = new BinaryExpression(one, ">", zero);
618+
Expression fullExpression = new BinaryExpression(varX, "==", oneGreaterZero);
619+
620+
// When
621+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
622+
623+
// Then
624+
assertNotNull(result, "Result should not be null");
625+
assertEquals("x == 1 > 0", result.getValue().toString(),
626+
"Boolean in equality should be unwrapped to show the original comparison");
627+
}
628+
629+
@Test
630+
void testShouldUnwrapBooleanInEqualityWithPropagation() {
631+
// Given: x == (a > b) && a == 3 && b == 1
632+
// Without unwrapping: x == true (unhelpful)
633+
// Expected: x == 3 > 1 (unwrapped and propagated)
634+
635+
Expression varX = new Var("x");
636+
Expression varA = new Var("a");
637+
Expression varB = new Var("b");
638+
Expression aGreaterB = new BinaryExpression(varA, ">", varB);
639+
Expression xEqualsComp = new BinaryExpression(varX, "==", aGreaterB);
640+
641+
Expression three = new LiteralInt(3);
642+
Expression aEquals3 = new BinaryExpression(varA, "==", three);
643+
Expression one = new LiteralInt(1);
644+
Expression bEquals1 = new BinaryExpression(varB, "==", one);
645+
646+
Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals1);
647+
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);
648+
649+
// When
650+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
651+
652+
// Then
653+
assertNotNull(result, "Result should not be null");
654+
assertEquals("x == 3 > 1", result.getValue().toString(),
655+
"Boolean in equality should be unwrapped after propagation");
656+
}
657+
658+
@Test
659+
void testShouldNotUnwrapBooleanWithBooleanChildren() {
660+
// Given: (y || true) && !true && y == false
661+
// Expected: false (both children of the fold are boolean, so no unwrapping needed)
662+
663+
Expression varY = new Var("y");
664+
Expression trueExp = new LiteralBoolean(true);
665+
Expression yOrTrue = new BinaryExpression(varY, "||", trueExp);
666+
Expression notTrue = new UnaryExpression("!", trueExp);
667+
Expression falseExp = new LiteralBoolean(false);
668+
Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp);
669+
670+
Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue);
671+
Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse);
672+
673+
// When
674+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
675+
676+
// Then: false stays as false since both sides in the derivation are booleans
677+
assertNotNull(result, "Result should not be null");
678+
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should remain a boolean");
679+
assertFalse(result.getValue().isBooleanTrue(), "Expected result to be false");
680+
}
681+
682+
@Test
683+
void testShouldUnwrapNestedBooleanInEquality() {
684+
// Given: x == (a + b > 10) && a == 3 && b == 5
685+
// Without unwrapping: x == true (unhelpful)
686+
// Expected: x == 8 > 10 (shows the actual comparison that produced the boolean)
687+
688+
Expression varX = new Var("x");
689+
Expression varA = new Var("a");
690+
Expression varB = new Var("b");
691+
Expression aPlusB = new BinaryExpression(varA, "+", varB);
692+
Expression ten = new LiteralInt(10);
693+
Expression comparison = new BinaryExpression(aPlusB, ">", ten);
694+
Expression xEqualsComp = new BinaryExpression(varX, "==", comparison);
695+
696+
Expression three = new LiteralInt(3);
697+
Expression aEquals3 = new BinaryExpression(varA, "==", three);
698+
Expression five = new LiteralInt(5);
699+
Expression bEquals5 = new BinaryExpression(varB, "==", five);
700+
701+
Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5);
702+
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);
703+
704+
// When
705+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
706+
707+
// Then
708+
assertNotNull(result, "Result should not be null");
709+
assertEquals("x == 8 > 10", result.getValue().toString(),
710+
"Boolean in equality should be unwrapped to show the computed comparison");
711+
}
712+
602713
/**
603714
* Helper method to compare two derivation nodes recursively
604715
*/

0 commit comments

Comments
 (0)