diff --git a/source_.jar b/source_.jar index 8869b751..f9adaff9 100644 Binary files a/source_.jar and b/source_.jar differ diff --git a/src/main/cod/demo/src/bin/test.array/__StaticModule__.codb b/src/main/cod/demo/src/bin/test.array/__StaticModule__.codb index b7fe5913..a69efd2b 100644 Binary files a/src/main/cod/demo/src/bin/test.array/__StaticModule__.codb and b/src/main/cod/demo/src/bin/test.array/__StaticModule__.codb differ diff --git a/src/main/cod/demo/src/bin/test.broadcast/Broadcaster.codb b/src/main/cod/demo/src/bin/test.broadcast/Broadcaster.codb new file mode 100644 index 00000000..4104b1da Binary files /dev/null and b/src/main/cod/demo/src/bin/test.broadcast/Broadcaster.codb differ diff --git a/src/main/cod/demo/src/bin/test.lambda/LambdaSuite.codb b/src/main/cod/demo/src/bin/test.lambda/LambdaSuite.codb new file mode 100644 index 00000000..16c60bd2 Binary files /dev/null and b/src/main/cod/demo/src/bin/test.lambda/LambdaSuite.codb differ diff --git a/src/main/cod/demo/src/main/test/broadcast/Broadcast.cod b/src/main/cod/demo/src/main/test/broadcast/Broadcast.cod index ac70fc3b..e6586c97 100644 --- a/src/main/cod/demo/src/main/test/broadcast/Broadcast.cod +++ b/src/main/cod/demo/src/main/test/broadcast/Broadcast.cod @@ -40,6 +40,24 @@ Broadcaster { out("reduced = " + reduced + " (expect 10)") out("reducedNamed = " + reducedNamed + " (expect 10)") out("reduceResult = " + reduceResult + " (expect 10)") + + scanned := [4, 5, 6].scan() + out("scanned[1] = " + scanned[1] + " (expect 5)") + + zipped := [1, 2, 3].zip([10, 20, 30]) + out("zipped[0][0] = " + zipped[0][0] + " (expect 1)") + out("zipped[0][1] = " + zipped[0][1] + " (expect 10)") + out("zipped[2][0] = " + zipped[2][0] + " (expect 3)") + out("zipped[2][1] = " + zipped[2][1] + " (expect 30)") + + zippedCombined := [1, 2, 3].zip([10, 20, 30], \(left, right) left + right) + out("zippedCombined[0] = " + zippedCombined[0] + " (expect 11)") + out("zippedCombined[2] = " + zippedCombined[2] + " (expect 33)") + + emptyArr := [] + nonEmptyArr := [1] + out("emptyArr.isEmpty() = " + emptyArr.isEmpty() + " (expect true)") + out("nonEmptyArr.isEmpty() = " + nonEmptyArr.isEmpty() + " (expect false)") } } diff --git a/src/main/cod/demo/src/main/test/multiarray/ConditionalFormulaOptimization.cod b/src/main/cod/demo/src/main/test/multiarray/ConditionalFormulaOptimization.cod new file mode 100644 index 00000000..abb82ef1 --- /dev/null +++ b/src/main/cod/demo/src/main/test/multiarray/ConditionalFormulaOptimization.cod @@ -0,0 +1,68 @@ +unit test.multiarray + +local isUnderThreshold(value: int, threshold: int) :: ok: bool { + ~> (ok: value < threshold) +} + +share main() { + out("=== Conditional formula optimization ===") + + optimized := [0 to 60] + for i of [0 to 60] { + if i % 2 == 0 { + base := i * 3 + optimized[i] = base + 1 + } elif i % 5 == 0 { + shifted := i + 7 + optimized[i] = shifted * 2 + } else { + optimized[i] = i - 1 + } + } + + out("optimized[10]=" + optimized[10] + " expected=31") + out("optimized[25]=" + optimized[25] + " expected=64") + out("optimized[7]=" + optimized[7] + " expected=6") + + fallback := [0 to 40] + for i of [0 to 40] { + if isUnderThreshold(i, 12) { + fallback[i] = i + 100 + } else { + fallback[i] = i - 100 + } + } + + out("fallback[5]=" + fallback[5] + " expected=105") + out("fallback[20]=" + fallback[20] + " expected=-80") + + constantTrue := [0 to 12] + for i of [0 to 12] { + if true { + constantTrue[i] = i + 9 + } else { + constantTrue[i] = i - 9 + } + } + out("constantTrue[3]=" + constantTrue[3] + " expected=12") + + constantFalse := [0 to 12] + for i of [0 to 12] { + if false { + constantFalse[i] = i + 9 + } else { + constantFalse[i] = i - 9 + } + } + out("constantFalse[3]=" + constantFalse[3] + " expected=-6") + + fusedMapped := optimized.map(\(x) x + 3).map(\(x) x * 2) + fusedFiltered := fusedMapped.filter(">=", 40).filter("<=", 150) + out("fusedMapped[0]=" + fusedMapped[0] + " expected=8") + out("fusedMapped[1]=" + fusedMapped[1] + " expected=6") + out("fusedFiltered[0]=" + fusedFiltered[0] + " expected=54") + + zippedFused := optimized.map(\(x) x + 1).zip(optimized.map(\(x) x + 2), \(left, right) left + right) + out("zippedFused[0]=" + zippedFused[0] + " expected=5") + out("zippedFused[1]=" + zippedFused[1] + " expected=3") +} diff --git a/src/main/java/cod/interpreter/registry/LiteralRegistry.java b/src/main/java/cod/interpreter/registry/LiteralRegistry.java index c5d2abb4..87a92848 100644 --- a/src/main/java/cod/interpreter/registry/LiteralRegistry.java +++ b/src/main/java/cod/interpreter/registry/LiteralRegistry.java @@ -110,6 +110,26 @@ public Object handle(Object literal, List arguments, ExecutionContext ct NaturalArray.class, List.class ); + define("scan", + new MethodHandler() { + @Override + public Object handle(Object literal, List arguments, ExecutionContext ctx) { + return handleArrayScan(literal, arguments); + } + }, + NaturalArray.class, List.class + ); + + define("zip", + new MethodHandler() { + @Override + public Object handle(Object literal, List arguments, ExecutionContext ctx) { + return handleArrayZip(literal, arguments, ctx); + } + }, + NaturalArray.class, List.class + ); + define("has", new MethodHandler() { @Override @@ -194,10 +214,10 @@ public Object handle(Object literal, List arguments, ExecutionContext ct new MethodHandler() { @Override public Object handle(Object literal, List arguments, ExecutionContext ctx) { - return handleStringIsEmpty(literal, arguments); + return handleIsEmpty(literal, arguments); } }, - String.class + String.class, NaturalArray.class, List.class ); define("isBlank", @@ -524,6 +544,22 @@ private Object handleStringIsEmpty(Object literal, List arguments) { return target.isEmpty(); } + @SuppressWarnings("unchecked") + private Object handleIsEmpty(Object literal, List arguments) { + requireArgCount("isEmpty", arguments, 0); + if (literal instanceof String) { + return ((String) literal).isEmpty(); + } + NaturalArray naturalArray = asNaturalArray(literal); + if (naturalArray != null) { + return naturalArray.size() == 0L; + } + if (literal instanceof List) { + return ((List) literal).isEmpty(); + } + throw new ProgramError("isEmpty is not supported on " + literal.getClass().getSimpleName()); + } + private Object handleStringIsBlank(Object literal, List arguments) { requireArgCount("isBlank", arguments, 0); String target = requireStringTarget(literal, "isBlank"); @@ -623,6 +659,9 @@ private Object handleArrayMap(Object literal, List arguments, ExecutionC if (arguments == null || arguments.isEmpty()) { throw new ProgramError("map expects a callback or (operator, operand)"); } + final LazyNaturalArrayMapView sourceMapView = literal instanceof LazyNaturalArrayMapView + ? (LazyNaturalArrayMapView) literal + : null; final NaturalArray naturalArray = asNaturalArray(literal); List source = naturalArray == null ? asConcreteList(literal) : null; @@ -630,6 +669,14 @@ private Object handleArrayMap(Object literal, List arguments, ExecutionC final String op = String.valueOf(arguments.get(0)); final Object operand = arguments.get(1); final TypeHandler typeHandler = ctx.getTypeHandler(); + if (sourceMapView != null) { + return sourceMapView.compose(new NaturalArrayMapper() { + @Override + public Object map(long index, Object value) { + return applyOperator(typeHandler, value, op, operand); + } + }); + } if (naturalArray != null) { return new LazyNaturalArrayMapView(naturalArray, new NaturalArrayMapper() { @Override @@ -649,6 +696,14 @@ public Object map(long index, Object value) { throw new ProgramError("map callback mode expects exactly one argument"); } final Object callback = arguments.get(0); + if (sourceMapView != null) { + return sourceMapView.compose(new NaturalArrayMapper() { + @Override + public Object map(long index, Object value) { + return invokeArrayCallback(callback, "map", ctx, value, Integer.valueOf((int) index)); + } + }); + } if (naturalArray != null) { return new LazyNaturalArrayMapView(naturalArray, new NaturalArrayMapper() { @Override @@ -668,6 +723,9 @@ private Object handleArrayFilter(Object literal, List arguments, Executi if (arguments == null || arguments.isEmpty()) { throw new ProgramError("filter expects a callback or (operator, operand)"); } + final LazyNaturalArrayFilterView sourceFilterView = literal instanceof LazyNaturalArrayFilterView + ? (LazyNaturalArrayFilterView) literal + : null; final NaturalArray naturalArray = asNaturalArray(literal); List source = naturalArray == null ? asConcreteList(literal) : null; @@ -675,6 +733,15 @@ private Object handleArrayFilter(Object literal, List arguments, Executi final String op = String.valueOf(arguments.get(0)); final Object operand = arguments.get(1); final TypeHandler typeHandler = ctx.getTypeHandler(); + if (sourceFilterView != null) { + return sourceFilterView.compose(new NaturalArrayPredicate() { + @Override + public boolean test(long index, Object value) { + Object comparison = compareWithOperator(typeHandler, value, op, operand); + return isTruthy(comparison); + } + }); + } if (naturalArray != null) { return new LazyNaturalArrayFilterView(naturalArray, new NaturalArrayPredicate() { @Override @@ -773,6 +840,30 @@ private Object handleArrayReduce(Object literal, List arguments, Executi return accumulator; } + private Object handleArrayScan(Object literal, List arguments) { + requireArgCount("scan", arguments, 0); + NaturalArray naturalArray = asNaturalArray(literal); + if (naturalArray != null) { + return new LazyNaturalArrayMapView(naturalArray, new NaturalArrayMapper() { + @Override + public Object map(long index, Object value) { + return value; + } + }); + } + return asConcreteList(literal); + } + + private Object handleArrayZip(Object literal, List arguments, ExecutionContext ctx) { + if (arguments == null || arguments.size() < 1 || arguments.size() > 2) { + throw new ProgramError("zip expects (list or NaturalArray) or (list or NaturalArray, callback)"); + } + ArrayZipSource left = toZipSource(literal); + ArrayZipSource right = toZipSource(arguments.get(0)); + final Object callback = arguments.size() == 2 ? arguments.get(1) : null; + return new LazyArrayZipView(left, right, callback, ctx); + } + private interface NaturalArrayMapper { Object map(long index, Object value); } @@ -781,6 +872,100 @@ private interface NaturalArrayPredicate { boolean test(long index, Object value); } + private interface ArrayZipSource { + long size(); + Object get(long index); + } + + private ArrayZipSource toZipSource(Object obj) { + NaturalArray natural = asNaturalArray(obj); + if (natural != null) { + return new NaturalArrayZipSource(natural); + } + if (obj instanceof LazyNaturalArrayMapView) { + return new LazyMapZipSource((LazyNaturalArrayMapView) obj); + } + if (obj instanceof LazyNaturalArrayFilterView) { + return new LazyFilterZipSource((LazyNaturalArrayFilterView) obj); + } + if (obj instanceof List) { + return new ListZipSource(asConcreteList(obj)); + } + throw new ProgramError("zip expects list or NaturalArray as source"); + } + + private static final class NaturalArrayZipSource implements ArrayZipSource { + private final NaturalArray source; + + private NaturalArrayZipSource(NaturalArray source) { + this.source = source; + } + + @Override + public long size() { + return source.size(); + } + + @Override + public Object get(long index) { + return source.get(index); + } + } + + private static final class ListZipSource implements ArrayZipSource { + private final List source; + + private ListZipSource(List source) { + this.source = source; + } + + @Override + public long size() { + return source.size(); + } + + @Override + public Object get(long index) { + return source.get((int) index); + } + } + + private static final class LazyMapZipSource implements ArrayZipSource { + private final LazyNaturalArrayMapView source; + + private LazyMapZipSource(LazyNaturalArrayMapView source) { + this.source = source; + } + + @Override + public long size() { + return source.size(); + } + + @Override + public Object get(long index) { + return source.get((int) index); + } + } + + private static final class LazyFilterZipSource implements ArrayZipSource { + private final LazyNaturalArrayFilterView source; + + private LazyFilterZipSource(LazyNaturalArrayFilterView source) { + this.source = source; + } + + @Override + public long size() { + return source.size(); + } + + @Override + public Object get(long index) { + return source.get((int) index); + } + } + private static final class LazyNaturalArrayMapView extends AbstractList { private final NaturalArray source; private final NaturalArrayMapper mapper; @@ -798,6 +983,16 @@ private LazyNaturalArrayMapView(NaturalArray source, NaturalArrayMapper mapper) this.size = (int) sourceSize; } + private LazyNaturalArrayMapView compose(final NaturalArrayMapper nextMapper) { + return new LazyNaturalArrayMapView(source, new NaturalArrayMapper() { + @Override + public Object map(long index, Object value) { + Object current = mapper.map(index, value); + return nextMapper.map(index, current); + } + }); + } + @Override public Object get(int index) { if (index < 0 || index >= size) { @@ -813,6 +1008,47 @@ public int size() { } } + private final class LazyArrayZipView extends AbstractList { + private final ArrayZipSource left; + private final ArrayZipSource right; + private final Object callback; + private final ExecutionContext ctx; + private final int size; + + private LazyArrayZipView(ArrayZipSource left, ArrayZipSource right, Object callback, ExecutionContext ctx) { + this.left = left; + this.right = right; + this.callback = callback; + this.ctx = ctx; + long zippedSize = Math.min(left.size(), right.size()); + if (zippedSize > Integer.MAX_VALUE) { + throw new ProgramError( + "Zipped array size " + zippedSize + " exceeds Integer.MAX_VALUE" + ); + } + this.size = (int) zippedSize; + } + + @Override + public Object get(int index) { + if (index < 0 || index >= size) { + throw new ProgramError("Index: " + index + ", Size: " + size); + } + long idx = (long) index; + Object leftValue = left.get(idx); + Object rightValue = right.get(idx); + if (callback != null) { + return invokeArrayCallback(callback, "zip", ctx, leftValue, rightValue, index); + } + return Arrays.asList(leftValue, rightValue); + } + + @Override + public int size() { + return size; + } + } + private static final class LazyNaturalArrayFilterView extends AbstractList { private final NaturalArray source; private final NaturalArrayPredicate predicate; @@ -830,6 +1066,15 @@ private LazyNaturalArrayFilterView(NaturalArray source, NaturalArrayPredicate pr this.fullyScanned = false; } + private LazyNaturalArrayFilterView compose(final NaturalArrayPredicate nextPredicate) { + return new LazyNaturalArrayFilterView(source, new NaturalArrayPredicate() { + @Override + public boolean test(long index, Object value) { + return predicate.test(index, value) && nextPredicate.test(index, value); + } + }); + } + @Override public Object get(int index) { if (index < 0) { diff --git a/src/main/java/cod/range/NaturalArray.java b/src/main/java/cod/range/NaturalArray.java index 9de0903b..2ef9b447 100644 --- a/src/main/java/cod/range/NaturalArray.java +++ b/src/main/java/cod/range/NaturalArray.java @@ -1407,7 +1407,14 @@ public void addConditionalFormula(ConditionalFormula formula) { ArrayTracker.recordFormulaApplication(this); } - conditionalFormulas.add(formula); + if (conditionalFormulas.isEmpty()) { + conditionalFormulas.add(formula); + } else { + int lastIndex = conditionalFormulas.size() - 1; + ConditionalFormula current = conditionalFormulas.get(lastIndex); + ConditionalFormula merged = ConditionalFormula.compose(formula, current); + conditionalFormulas.set(lastIndex, merged); + } clearCache(); } diff --git a/src/main/java/cod/range/formula/ConditionalFormula.java b/src/main/java/cod/range/formula/ConditionalFormula.java index cccc443e..1fd4c284 100644 --- a/src/main/java/cod/range/formula/ConditionalFormula.java +++ b/src/main/java/cod/range/formula/ConditionalFormula.java @@ -1,7 +1,8 @@ -// In ConditionalFormula.java package cod.range.formula; +import cod.ast.ASTFactory; import cod.ast.node.*; +import cod.error.ProgramError; import cod.interpreter.Evaluator; import cod.interpreter.context.ExecutionContext; import cod.interpreter.handler.TypeHandler; @@ -11,10 +12,17 @@ public class ConditionalFormula { public final long start; public final long end; public final String indexVar; - public final List conditions; - public final List> branchStatements; - public final List elseStatements; - + + private final Expr unifiedExpression; + private final boolean usesUnifiedExpression; + + private final List conditions; + private final List> branchStatements; + private final List elseStatements; + + private final ConditionalFormula newerFormula; + private final ConditionalFormula olderFormula; + public ConditionalFormula(long start, long end, String indexVar, List conditions, List> branchStatements, @@ -25,44 +33,739 @@ public ConditionalFormula(long start, long end, String indexVar, this.conditions = conditions != null ? conditions : new ArrayList(); this.branchStatements = branchStatements != null ? branchStatements : new ArrayList>(); this.elseStatements = elseStatements != null ? elseStatements : new ArrayList(); + this.newerFormula = null; + this.olderFormula = null; + + Expr built = tryBuildUnifiedExpression(); + this.unifiedExpression = built; + this.usesUnifiedExpression = built != null; } - + + private ConditionalFormula(long start, long end, String indexVar, + ConditionalFormula newerFormula, + ConditionalFormula olderFormula) { + this.start = start; + this.end = end; + this.indexVar = indexVar; + this.newerFormula = newerFormula; + this.olderFormula = olderFormula; + + this.conditions = new ArrayList(); + this.branchStatements = new ArrayList>(); + this.elseStatements = new ArrayList(); + this.unifiedExpression = null; + this.usesUnifiedExpression = false; + } + + public static ConditionalFormula compose(ConditionalFormula newerFormula, ConditionalFormula olderFormula) { + if (newerFormula == null) return olderFormula; + if (olderFormula == null) return newerFormula; + long mergedStart = Math.min(newerFormula.start, olderFormula.start); + long mergedEnd = Math.max(newerFormula.end, olderFormula.end); + String mergedIndexVar = newerFormula.indexVar != null ? newerFormula.indexVar : olderFormula.indexVar; + return new ConditionalFormula(mergedStart, mergedEnd, mergedIndexVar, newerFormula, olderFormula); + } + public boolean contains(long index) { + if (isComposite()) { + return (newerFormula != null && newerFormula.contains(index)) + || (olderFormula != null && olderFormula.contains(index)); + } return index >= start && index <= end; } - + public Object evaluate(long index, Evaluator evaluator, ExecutionContext context) { - TypeHandler typeSystem = new TypeHandler(); // or get from somewhere - - // Create base context with index variable + if (isComposite()) { + return evaluateComposite(index, evaluator, context); + } + ExecutionContext evalCtx = context.copyWithVariable(indexVar, index, null); - - try { - // Try each branch in order - for (int i = 0; i < conditions.size(); i++) { - Object condResult = evaluator.evaluate(conditions.get(i), evalCtx); - if (typeSystem.isTruthy(typeSystem.unwrap(condResult))) { - // Execute branch statements in sequence - return executeStatementSequence(branchStatements.get(i), evaluator, evalCtx); + if (usesUnifiedExpression) { + try { + return evaluator.evaluate(unifiedExpression, evalCtx); + } catch (ProgramError e) { + return evaluateLegacy(evalCtx, evaluator); + } catch (Exception e) { + return evaluateLegacy(evalCtx, evaluator); + } + } + return evaluateLegacy(evalCtx, evaluator); + } + + private boolean isComposite() { + return newerFormula != null || olderFormula != null; + } + + private Object evaluateComposite(long index, Evaluator evaluator, ExecutionContext context) { + if (newerFormula != null && newerFormula.contains(index)) { + return newerFormula.evaluate(index, evaluator, context); + } + if (olderFormula != null && olderFormula.contains(index)) { + return olderFormula.evaluate(index, evaluator, context); + } + return null; + } + + private Object evaluateLegacy(ExecutionContext evalCtx, Evaluator evaluator) { + TypeHandler typeSystem = new TypeHandler(); + for (int i = 0; i < conditions.size(); i++) { + Object condResult = evaluator.evaluate(conditions.get(i), evalCtx); + if (typeSystem.isTruthy(typeSystem.unwrap(condResult))) { + return executeStatementSequence(branchStatements.get(i), evaluator, evalCtx); + } + } + return executeStatementSequence(elseStatements, evaluator, evalCtx); + } + + private Expr tryBuildUnifiedExpression() { + if (conditions.isEmpty() || branchStatements.isEmpty() || conditions.size() != branchStatements.size()) { + return null; + } + if (indexVar == null || indexVar.isEmpty()) { + return null; + } + + List indicatorExpressions = new ArrayList(conditions.size()); + List branchExpressions = new ArrayList(conditions.size()); + + for (int i = 0; i < conditions.size(); i++) { + Expr condition = conditions.get(i); + if (condition == null || !isPureExpression(condition)) { + return null; + } + Expr indicator = buildNumericIndicator(condition); + if (indicator == null) { + return null; + } + indicatorExpressions.add(indicator); + + Expr branchExpr = extractPureBranchExpression(branchStatements.get(i)); + if (branchExpr == null || !isPureExpression(branchExpr)) { + return null; + } + branchExpressions.add(branchExpr); + } + + Expr elseExpr = extractPureBranchExpression(elseStatements); + if (elseExpr == null || !isPureExpression(elseExpr)) { + return null; + } + + Expr unified = cloneExpr(elseExpr); + for (int i = indicatorExpressions.size() - 1; i >= 0; i--) { + Expr indicator = cloneExpr(indicatorExpressions.get(i)); + Expr branchExpr = cloneExpr(branchExpressions.get(i)); + Expr complementIndicator = simplifyExpr(ASTFactory.createBinaryOp(one(), "-", indicator, null)); + Expr leftTerm = simplifyExpr(ASTFactory.createBinaryOp(indicator, "*", branchExpr, null)); + Expr rightTerm = simplifyExpr(ASTFactory.createBinaryOp(complementIndicator, "*", unified, null)); + unified = simplifyExpr(ASTFactory.createBinaryOp(leftTerm, "+", rightTerm, null)); + } + return simplifyExpr(unified); + } + + private Expr extractPureBranchExpression(List statements) { + if (statements == null) return null; + Map tempExpressions = new HashMap(); + Expr finalExpression = null; + + for (Stmt stmt : statements) { + if (stmt instanceof Var) { + Var var = (Var) stmt; + if (var.name == null || var.value == null || !isPureExpression(var.value)) { + return null; + } + Expr resolved = substituteIdentifiers(cloneExpr(var.value), tempExpressions); + if (!isPureExpression(resolved)) { + return null; + } + tempExpressions.put(var.name, resolved); + continue; + } + + if (stmt instanceof Assignment) { + Assignment assignment = (Assignment) stmt; + if (assignment.left instanceof Identifier && assignment.isDeclaration) { + Identifier id = (Identifier) assignment.left; + if (id.name == null || assignment.right == null || !isPureExpression(assignment.right)) { + return null; + } + Expr resolved = substituteIdentifiers(cloneExpr(assignment.right), tempExpressions); + if (!isPureExpression(resolved)) { + return null; + } + tempExpressions.put(id.name, resolved); + continue; + } + + if (assignment.left instanceof IndexAccess) { + IndexAccess indexAccess = (IndexAccess) assignment.left; + if (!(indexAccess.index instanceof Identifier)) { + return null; + } + Identifier idx = (Identifier) indexAccess.index; + if (idx.name == null || !idx.name.equals(indexVar) || assignment.right == null) { + return null; + } + Expr resolved = substituteIdentifiers(cloneExpr(assignment.right), tempExpressions); + if (!isPureExpression(resolved)) { + return null; + } + finalExpression = resolved; + continue; + } + } + + return null; + } + + return finalExpression; + } + + private Expr substituteIdentifiers(Expr expr, Map replacements) { + if (expr == null) return null; + + if (expr instanceof Identifier) { + Identifier id = (Identifier) expr; + Expr replacement = replacements.get(id.name); + return replacement != null ? cloneExpr(replacement) : expr; + } + + if (expr instanceof BinaryOp) { + BinaryOp op = (BinaryOp) expr; + return ASTFactory.createBinaryOp( + substituteIdentifiers(op.left, replacements), + op.op, + substituteIdentifiers(op.right, replacements), + null + ); + } + + if (expr instanceof Unary) { + Unary unary = (Unary) expr; + return ASTFactory.createUnaryOp(unary.op, substituteIdentifiers(unary.operand, replacements), null); + } + + if (expr instanceof TypeCast) { + TypeCast cast = (TypeCast) expr; + return ASTFactory.createTypeCast(cast.targetType, substituteIdentifiers(cast.expression, replacements), null); + } + + if (expr instanceof ExprIf) { + ExprIf exprIf = (ExprIf) expr; + return new ExprIf( + substituteIdentifiers(exprIf.condition, replacements), + substituteIdentifiers(exprIf.thenExpr, replacements), + substituteIdentifiers(exprIf.elseExpr, replacements) + ); + } + + if (expr instanceof EqualityChain) { + EqualityChain chain = (EqualityChain) expr; + List args = new ArrayList(); + if (chain.chainArguments != null) { + for (Expr arg : chain.chainArguments) { + args.add(substituteIdentifiers(arg, replacements)); + } + } + return ASTFactory.createEqualityChain( + substituteIdentifiers(chain.left, replacements), + chain.operator, + chain.isAllChain, + args, + null, + null, + null + ); + } + + if (expr instanceof ChainedComparison) { + ChainedComparison cmp = (ChainedComparison) expr; + List copiedExpressions = new ArrayList(); + if (cmp.expressions != null) { + for (Expr item : cmp.expressions) { + copiedExpressions.add(substituteIdentifiers(item, replacements)); + } + } + List copiedOperators = cmp.operators != null ? new ArrayList(cmp.operators) : new ArrayList(); + return new ChainedComparison(copiedExpressions, copiedOperators); + } + + if (expr instanceof BooleanChain) { + BooleanChain chain = (BooleanChain) expr; + List expressions = new ArrayList(); + if (chain.expressions != null) { + for (Expr item : chain.expressions) { + expressions.add(substituteIdentifiers(item, replacements)); + } + } + return ASTFactory.createBooleanChain(chain.isAll, expressions, null); + } + + return expr; + } + + private Expr buildNumericIndicator(Expr condition) { + return simplifyExpr(convertBooleanToNumeric(cloneExpr(condition))); + } + + private Expr convertBooleanToNumeric(Expr condition) { + if (condition == null) return null; + + Boolean constant = evaluateConstantBoolean(condition); + if (constant != null) { + return constant.booleanValue() ? one() : zero(); + } + + if (condition instanceof BoolLiteral) { + return ((BoolLiteral) condition).value ? one() : zero(); + } + + if (condition instanceof Unary) { + Unary unary = (Unary) condition; + if ("!".equals(unary.op)) { + Expr inner = convertBooleanToNumeric(unary.operand); + if (inner == null) return null; + return ASTFactory.createBinaryOp(one(), "-", inner, null); + } + } + + if (condition instanceof BinaryOp) { + BinaryOp binary = (BinaryOp) condition; + if ("&&".equals(binary.op) || "and".equals(binary.op)) { + Expr left = convertBooleanToNumeric(binary.left); + Expr right = convertBooleanToNumeric(binary.right); + if (left == null || right == null) return null; + return ASTFactory.createBinaryOp(left, "*", right, null); + } + if ("||".equals(binary.op) || "or".equals(binary.op)) { + Expr left = convertBooleanToNumeric(binary.left); + Expr right = convertBooleanToNumeric(binary.right); + if (left == null || right == null) return null; + Expr leftCloneA = cloneExpr(left); + Expr rightCloneA = cloneExpr(right); + Expr leftCloneB = cloneExpr(left); + Expr rightCloneB = cloneExpr(right); + Expr sum = ASTFactory.createBinaryOp(leftCloneA, "+", rightCloneA, null); + Expr prod = ASTFactory.createBinaryOp(leftCloneB, "*", rightCloneB, null); + return ASTFactory.createBinaryOp(sum, "-", prod, null); + } + } + + return new ExprIf(condition, one(), zero()); + } + + private boolean isPureExpression(Expr expr) { + if (expr == null) return false; + + if (expr instanceof Identifier + || expr instanceof IntLiteral + || expr instanceof FloatLiteral + || expr instanceof BoolLiteral + || expr instanceof TextLiteral + || expr instanceof NoneLiteral + || expr instanceof ValueExpr) { + return true; + } + + if (expr instanceof BinaryOp) { + BinaryOp op = (BinaryOp) expr; + return isPureExpression(op.left) && isPureExpression(op.right); + } + + if (expr instanceof Unary) { + Unary unary = (Unary) expr; + return isPureExpression(unary.operand); + } + + if (expr instanceof TypeCast) { + TypeCast cast = (TypeCast) expr; + return isPureExpression(cast.expression); + } + + if (expr instanceof ExprIf) { + ExprIf exprIf = (ExprIf) expr; + return isPureExpression(exprIf.condition) + && isPureExpression(exprIf.thenExpr) + && isPureExpression(exprIf.elseExpr); + } + + if (expr instanceof EqualityChain) { + EqualityChain chain = (EqualityChain) expr; + if (!isPureExpression(chain.left)) return false; + if (chain.chainArguments != null) { + for (Expr arg : chain.chainArguments) { + if (!isPureExpression(arg)) return false; + } + } + return true; + } + + if (expr instanceof ChainedComparison) { + ChainedComparison chain = (ChainedComparison) expr; + if (chain.expressions != null) { + for (Expr item : chain.expressions) { + if (!isPureExpression(item)) return false; } } - - // No branch matched - execute else statements - return executeStatementSequence(elseStatements, evaluator, evalCtx); - - } catch (Exception e) { - throw new RuntimeException("Error evaluating conditional formula at index " + index, e); + return true; + } + + if (expr instanceof BooleanChain) { + BooleanChain chain = (BooleanChain) expr; + if (chain.expressions != null) { + for (Expr item : chain.expressions) { + if (!isPureExpression(item)) return false; + } + } + return true; + } + + return false; + } + + private Expr simplifyExpr(Expr expr) { + if (!(expr instanceof BinaryOp)) return expr; + + BinaryOp op = (BinaryOp) expr; + Expr left = simplifyExpr(op.left); + Expr right = simplifyExpr(op.right); + + Expr folded = foldNumericConstants(op.op, left, right); + if (folded != null) { + return folded; + } + + if ("*".equals(op.op)) { + if (isZero(left) || isZero(right)) return zero(); + if (isOne(left)) return right; + if (isOne(right)) return left; + } else if ("+".equals(op.op)) { + if (isZero(left)) return right; + if (isZero(right)) return left; + if (structurallyEqual(left, right)) { + return ASTFactory.createBinaryOp(ASTFactory.createIntLiteral(2, null), "*", left, null); + } + Expr factored = tryFactorCommonTerm(left, right); + if (factored != null) { + return simplifyExpr(factored); + } + Expr branchCollapse = tryCollapseEquivalentBranchMix(left, right); + if (branchCollapse != null) { + return branchCollapse; + } + } else if ("-".equals(op.op)) { + if (isZero(right)) return left; + if (sameLiteral(left, right) || structurallyEqual(left, right)) return zero(); + } + + return ASTFactory.createBinaryOp(left, op.op, right, null); + } + + private Expr foldNumericConstants(String op, Expr left, Expr right) { + if (!(left instanceof IntLiteral) || !(right instanceof IntLiteral)) { + return null; + } + long leftValue = ((IntLiteral) left).value.longValue(); + long rightValue = ((IntLiteral) right).value.longValue(); + if ("+".equals(op)) return ASTFactory.createLongLiteral(leftValue + rightValue, null); + if ("-".equals(op)) return ASTFactory.createLongLiteral(leftValue - rightValue, null); + if ("*".equals(op)) return ASTFactory.createLongLiteral(leftValue * rightValue, null); + if ("/".equals(op)) { + if (rightValue == 0L) return null; + if (leftValue % rightValue != 0L) return null; + return ASTFactory.createLongLiteral(leftValue / rightValue, null); + } + return null; + } + + private Expr tryFactorCommonTerm(Expr left, Expr right) { + if (!(left instanceof BinaryOp) || !(right instanceof BinaryOp)) { + return null; + } + BinaryOp leftBin = (BinaryOp) left; + BinaryOp rightBin = (BinaryOp) right; + if (!"*".equals(leftBin.op) || !"*".equals(rightBin.op)) { + return null; + } + + if (structurallyEqual(leftBin.left, rightBin.left)) { + Expr sum = simplifyExpr(ASTFactory.createBinaryOp(leftBin.right, "+", rightBin.right, null)); + return ASTFactory.createBinaryOp(leftBin.left, "*", sum, null); + } + if (structurallyEqual(leftBin.left, rightBin.right)) { + Expr sum = simplifyExpr(ASTFactory.createBinaryOp(leftBin.right, "+", rightBin.left, null)); + return ASTFactory.createBinaryOp(leftBin.left, "*", sum, null); + } + if (structurallyEqual(leftBin.right, rightBin.left)) { + Expr sum = simplifyExpr(ASTFactory.createBinaryOp(leftBin.left, "+", rightBin.right, null)); + return ASTFactory.createBinaryOp(leftBin.right, "*", sum, null); + } + if (structurallyEqual(leftBin.right, rightBin.right)) { + Expr sum = simplifyExpr(ASTFactory.createBinaryOp(leftBin.left, "+", rightBin.left, null)); + return ASTFactory.createBinaryOp(leftBin.right, "*", sum, null); + } + return null; + } + + private Expr tryCollapseEquivalentBranchMix(Expr left, Expr right) { + Expr[] leftParts = splitProduct(left); + Expr[] rightParts = splitProduct(right); + if (leftParts == null || rightParts == null) { + return null; + } + + Expr leftCoefficient = leftParts[0]; + Expr leftValue = leftParts[1]; + Expr rightCoefficient = rightParts[0]; + Expr rightValue = rightParts[1]; + + if (!structurallyEqual(leftValue, rightValue)) { + return null; + } + if (isComplementIndicator(leftCoefficient, rightCoefficient)) { + return leftValue; + } + return null; + } + + private Expr[] splitProduct(Expr expr) { + if (!(expr instanceof BinaryOp)) return null; + BinaryOp op = (BinaryOp) expr; + if (!"*".equals(op.op)) return null; + return new Expr[] { op.left, op.right }; + } + + private boolean isComplementIndicator(Expr a, Expr b) { + if (!(b instanceof BinaryOp)) return false; + BinaryOp bOp = (BinaryOp) b; + return "-".equals(bOp.op) && isOne(bOp.left) && structurallyEqual(a, bOp.right); + } + + private boolean structurallyEqual(Expr a, Expr b) { + if (a == b) return true; + if (a == null || b == null) return false; + if (!a.getClass().equals(b.getClass())) return false; + + if (a instanceof Identifier) { + String leftName = ((Identifier) a).name; + String rightName = ((Identifier) b).name; + return leftName != null ? leftName.equals(rightName) : rightName == null; + } + if (a instanceof IntLiteral) { + return ((IntLiteral) a).value.equals(((IntLiteral) b).value); + } + if (a instanceof FloatLiteral) { + return ((FloatLiteral) a).value.equals(((FloatLiteral) b).value); + } + if (a instanceof BoolLiteral) { + return ((BoolLiteral) a).value == ((BoolLiteral) b).value; + } + if (a instanceof TextLiteral) { + return Objects.equals(((TextLiteral) a).value, ((TextLiteral) b).value); + } + if (a instanceof NoneLiteral) { + return true; + } + if (a instanceof BinaryOp) { + BinaryOp x = (BinaryOp) a; + BinaryOp y = (BinaryOp) b; + return Objects.equals(x.op, y.op) + && structurallyEqual(x.left, y.left) + && structurallyEqual(x.right, y.right); + } + if (a instanceof Unary) { + Unary x = (Unary) a; + Unary y = (Unary) b; + return Objects.equals(x.op, y.op) && structurallyEqual(x.operand, y.operand); + } + if (a instanceof ExprIf) { + ExprIf x = (ExprIf) a; + ExprIf y = (ExprIf) b; + return structurallyEqual(x.condition, y.condition) + && structurallyEqual(x.thenExpr, y.thenExpr) + && structurallyEqual(x.elseExpr, y.elseExpr); } + return false; } - - private Object executeStatementSequence( - List statements, Evaluator evaluator, ExecutionContext ctx) { - + + private Boolean evaluateConstantBoolean(Expr expr) { + if (expr == null) return null; + + if (expr instanceof BoolLiteral) { + return Boolean.valueOf(((BoolLiteral) expr).value); + } + if (expr instanceof Unary) { + Unary unary = (Unary) expr; + if ("!".equals(unary.op)) { + Boolean inner = evaluateConstantBoolean(unary.operand); + return inner != null ? Boolean.valueOf(!inner.booleanValue()) : null; + } + } + if (expr instanceof BinaryOp) { + BinaryOp op = (BinaryOp) expr; + if ("&&".equals(op.op) || "and".equals(op.op)) { + Boolean left = evaluateConstantBoolean(op.left); + Boolean right = evaluateConstantBoolean(op.right); + if (left != null && right != null) { + return Boolean.valueOf(left.booleanValue() && right.booleanValue()); + } + return null; + } + if ("||".equals(op.op) || "or".equals(op.op)) { + Boolean left = evaluateConstantBoolean(op.left); + Boolean right = evaluateConstantBoolean(op.right); + if (left != null && right != null) { + return Boolean.valueOf(left.booleanValue() || right.booleanValue()); + } + return null; + } + + Object leftConst = constantValue(op.left); + Object rightConst = constantValue(op.right); + if (leftConst == null || rightConst == null) { + return null; + } + + if ("==".equals(op.op)) return Boolean.valueOf(leftConst.equals(rightConst)); + if ("!=".equals(op.op)) return Boolean.valueOf(!leftConst.equals(rightConst)); + + Integer cmp = compareConstants(leftConst, rightConst); + if (cmp == null) return null; + if (">".equals(op.op)) return Boolean.valueOf(cmp.intValue() > 0); + if ("<".equals(op.op)) return Boolean.valueOf(cmp.intValue() < 0); + if (">=".equals(op.op)) return Boolean.valueOf(cmp.intValue() >= 0); + if ("<=".equals(op.op)) return Boolean.valueOf(cmp.intValue() <= 0); + } + return null; + } + + private Object constantValue(Expr expr) { + if (expr instanceof IntLiteral) { + return ((IntLiteral) expr).value; + } + if (expr instanceof FloatLiteral) { + return ((FloatLiteral) expr).value; + } + if (expr instanceof BoolLiteral) { + return Boolean.valueOf(((BoolLiteral) expr).value); + } + if (expr instanceof TextLiteral) { + return ((TextLiteral) expr).value; + } + return null; + } + + private Integer compareConstants(Object left, Object right) { + if (left instanceof cod.math.AutoStackingNumber && right instanceof cod.math.AutoStackingNumber) { + return Integer.valueOf(((cod.math.AutoStackingNumber) left).compareTo((cod.math.AutoStackingNumber) right)); + } + if (left instanceof Boolean && right instanceof Boolean) { + boolean leftBool = ((Boolean) left).booleanValue(); + boolean rightBool = ((Boolean) right).booleanValue(); + return Integer.valueOf(leftBool == rightBool ? 0 : (leftBool ? 1 : -1)); + } + if (left instanceof String && right instanceof String) { + return Integer.valueOf(((String) left).compareTo((String) right)); + } + return null; + } + + private boolean isZero(Expr expr) { + if (!(expr instanceof IntLiteral)) return false; + return ((IntLiteral) expr).value.isZero(); + } + + private boolean isOne(Expr expr) { + if (!(expr instanceof IntLiteral)) return false; + return ((IntLiteral) expr).value.longValue() == 1L; + } + + private boolean sameLiteral(Expr a, Expr b) { + if (a instanceof IntLiteral && b instanceof IntLiteral) { + return ((IntLiteral) a).value.equals(((IntLiteral) b).value); + } + if (a instanceof BoolLiteral && b instanceof BoolLiteral) { + return ((BoolLiteral) a).value == ((BoolLiteral) b).value; + } + return false; + } + + private Expr cloneExpr(Expr expr) { + if (expr == null) return null; + if (expr instanceof Identifier) return ASTFactory.createIdentifier(((Identifier) expr).name, null); + if (expr instanceof IntLiteral) { + long value = ((IntLiteral) expr).value.longValue(); + if (value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE) { + return ASTFactory.createIntLiteral((int) value, null); + } + return ASTFactory.createLongLiteral(value, null); + } + if (expr instanceof FloatLiteral) return ASTFactory.createFloatLiteral(((FloatLiteral) expr).value, null); + if (expr instanceof BoolLiteral) return ASTFactory.createBoolLiteral(((BoolLiteral) expr).value, null); + if (expr instanceof TextLiteral) return ASTFactory.createTextLiteral(((TextLiteral) expr).value, null); + if (expr instanceof NoneLiteral) return ASTFactory.createNoneLiteral(null); + if (expr instanceof ValueExpr) return new ValueExpr(((ValueExpr) expr).getValue()); + if (expr instanceof BinaryOp) { + BinaryOp op = (BinaryOp) expr; + return ASTFactory.createBinaryOp(cloneExpr(op.left), op.op, cloneExpr(op.right), null); + } + if (expr instanceof Unary) { + Unary unary = (Unary) expr; + return ASTFactory.createUnaryOp(unary.op, cloneExpr(unary.operand), null); + } + if (expr instanceof TypeCast) { + TypeCast cast = (TypeCast) expr; + return ASTFactory.createTypeCast(cast.targetType, cloneExpr(cast.expression), null); + } + if (expr instanceof ExprIf) { + ExprIf exprIf = (ExprIf) expr; + return new ExprIf(cloneExpr(exprIf.condition), cloneExpr(exprIf.thenExpr), cloneExpr(exprIf.elseExpr)); + } + if (expr instanceof EqualityChain) { + EqualityChain chain = (EqualityChain) expr; + List args = new ArrayList(); + if (chain.chainArguments != null) { + for (Expr arg : chain.chainArguments) { + args.add(cloneExpr(arg)); + } + } + return ASTFactory.createEqualityChain(cloneExpr(chain.left), chain.operator, chain.isAllChain, args, null, null, null); + } + if (expr instanceof ChainedComparison) { + ChainedComparison source = (ChainedComparison) expr; + List copiedExpressions = new ArrayList(); + if (source.expressions != null) { + for (Expr item : source.expressions) { + copiedExpressions.add(cloneExpr(item)); + } + } + List copiedOperators = source.operators != null ? new ArrayList(source.operators) : new ArrayList(); + return new ChainedComparison(copiedExpressions, copiedOperators); + } + if (expr instanceof BooleanChain) { + BooleanChain source = (BooleanChain) expr; + List items = new ArrayList(); + if (source.expressions != null) { + for (Expr item : source.expressions) { + items.add(cloneExpr(item)); + } + } + return ASTFactory.createBooleanChain(source.isAll, items, null); + } + return expr; + } + + private Expr zero() { + return ASTFactory.createIntLiteral(0, null); + } + + private Expr one() { + return ASTFactory.createIntLiteral(1, null); + } + + private Object executeStatementSequence(List statements, Evaluator evaluator, ExecutionContext ctx) { Object lastResult = null; - - // Create a new scope for temporary variables ctx.pushScope(); - try { for (Stmt stmt : statements) { lastResult = evaluator.evaluate(stmt, ctx); @@ -70,7 +773,6 @@ private Object executeStatementSequence( } finally { ctx.popScope(); } - return lastResult; } -} \ No newline at end of file +}