Skip to content

Commit 7c965de

Browse files
committed
Expand Aliases
1 parent f8d3f3e commit 7c965de

File tree

5 files changed

+187
-10
lines changed

5 files changed

+187
-10
lines changed

liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ public void processSubtyping(Predicate expectedType, List<GhostState> list, CtEl
5858
}
5959
SMTResult result = verifySMTSubtype(expected, premises, element.getPosition());
6060
if (result.isError()) {
61-
throw new RefinementError(element.getPosition(), expectedType.simplify(), premisesBeforeChange.simplify(),
62-
map, result.getCounterexample(), customMessage);
61+
throw new RefinementError(element.getPosition(), expectedType.simplify(context),
62+
premisesBeforeChange.simplify(context), map, result.getCounterexample(), customMessage);
6363
}
6464
}
6565

@@ -277,7 +277,7 @@ protected void throwRefinementError(SourcePosition position, Predicate expected,
277277
gatherVariables(found, lrv, mainVars);
278278
TranslationTable map = new TranslationTable();
279279
Predicate premises = joinPredicates(expected, mainVars, lrv, map).toConjunctions();
280-
throw new RefinementError(position, expected.simplify(), premises.simplify(), map, counterexample,
280+
throw new RefinementError(position, expected.simplify(context), premises.simplify(context), map, counterexample,
281281
customMessage);
282282
}
283283

@@ -288,7 +288,7 @@ protected void throwStateRefinementError(SourcePosition position, Predicate foun
288288
TranslationTable map = new TranslationTable();
289289
VCImplication foundState = joinPredicates(found, mainVars, lrv, map);
290290
throw new StateRefinementError(position, expected.getExpression(),
291-
foundState.toConjunctions().simplify().getValue(), map, customMessage);
291+
foundState.toConjunctions().simplify(context).getValue(), map, customMessage);
292292
}
293293

294294
protected void throwStateConflictError(SourcePosition position, Predicate expected) throws StateConflictError {

liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,14 @@ public Expression getExpression() {
187187
return exp;
188188
}
189189

190-
public ValDerivationNode simplify() {
191-
return ExpressionSimplifier.simplify(exp.clone());
190+
public ValDerivationNode simplify(Context context) {
191+
// collect aliases from context
192+
Map<String, AliasDTO> aliases = new HashMap<>();
193+
for (AliasWrapper aw : context.getAliases()) {
194+
aliases.put(aw.getName(), aw.createAliasDTO());
195+
}
196+
// simplify expression
197+
return ExpressionSimplifier.simplify(exp.clone(), aliases);
192198
}
193199

194200
private static boolean isBooleanLiteral(Expression expr, boolean value) {
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package liquidjava.rj_language.opt;
2+
3+
import java.util.Map;
4+
5+
import liquidjava.processor.facade.AliasDTO;
6+
import liquidjava.rj_language.ast.AliasInvocation;
7+
import liquidjava.rj_language.ast.BinaryExpression;
8+
import liquidjava.rj_language.ast.Expression;
9+
import liquidjava.rj_language.ast.GroupExpression;
10+
import liquidjava.rj_language.ast.UnaryExpression;
11+
import liquidjava.rj_language.ast.Var;
12+
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
13+
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
14+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
15+
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
16+
17+
public class AliasExpansion {
18+
19+
/**
20+
* Expands alias invocations in a derivation node to their definitions, storing the expanded body as the origin of
21+
* each alias invocation node.
22+
*/
23+
public static ValDerivationNode expand(ValDerivationNode node, Map<String, AliasDTO> aliases) {
24+
return expandRecursive(node, aliases);
25+
}
26+
27+
private static ValDerivationNode expandRecursive(ValDerivationNode node, Map<String, AliasDTO> aliases) {
28+
Expression exp = node.getValue();
29+
30+
// expand alias invocation
31+
if (exp instanceof AliasInvocation ai) {
32+
return expandAlias(ai, aliases);
33+
}
34+
35+
// recurse into binary expressions
36+
if (exp instanceof BinaryExpression binary) {
37+
ValDerivationNode leftNode;
38+
ValDerivationNode rightNode;
39+
if (node.getOrigin()instanceof BinaryDerivationNode binOrigin) {
40+
leftNode = expandRecursive(binOrigin.getLeft(), aliases);
41+
rightNode = expandRecursive(binOrigin.getRight(), aliases);
42+
} else {
43+
leftNode = expandRecursive(new ValDerivationNode(binary.getFirstOperand(), null), aliases);
44+
rightNode = expandRecursive(new ValDerivationNode(binary.getSecondOperand(), null), aliases);
45+
}
46+
boolean hasExpansion = leftNode.getOrigin() != null || rightNode.getOrigin() != null;
47+
DerivationNode origin = hasExpansion ? new BinaryDerivationNode(leftNode, rightNode, binary.getOperator())
48+
: node.getOrigin();
49+
return new ValDerivationNode(exp, origin);
50+
}
51+
52+
// recurse into unary expressions
53+
if (exp instanceof UnaryExpression unary) {
54+
ValDerivationNode operandNode;
55+
if (node.getOrigin()instanceof UnaryDerivationNode unaryOrigin) {
56+
operandNode = expandRecursive(unaryOrigin.getOperand(), aliases);
57+
} else {
58+
operandNode = expandRecursive(new ValDerivationNode(unary.getChildren().get(0), null), aliases);
59+
}
60+
DerivationNode origin = operandNode.getOrigin() != null
61+
? new UnaryDerivationNode(operandNode, unary.getOp()) : node.getOrigin();
62+
return new ValDerivationNode(exp, origin);
63+
}
64+
65+
// recurse into group expressions
66+
if (exp instanceof GroupExpression group && group.getChildren().size() == 1) {
67+
return expandRecursive(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()), aliases);
68+
}
69+
70+
return node;
71+
}
72+
73+
private static ValDerivationNode expandAlias(AliasInvocation ai, Map<String, AliasDTO> aliases) {
74+
AliasDTO dto = aliases.get(ai.getName());
75+
76+
// no alias found
77+
if (dto == null || dto.getExpression() == null) {
78+
return new ValDerivationNode(ai, null);
79+
}
80+
81+
// substitute parameters in the body with the invocation arguments
82+
Expression body = dto.getExpression().clone();
83+
for (int i = 0; i < ai.getArgs().size() && i < dto.getVarNames().size(); i++) {
84+
body = body.substitute(new Var(dto.getVarNames().get(i)), ai.getArgs().get(i));
85+
}
86+
87+
// recursively expand the body
88+
ValDerivationNode expandedBody = expandRecursive(new ValDerivationNode(body, null), aliases);
89+
return new ValDerivationNode(ai, expandedBody);
90+
}
91+
}

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package liquidjava.rj_language.opt;
22

3+
import java.util.Map;
4+
5+
import liquidjava.processor.facade.AliasDTO;
36
import liquidjava.rj_language.ast.BinaryExpression;
47
import liquidjava.rj_language.ast.Expression;
58
import liquidjava.rj_language.ast.LiteralBoolean;
@@ -10,12 +13,18 @@
1013
public class ExpressionSimplifier {
1114

1215
/**
13-
* Simplifies an expression by applying constant propagation, constant folding and removing redundant conjuncts
14-
* Returns a derivation node representing the tree of simplifications applied
16+
* Simplifies an expression by applying constant propagation, constant folding, removing redundant conjuncts and
17+
* expanding aliases Returns a derivation node representing the tree of simplifications applied
1518
*/
19+
public static ValDerivationNode simplify(Expression exp, Map<String, AliasDTO> aliases) {
20+
ValDerivationNode node = new ValDerivationNode(exp, null);
21+
ValDerivationNode fixedPoint = simplifyToFixedPoint(node, exp);
22+
ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint);
23+
return AliasExpansion.expand(simplified, aliases);
24+
}
25+
1626
public static ValDerivationNode simplify(Expression exp) {
17-
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp);
18-
return simplifyValDerivationNode(fixedPoint);
27+
return simplify(exp, Map.of());
1928
}
2029

2130
/**

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
import static org.junit.jupiter.api.Assertions.*;
44

5+
import java.util.List;
6+
import java.util.Map;
7+
8+
import liquidjava.processor.facade.AliasDTO;
9+
import liquidjava.rj_language.ast.AliasInvocation;
510
import liquidjava.rj_language.ast.BinaryExpression;
611
import liquidjava.rj_language.ast.Expression;
712
import liquidjava.rj_language.ast.LiteralBoolean;
@@ -550,6 +555,72 @@ void testTransitive() {
550555
assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1");
551556
}
552557

558+
@Test
559+
void testByteAliasExpansion() {
560+
// Given: Byte(b) with alias Byte(int b) { b >= -128 && b <= 127 }
561+
AliasDTO byteAlias = new AliasDTO("Byte", List.of("int"), List.of("b"), "b >= -128 && b <= 127");
562+
byteAlias.parse("");
563+
Map<String, AliasDTO> aliases = Map.of("Byte", byteAlias);
564+
Expression exp = new AliasInvocation("Byte", List.of(new Var("b")));
565+
566+
// When
567+
ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases);
568+
569+
// Then
570+
assertEquals("Byte(b)", result.getValue().toString());
571+
assertNotNull(result.getOrigin(), "Origin should contain the expanded body");
572+
ValDerivationNode origin = (ValDerivationNode) result.getOrigin();
573+
assertEquals("b >= -128 && b <= 127", origin.getValue().toString());
574+
}
575+
576+
@Test
577+
void testPositiveAliasExpansion() {
578+
// Given: Positive(x) with alias Positive(int v) { v > 0 }
579+
AliasDTO positiveAlias = new AliasDTO("Positive", List.of("int"), List.of("v"), "v > 0");
580+
positiveAlias.parse("");
581+
Map<String, AliasDTO> aliases = Map.of("Positive", positiveAlias);
582+
Expression exp = new AliasInvocation("Positive", List.of(new Var("x")));
583+
584+
// When
585+
ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases);
586+
587+
// Then
588+
assertEquals("Positive(x)", result.getValue().toString());
589+
assertNotNull(result.getOrigin(), "Origin should contain the expanded body");
590+
ValDerivationNode origin = (ValDerivationNode) result.getOrigin();
591+
assertEquals("x > 0", origin.getValue().toString());
592+
}
593+
594+
@Test
595+
void testTwoArgAliasWithNormalExpression() {
596+
// Given: Bounded(v, 100) && v > 50 with alias Bounded(int x, int n) { x > 0 && x < n }
597+
AliasDTO boundedAlias = new AliasDTO("Bounded", List.of("int", "int"), List.of("x", "n"), "x > 0 && x < n");
598+
boundedAlias.parse("");
599+
Map<String, AliasDTO> aliases = Map.of("Bounded", boundedAlias);
600+
601+
Expression varV = new Var("v");
602+
Expression bounded = new AliasInvocation("Bounded", List.of(varV, new LiteralInt(100)));
603+
Expression vGt50 = new BinaryExpression(varV, ">", new LiteralInt(50));
604+
Expression fullExpression = new BinaryExpression(bounded, "&&", vGt50);
605+
606+
// When
607+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression, aliases);
608+
609+
// Then
610+
assertEquals("Bounded(v, 100) && v > 50", result.getValue().toString());
611+
assertInstanceOf(BinaryDerivationNode.class, result.getOrigin());
612+
BinaryDerivationNode binOrigin = (BinaryDerivationNode) result.getOrigin();
613+
assertEquals("&&", binOrigin.getOp());
614+
ValDerivationNode leftNode = binOrigin.getLeft();
615+
assertEquals("Bounded(v, 100)", leftNode.getValue().toString());
616+
assertNotNull(leftNode.getOrigin(), "Alias invocation should have expanded body as origin");
617+
ValDerivationNode expandedBody = (ValDerivationNode) leftNode.getOrigin();
618+
assertEquals("v > 0 && v < 100", expandedBody.getValue().toString());
619+
ValDerivationNode rightNode = binOrigin.getRight();
620+
assertEquals("v > 50", rightNode.getValue().toString());
621+
assertNull(rightNode.getOrigin());
622+
}
623+
553624
/**
554625
* Helper method to compare two derivation nodes recursively
555626
*/

0 commit comments

Comments
 (0)