diff --git a/source_.jar b/source_.jar index 745351b7..d45e76f3 100644 Binary files a/source_.jar and b/source_.jar differ diff --git a/src/main/java/cod/interpreter/InterpreterVisitor.java b/src/main/java/cod/interpreter/InterpreterVisitor.java index f38ce8c0..65043697 100644 --- a/src/main/java/cod/interpreter/InterpreterVisitor.java +++ b/src/main/java/cod/interpreter/InterpreterVisitor.java @@ -24,26 +24,6 @@ public class InterpreterVisitor extends ASTVisitor implements Evaluator private static final String SELF_CALL_LAMBDA_OWNER = "self-call lambda"; private static final double SELF_CALL_LEVEL_FLOAT_EPSILON = 1e-12d; - private static final class TailCallSignal extends RuntimeException { - public final String methodName; - public final LambdaClosure lambdaClosure; - public final List arguments; - - private TailCallSignal(String methodName, LambdaClosure lambdaClosure, List arguments) { - this.methodName = methodName; - this.lambdaClosure = lambdaClosure; - this.arguments = - arguments != null ? new ArrayList(arguments) : Collections.emptyList(); - } - - static TailCallSignal forMethod(String methodName, List arguments) { - return new TailCallSignal(methodName, null, arguments); - } - - static TailCallSignal forLambda(LambdaClosure lambdaClosure, List arguments) { - return new TailCallSignal(null, lambdaClosure, arguments); - } - } // Tail-call trampolining intentionally uses this internal signal to unwind Java frames // without allocating wrapper result objects through every visitor return path. // This favors lower allocation overhead over exception cost in non-tail paths. @@ -102,11 +82,12 @@ private static class LinearRecurrencePattern { private final Stack contextStack = new Stack(); private final ExpressionHandler expressionHandler; private final AssignmentHandler assignmentHandler; - // Lazily resolved internal.range type references used for runtime range index objects. - private Type internalRangeSpecType; - private Type internalMultiRangeSpecType; - private final LiteralRegistry literalRegistry; + private final ContextHandler contextHandler; + private final LambdaInvokingHandler lambdaInvokingHandler; + private final ArrayOperationHandler arrayOperationHandler; + private final PatternHandler patternHandler; + private final LoopOptimizationHandler loopOptimizationHandler; // ========== SIMPLE LOOP OPTIMIZATION CONSTANTS ========== private static final int LAZY_THRESHOLD = 10; // From your data: 10+ iterations = worth it @@ -127,8 +108,16 @@ public InterpreterVisitor(Interpreter interpreter, TypeHandler typeSystem, this.interpreter = interpreter; this.typeSystem = typeSystem; this.literalRegistry = literalRegistry; + this.contextHandler = new ContextHandler(interpreter); this.expressionHandler = new ExpressionHandler(typeSystem, this); this.assignmentHandler = new AssignmentHandler(typeSystem, interpreter, expressionHandler, this); + this.arrayOperationHandler = + new ArrayOperationHandler(this, interpreter, typeSystem, expressionHandler, contextHandler); + this.patternHandler = + new PatternHandler(this, typeSystem, expressionHandler, arrayOperationHandler); + this.loopOptimizationHandler = + new LoopOptimizationHandler(this, typeSystem, expressionHandler, arrayOperationHandler, patternHandler); + this.lambdaInvokingHandler = new LambdaInvokingHandler(typeSystem, this); } // Implement Evaluator interface @@ -171,7 +160,7 @@ public Object invokeLambda(Object callback, List arguments, ExecutionCon if (ctx == null) { throw new InternalError("invokeLambda called with null context"); } - return invokeLambdaCallback(callback, arguments, ctx, ownerMethod); + return lambdaInvokingHandler.invokeLambdaCallback(callback, arguments, ctx, ownerMethod); } public void pushContext(ExecutionContext context) { @@ -205,8 +194,12 @@ public boolean isContextStackEmpty() { return contextStack.isEmpty(); } + public boolean shouldReturnEarly(Map slotValues, Set slotsInCurrentPath) { + return interpreter.shouldReturnEarly(slotValues, slotsInCurrentPath); + } + private Object createNoneValue() { - return new NoneLiteral(); + return contextHandler.createNoneValue(); } @Override @@ -350,7 +343,7 @@ public Object visit(Var node) { if (node.value == null) { throw new ProgramError("Constant '" + node.name + "' must have an initial value"); } - if (isVariableDeclaredInAnyScope(ctx, node.name)) { + if (contextHandler.isVariableDeclaredInAnyScope(ctx, node.name)) { throw new ProgramError("Cannot reassign constant '" + node.name + "'"); } } @@ -368,7 +361,7 @@ public Object visit(Var node) { // If expected is [text] but actual is not text, create a converting wrapper if (expectedElementType.equals("text") && !actualElementType.equals("text")) { // Create a new NaturalArray with conversion enabled - Range range = getRangeFromArray(arr); + Range range = contextHandler.getRangeFromArray(arr); if (range != null) { val = new NaturalArray(range, this, ctx, node.explicitType); } @@ -428,62 +421,19 @@ public Object visit(Var node) { // Helper method to extract Range from NaturalArray private Range getRangeFromArray(NaturalArray arr) { - try { - java.lang.reflect.Field rangeField = NaturalArray.class.getDeclaredField("baseRange"); - rangeField.setAccessible(true); - return (Range) rangeField.get(arr); - } catch (Exception e) { - return null; - } + return contextHandler.getRangeFromArray(arr); } private Type resolveInternalRangeSpecType() { - if (internalRangeSpecType != null) { - return internalRangeSpecType; - } - try { - Type type = interpreter.getImportResolver().resolveImport("internal.range.RangeSpec"); - if (type == null) { - throw new ProgramError("Unable to load internal.range.RangeSpec"); - } - internalRangeSpecType = type; - return type; - } catch (ProgramError e) { - throw e; - } catch (Exception e) { - throw new InternalError("Failed loading internal.range.RangeSpec", e); - } + return contextHandler.resolveInternalRangeSpecType(); } private Type resolveInternalMultiRangeSpecType() { - if (internalMultiRangeSpecType != null) { - return internalMultiRangeSpecType; - } - try { - Type type = interpreter.getImportResolver().resolveImport("internal.range.MultiRangeSpec"); - if (type == null) { - throw new ProgramError("Unable to load internal.range.MultiRangeSpec"); - } - internalMultiRangeSpecType = type; - return type; - } catch (ProgramError e) { - throw e; - } catch (Exception e) { - throw new InternalError("Failed loading internal.range.MultiRangeSpec", e); - } + return contextHandler.resolveInternalMultiRangeSpecType(); } private boolean isVariableDeclaredInAnyScope(ExecutionContext ctx, String name) { - if (ctx == null || name == null) return false; - List> localsStack = ctx.getLocalsStack(); - if (localsStack == null) return false; - for (int i = localsStack.size() - 1; i >= 0; i--) { - Map scope = localsStack.get(i); - if (scope != null && scope.containsKey(name)) { - return true; - } - } - return false; + return contextHandler.isVariableDeclaredInAnyScope(ctx, name); } @Override @@ -550,71 +500,7 @@ public Object visit(ExprIf node) { // ========== UPDATED FOR NODE WITH SIMPLE LOOP DECISION ========== @Override public Object visit(For node) { - if (node == null) { - throw new InternalError("visit(For) called with null node"); - } - - ExecutionContext ctx = getCurrentContext(); - int originalDepth = ctx.getScopeDepth(); - - // Estimate loop size before execution - long loopSize = estimateLoopSize(node, ctx); - boolean hasSideEffects = hasSideEffects(node.body); - - // Simple decision: should we try lazy execution? - boolean useLazyExecution = shouldUseLazyExecution(loopSize, hasSideEffects); - - // Start tracking this loop - int loopId = ArrayTracker.beginLoop(node); - - // Store estimated size in tracker - ArrayTracker.setLoopSize(loopId, loopSize); - ArrayTracker.setSideEffects(loopId, hasSideEffects); - - try { - ctx.pushScope(); - - // Try lazy execution if beneficial - if (useLazyExecution) { - Object result = tryOptimizedExecution(node, loopId); - if (result != null) { - return result; - } - } else { - DebugSystem.debug("LOOP", - String.format("Skipping optimization: size=%d, sideEffects=%s", - loopSize, hasSideEffects)); - } - - // Normal eager execution - ArrayTracker.incrementIteration(); - - if (node.range != null) { - return executeRangeLoop(ctx, node, node.iterator); - } else if (node.arraySource != null) { - Object arrayObj = dispatch(node.arraySource); - arrayObj = typeSystem.unwrap(arrayObj); - return executeArrayLoop(ctx, node, node.iterator, arrayObj); - } - throw new ProgramError("Invalid for loop: neither range nor array source specified"); - - } catch (ProgramError e) { - throw e; - } catch (TailCallSignal e) { - throw e; - } catch (Exception e) { - throw new InternalError("For loop execution failed", e); - } finally { - // End tracking - simple log - ArrayTracker.LoopStats stats = ArrayTracker.endLoop(); - if (stats != null) { - DebugSystem.debug("LOOP", stats.toString()); - } - - while (ctx.getScopeDepth() > originalDepth) { - ctx.popScope(); - } - } + return loopOptimizationHandler.executeForLoop(node); } // ========== SIMPLE LOOP DECISION METHODS ========== @@ -2400,130 +2286,17 @@ private Object buildDimensionArray(List ranges, int dimension) { @SuppressWarnings("unchecked") @Override public Object visit(IndexAccess node) { - if (node == null) { - throw new InternalError("visit(IndexAccess) called with null node"); - } - - try { - Object arrayObj = dispatch(node.array); - arrayObj = typeSystem.unwrap(arrayObj); - - // Force materialization BEFORE index access - if (arrayObj instanceof NaturalArray) { - NaturalArray natural = (NaturalArray) arrayObj; - if (natural.hasPendingUpdates()) { - natural.commitUpdates(); - } - } - - Object indexObj = dispatch(node.index); - indexObj = typeSystem.unwrap(indexObj); - - if (indexObj instanceof List) { - return applyTupleIndices(arrayObj, (List) indexObj); - } - - if (RangeObjects.isRangeSpec(indexObj)) { - if (arrayObj instanceof String) { - return applyStringRangeIndex((String) arrayObj, indexObj); - } - return applyRangeIndex(arrayObj, indexObj); - } - - if (RangeObjects.isMultiRangeSpec(indexObj)) { - return applyMultiRangeIndex(arrayObj, indexObj); - } - - if (arrayObj instanceof String) { - String text = (String) arrayObj; - int index = expressionHandler.toIntIndex(indexObj); - index = normalizeTextIndex(index, text.length()); - if (index < 0 || index >= text.length()) { - throw new ProgramError( - "Index out of bounds: " + index + " for text of length " + text.length()); - } - return String.valueOf(text.charAt(index)); - } - - if (arrayObj instanceof NaturalArray) { - NaturalArray natural = (NaturalArray) arrayObj; - long index = expressionHandler.toLongIndex(indexObj); - - if (natural.needsConversion()) { - return natural.get(index, true); - } - return natural.get(index); - } - - if (arrayObj instanceof List) { - List list = (List) arrayObj; - if (indexObj instanceof AutoStackingNumber) { - int index = (int) ((AutoStackingNumber) indexObj).longValue(); - if (index < 0 || index >= list.size()) { - throw new ProgramError( - "Index out of bounds: " + index + " for array of size " + list.size()); - } - return list.get(index); - } else { - int index = expressionHandler.toIntIndex(indexObj); - if (index < 0 || index >= list.size()) { - throw new ProgramError( - "Index out of bounds: " + index + " for array of size " + list.size()); - } - return list.get(index); - } - } - - throw new ProgramError( - "Invalid array access: expected NaturalArray or List, got " - + (arrayObj != null ? arrayObj.getClass().getSimpleName() : "null")); - } catch (ProgramError e) { - throw e; - } catch (Exception e) { - throw new InternalError("Index access failed", e); - } + return arrayOperationHandler.visitIndexAccess(node); } @Override public Object visit(RangeIndex node) { - if (node == null) { - throw new InternalError("visit(RangeIndex) called with null node"); - } - - try { - Object step = node.step != null ? dispatch(node.step) : null; - Object start = dispatch(node.start); - Object end = dispatch(node.end); - - return RangeObjects.createRangeSpec(resolveInternalRangeSpecType(), step, start, end); - } catch (ProgramError e) { - throw e; - } catch (Exception e) { - throw new InternalError("Range index creation failed", e); - } + return arrayOperationHandler.visitRangeIndex(node); } @Override public Object visit(MultiRangeIndex node) { - if (node == null) { - throw new InternalError("visit(MultiRangeIndex) called with null node"); - } - - try { - List ranges = new ArrayList(); - for (RangeIndex rangeNode : node.ranges) { - Object range = visit(rangeNode); - if (!RangeObjects.isRangeSpec(range)) { - throw new InternalError("Multi-range index contains non-range value"); - } - ranges.add(range); - } - return RangeObjects.createMultiRangeSpec(resolveInternalMultiRangeSpecType(), ranges); - } catch (ProgramError e) { - throw e; - } catch (Exception e) { - throw new InternalError("Multi-range index creation failed", e); - } + return arrayOperationHandler.visitMultiRangeIndex(node); } @Override @@ -2578,26 +2351,7 @@ public Object visit(Slot n) { @Override public Object visit(Lambda node) { - if (node == null) { - throw new InternalError("visit(Lambda) called with null node"); - } - - ExecutionContext ctx = getCurrentContext(); - Map captured = new HashMap(); - for (int i = 0; i < ctx.getScopeDepth(); i++) { - Map scope = ctx.getScope(i); - if (scope != null) { - captured.putAll(scope); - } - } - - return new LambdaClosure( - node, - captured, - ctx.objectInstance, - ctx.currentClass, - ctx.currentLambdaClosure, - Collections.emptyList()); + return lambdaInvokingHandler.createLambdaClosure(node, getCurrentContext()); } private Object invokeLambdaCallback( @@ -2605,71 +2359,7 @@ private Object invokeLambdaCallback( List args, ExecutionContext parentCtx, String ownerMethod) { - - Object callback = typeSystem.unwrap(callbackObj); - LambdaClosure closure; - if (callback instanceof LambdaClosure) { - closure = (LambdaClosure) callback; - } else if (callback instanceof Lambda) { - closure = - new LambdaClosure( - (Lambda) callback, - parentCtx.locals(), - parentCtx.objectInstance, - parentCtx.currentClass, - parentCtx.currentLambdaClosure, - Collections.emptyList()); - } else { - String actualType = callback == null ? "null" : callback.getClass().getSimpleName(); - throw new ProgramError(ownerMethod + " expects a lambda callback, got: " + actualType); - } - - LambdaClosure activeClosure = closure; - List activeIncomingValues = args != null ? args : Collections.emptyList(); - - while (true) { - Lambda lambda = activeClosure.lambda; - List params = resolveLambdaParameters(lambda); - List combinedValues = - mergeBoundAndIncomingLambdaArgs(activeClosure.boundArguments, activeIncomingValues); - - if (shouldAutoCurry(params, combinedValues)) { - return createCurriedLambdaClosure(activeClosure, combinedValues); - } - - int parameterBindCount = Math.min(params.size(), combinedValues.size()); - List values = new ArrayList(combinedValues.subList(0, parameterBindCount)); - List leftoverValues = - new ArrayList(combinedValues.subList(parameterBindCount, combinedValues.size())); - - Map lambdaLocals = - bindLambdaArguments(params, values, activeClosure, ownerMethod); - if (lambda.inferParameters && params.isEmpty()) { - bindPositionalInferredPlaceholderAliases(lambdaLocals, values); - } - - Object result; - try { - if (lambda.expressionBody != null) { - result = evaluateLambdaExpressionBody(lambda, activeClosure, lambdaLocals); - } else { - result = evaluateLambdaBlockBody(lambda, activeClosure, lambdaLocals); - } - } catch (TailCallSignal tailCallSignal) { - if (tailCallSignal.lambdaClosure != null - && tailCallSignal.lambdaClosure == activeClosure) { - activeClosure = tailCallSignal.lambdaClosure; - activeIncomingValues = tailCallSignal.arguments; - continue; - } - throw tailCallSignal; - } - - if (!leftoverValues.isEmpty() && (result instanceof LambdaClosure || result instanceof Lambda)) { - return invokeLambdaCallback(result, leftoverValues, parentCtx, ownerMethod); - } - return result; - } + return lambdaInvokingHandler.invokeLambdaCallback(callbackObj, args, parentCtx, ownerMethod); } private List resolveLambdaParameters(Lambda lambda) { diff --git a/src/main/java/cod/interpreter/TailCallSignal.java b/src/main/java/cod/interpreter/TailCallSignal.java new file mode 100644 index 00000000..9c602b1b --- /dev/null +++ b/src/main/java/cod/interpreter/TailCallSignal.java @@ -0,0 +1,27 @@ +package cod.interpreter; + +import cod.interpreter.context.LambdaClosure; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public final class TailCallSignal extends RuntimeException { + public final String methodName; + public final LambdaClosure lambdaClosure; + public final List arguments; + + private TailCallSignal(String methodName, LambdaClosure lambdaClosure, List arguments) { + this.methodName = methodName; + this.lambdaClosure = lambdaClosure; + this.arguments = arguments != null ? new ArrayList(arguments) : Collections.emptyList(); + } + + public static TailCallSignal forMethod(String methodName, List arguments) { + return new TailCallSignal(methodName, null, arguments); + } + + public static TailCallSignal forLambda(LambdaClosure lambdaClosure, List arguments) { + return new TailCallSignal(null, lambdaClosure, arguments); + } +} diff --git a/src/main/java/cod/interpreter/handler/ArrayOperationHandler.java b/src/main/java/cod/interpreter/handler/ArrayOperationHandler.java new file mode 100644 index 00000000..f3422763 --- /dev/null +++ b/src/main/java/cod/interpreter/handler/ArrayOperationHandler.java @@ -0,0 +1,582 @@ +package cod.interpreter.handler; + +import cod.ast.node.*; +import cod.error.InternalError; +import cod.error.ProgramError; +import cod.interpreter.Interpreter; +import cod.interpreter.InterpreterVisitor; +import cod.interpreter.exception.BreakLoopException; +import cod.interpreter.exception.SkipIterationException; +import cod.math.AutoStackingNumber; +import cod.range.NaturalArray; +import cod.range.RangeObjects; + +import java.util.ArrayList; +import java.util.List; + +public class ArrayOperationHandler { + private final InterpreterVisitor dispatcher; + private final Interpreter interpreter; + private final TypeHandler typeSystem; + private final ExpressionHandler expressionHandler; + private final ContextHandler contextHandler; + + public ArrayOperationHandler( + InterpreterVisitor dispatcher, + Interpreter interpreter, + TypeHandler typeSystem, + ExpressionHandler expressionHandler, + ContextHandler contextHandler) { + if (dispatcher == null) throw new InternalError("ArrayOperationHandler dispatcher is null"); + if (interpreter == null) throw new InternalError("ArrayOperationHandler interpreter is null"); + if (typeSystem == null) throw new InternalError("ArrayOperationHandler typeSystem is null"); + if (expressionHandler == null) throw new InternalError("ArrayOperationHandler expressionHandler is null"); + if (contextHandler == null) throw new InternalError("ArrayOperationHandler contextHandler is null"); + this.dispatcher = dispatcher; + this.interpreter = interpreter; + this.typeSystem = typeSystem; + this.expressionHandler = expressionHandler; + this.contextHandler = contextHandler; + } + + public Object executeForLoopNormally(For node) { + if (node == null) { + throw new InternalError("executeForLoopNormally called with null node"); + } + + cod.interpreter.context.ExecutionContext ctx = dispatcher.getCurrentContext(); + String iter = node.iterator; + + try { + if (node.range != null) { + return executeRangeLoop(ctx, node, iter); + } else if (node.arraySource != null) { + Object arrayObj = dispatcher.dispatch(node.arraySource); + arrayObj = typeSystem.unwrap(arrayObj); + return executeArrayLoop(ctx, node, iter, arrayObj); + } + throw new ProgramError("Invalid for loop"); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Normal loop execution failed", e); + } + } + + @SuppressWarnings("unchecked") + public Object executeArrayLoop( + cod.interpreter.context.ExecutionContext ctx, For node, String iter, Object arrayObj) { + try { + if (arrayObj instanceof NaturalArray) { + NaturalArray natural = (NaturalArray) arrayObj; + long size = natural.size(); + for (long i = 0; i < size; i++) { + Object currentValue = natural.get(i); + ctx.setVariable(iter, currentValue); + try { + executeLoopBody(ctx, node); + } catch (BreakLoopException e) { + break; + } + } + } else if (arrayObj instanceof List) { + List list = (List) arrayObj; + for (Object currentValue : list) { + ctx.setVariable(iter, currentValue); + try { + executeLoopBody(ctx, node); + } catch (BreakLoopException e) { + break; + } + } + } else { + throw new ProgramError("Cannot iterate over: " + + (arrayObj != null ? arrayObj.getClass().getSimpleName() : "null")); + } + return null; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Array loop execution failed", e); + } + } + + public Object executeRangeLoop(cod.interpreter.context.ExecutionContext ctx, For node, String iter) { + try { + Object startObj = dispatcher.dispatch(node.range.start); + Object endObj = dispatcher.dispatch(node.range.end); + startObj = typeSystem.unwrap(startObj); + endObj = typeSystem.unwrap(endObj); + + if (node.range.step != null && node.range.step instanceof BinaryOp) { + BinaryOp binOp = (BinaryOp) node.range.step; + if (binOp.left instanceof Identifier + && ((Identifier) binOp.left).name.equals(iter) + && (binOp.op.equals("*") || binOp.op.equals("/"))) { + Object rightObj = dispatcher.dispatch(binOp.right); + rightObj = typeSystem.unwrap(rightObj); + AutoStackingNumber factor = typeSystem.toAutoStackingNumber(rightObj); + validateFactor(factor, binOp.op); + return executeMultiplicativeLoop(ctx, node, startObj, endObj, factor, binOp.op); + } + } + + AutoStackingNumber step; + if (node.range.step != null) { + Object stepObj = dispatcher.dispatch(node.range.step); + step = typeSystem.toAutoStackingNumber(typeSystem.unwrap(stepObj)); + } else { + AutoStackingNumber start = typeSystem.toAutoStackingNumber(startObj); + AutoStackingNumber end = typeSystem.toAutoStackingNumber(endObj); + step = (start.compareTo(end) > 0) ? AutoStackingNumber.minusOne(1) : AutoStackingNumber.one(1); + } + + if (step.isZero()) { + throw new ProgramError("Loop step cannot be zero."); + } + + return executeAdditiveLoop(ctx, node, startObj, endObj, step); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Range loop execution failed", e); + } + } + + public Object executeAdditiveLoop( + cod.interpreter.context.ExecutionContext ctx, For node, Object startObj, Object endObj, AutoStackingNumber step) { + try { + AutoStackingNumber start = typeSystem.toAutoStackingNumber(startObj); + AutoStackingNumber end = typeSystem.toAutoStackingNumber(endObj); + AutoStackingNumber current = start; + boolean increasing = step.isPositive(); + + while (shouldContinueAdditive(current, end, step, increasing)) { + try { + executeIteration(ctx, node, current, startObj); + } catch (BreakLoopException e) { + break; + } + current = current.add(step); + } + return null; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Additive loop execution failed", e); + } + } + + public Object executeMultiplicativeLoop( + cod.interpreter.context.ExecutionContext ctx, + For node, + Object startObj, + Object endObj, + AutoStackingNumber factor, + String operation) { + try { + AutoStackingNumber start = typeSystem.toAutoStackingNumber(startObj); + AutoStackingNumber end = typeSystem.toAutoStackingNumber(endObj); + AutoStackingNumber current = start; + + while (shouldContinueMultiplicative(current, start, end, factor, operation)) { + try { + executeIteration(ctx, node, current, startObj); + } catch (BreakLoopException e) { + break; + } + if (operation.equals("*")) { + current = current.multiply(factor); + } else { + current = current.divide(factor); + } + } + return null; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Multiplicative loop execution failed", e); + } + } + + public void executeIteration( + cod.interpreter.context.ExecutionContext ctx, For node, AutoStackingNumber current, Object startObj) { + try { + String iter = node.iterator; + Object currentValue = convertToAppropriateType(current, startObj); + ctx.setVariable(iter, currentValue); + if (ctx.getVariableType(iter) == null) { + String inferredType = (current.fitsInStacks(1) && + (current.getWords()[0] & 0x7FFFFFFFFFFFFFFFL) < Long.MAX_VALUE) + ? cod.syntax.Keyword.INT.toString() : cod.syntax.Keyword.FLOAT.toString(); + ctx.setVariableType(iter, inferredType); + } + executeLoopBody(ctx, node); + } catch (BreakLoopException e) { + throw e; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Loop iteration failed", e); + } + } + + public void executeLoopBody(cod.interpreter.context.ExecutionContext ctx, For node) { + try { + for (Stmt s : node.body.statements) { + try { + dispatcher.dispatch(s); + } catch (SkipIterationException e) { + break; + } catch (BreakLoopException e) { + throw e; + } + + if (!ctx.slotsInCurrentPath.isEmpty() + && dispatcher.shouldReturnEarly(ctx.getSlotValues(), ctx.slotsInCurrentPath)) return; + } + } catch (BreakLoopException e) { + throw e; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Loop body execution failed", e); + } + } + + @SuppressWarnings("unchecked") + public Object visitIndexAccess(IndexAccess node) { + if (node == null) { + throw new InternalError("visit(IndexAccess) called with null node"); + } + + try { + Object arrayObj = dispatcher.dispatch(node.array); + arrayObj = typeSystem.unwrap(arrayObj); + + if (arrayObj instanceof NaturalArray) { + NaturalArray natural = (NaturalArray) arrayObj; + if (natural.hasPendingUpdates()) { + natural.commitUpdates(); + } + } + + Object indexObj = dispatcher.dispatch(node.index); + indexObj = typeSystem.unwrap(indexObj); + + if (indexObj instanceof List) { + return applyTupleIndices(arrayObj, (List) indexObj); + } + + if (RangeObjects.isRangeSpec(indexObj)) { + if (arrayObj instanceof String) { + return applyStringRangeIndex((String) arrayObj, indexObj); + } + return applyRangeIndex(arrayObj, indexObj); + } + + if (RangeObjects.isMultiRangeSpec(indexObj)) { + return applyMultiRangeIndex(arrayObj, indexObj); + } + + if (arrayObj instanceof String) { + String text = (String) arrayObj; + int index = expressionHandler.toIntIndex(indexObj); + index = normalizeTextIndex(index, text.length()); + if (index < 0 || index >= text.length()) { + throw new ProgramError( + "Index out of bounds: " + index + " for text of length " + text.length()); + } + return String.valueOf(text.charAt(index)); + } + + if (arrayObj instanceof NaturalArray) { + NaturalArray natural = (NaturalArray) arrayObj; + long index = expressionHandler.toLongIndex(indexObj); + + if (natural.needsConversion()) { + return natural.get(index, true); + } + return natural.get(index); + } + + if (arrayObj instanceof List) { + List list = (List) arrayObj; + if (indexObj instanceof AutoStackingNumber) { + int index = (int) ((AutoStackingNumber) indexObj).longValue(); + if (index < 0 || index >= list.size()) { + throw new ProgramError( + "Index out of bounds: " + index + " for array of size " + list.size()); + } + return list.get(index); + } else { + int index = expressionHandler.toIntIndex(indexObj); + if (index < 0 || index >= list.size()) { + throw new ProgramError( + "Index out of bounds: " + index + " for array of size " + list.size()); + } + return list.get(index); + } + } + + throw new ProgramError( + "Invalid array access: expected NaturalArray or List, got " + + (arrayObj != null ? arrayObj.getClass().getSimpleName() : "null")); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Index access failed", e); + } + } + + public Object visitRangeIndex(RangeIndex node) { + if (node == null) { + throw new InternalError("visit(RangeIndex) called with null node"); + } + + try { + Object step = node.step != null ? dispatcher.dispatch(node.step) : null; + Object start = dispatcher.dispatch(node.start); + Object end = dispatcher.dispatch(node.end); + + return RangeObjects.createRangeSpec(contextHandler.resolveInternalRangeSpecType(), step, start, end); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Range index creation failed", e); + } + } + + public Object visitMultiRangeIndex(MultiRangeIndex node) { + if (node == null) { + throw new InternalError("visit(MultiRangeIndex) called with null node"); + } + + try { + List ranges = new ArrayList(); + for (RangeIndex rangeNode : node.ranges) { + Object range = visitRangeIndex(rangeNode); + if (!RangeObjects.isRangeSpec(range)) { + throw new InternalError("Multi-range index contains non-range value"); + } + ranges.add(range); + } + return RangeObjects.createMultiRangeSpec(contextHandler.resolveInternalMultiRangeSpecType(), ranges); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Multi-range index creation failed", e); + } + } + + @SuppressWarnings("unchecked") + public Object applyRangeIndex(Object array, Object range) { + if (array instanceof NaturalArray) { + NaturalArray natural = (NaturalArray) array; + return natural.getRange(range); + } else if (array instanceof List) { + List list = (List) array; + return getListRange(list, range); + } + throw new ProgramError("Cannot apply range index to " + + (array != null ? array.getClass().getSimpleName() : "null")); + } + + @SuppressWarnings("unchecked") + public Object applyMultiRangeIndex(Object array, Object multiRange) { + if (array instanceof NaturalArray) { + NaturalArray natural = (NaturalArray) array; + return natural.getMultiRange(multiRange); + } else if (array instanceof List) { + List list = (List) array; + return getListMultiRange(list, multiRange); + } + throw new ProgramError("Cannot apply multi-range index to " + + (array != null ? array.getClass().getSimpleName() : "null")); + } + + @SuppressWarnings("unchecked") + public Object applyTupleIndices(Object array, List indices) { + Object current = array; + for (Object rawIndex : indices) { + Object indexObj = typeSystem.unwrap(rawIndex); + if (RangeObjects.isRangeSpec(indexObj)) { + current = applyRangeIndex(current, indexObj); + continue; + } + if (RangeObjects.isMultiRangeSpec(indexObj)) { + current = applyMultiRangeIndex(current, indexObj); + continue; + } + if (current instanceof NaturalArray) { + NaturalArray natural = (NaturalArray) current; + long idx = expressionHandler.toLongIndex(indexObj); + current = natural.needsConversion() ? natural.get(idx, true) : natural.get(idx); + continue; + } + if (current instanceof List) { + List list = (List) current; + int idx = expressionHandler.toIntIndex(indexObj); + if (idx < 0 || idx >= list.size()) { + throw new ProgramError("Index out of bounds: " + idx + " for array of size " + list.size()); + } + current = list.get(idx); + continue; + } + throw new ProgramError("Invalid array access during multidimensional indexing: expected NaturalArray or List, got " + + (current != null ? current.getClass().getSimpleName() : "null")); + } + return current; + } + + public List getListRange(List list, Object range) { + try { + long start, end; + + start = expressionHandler.toLongIndex(RangeObjects.getStart(range)); + if (start < 0) start = list.size() + start; + + end = expressionHandler.toLongIndex(RangeObjects.getEnd(range)); + if (end < 0) end = list.size() + end; + + long step = expressionHandler.calculateStep(range); + + List result = new ArrayList(); + if (step > 0) { + for (long i = start; i <= end && i < list.size(); i += step) { + result.add(list.get((int) i)); + } + } else if (step < 0) { + for (long i = start; i >= end && i >= 0; i += step) { + result.add(list.get((int) i)); + } + } else { + throw new InternalError("Step cannot be zero - should have been caught earlier"); + } + return result; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("List range extraction failed", e); + } + } + + public List getListMultiRange(List list, Object multiRange) { + try { + List result = new ArrayList(); + for (Object range : RangeObjects.getRanges(multiRange)) { + result.addAll(getListRange(list, range)); + } + return result; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("List multi-range extraction failed", e); + } + } + + public String applyStringRangeIndex(String text, Object range) { + try { + long start = expressionHandler.toLongIndex(RangeObjects.getStart(range)); + long end = expressionHandler.toLongIndex(RangeObjects.getEnd(range)); + long step = expressionHandler.calculateStep(range); + + int length = text.length(); + start = normalizeTextIndex(start, length); + end = normalizeTextIndex(end, length); + + if (start < 0 || start >= length) { + throw new ProgramError("Range start index out of bounds: " + start + " for text of length " + length); + } + if (end < 0 || end >= length) { + throw new ProgramError("Range end index out of bounds: " + end + " for text of length " + length); + } + if (step == 0) { + throw new ProgramError("Range step cannot be zero"); + } + + StringBuilder result = new StringBuilder(); + if (step > 0) { + for (long i = start; i <= end; i += step) { + result.append(text.charAt((int) i)); + } + } else { + for (long i = start; i >= end; i += step) { + result.append(text.charAt((int) i)); + } + } + return result.toString(); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("String range extraction failed", e); + } + } + + public int normalizeTextIndex(int index, int length) { + return (int) normalizeTextIndex((long) index, length); + } + + public long normalizeTextIndex(long index, int length) { + if (index < 0) { + return length + index; + } + return index; + } + + public long calculateRangeStep(Range range) { + if (range == null) { + return 1L; + } + + if (range.step != null) { + Object stepObj = dispatcher.dispatch(range.step); + return expressionHandler.toLong(stepObj); + } + + Object startObj = dispatcher.dispatch(range.start); + Object endObj = dispatcher.dispatch(range.end); + long start = expressionHandler.toLong(startObj); + long end = expressionHandler.toLong(endObj); + + return (start < end) ? 1L : -1L; + } + + private boolean shouldContinueAdditive( + AutoStackingNumber current, AutoStackingNumber end, AutoStackingNumber step, boolean increasing) { + return increasing ? current.compareTo(end) <= 0 : current.compareTo(end) >= 0; + } + + private void validateFactor(AutoStackingNumber factor, String operation) { + if (factor.compareTo(AutoStackingNumber.zero(1)) <= 0) { + throw new ProgramError("Factor must be positive"); + } + } + + private boolean shouldContinueMultiplicative( + AutoStackingNumber current, AutoStackingNumber start, AutoStackingNumber end, + AutoStackingNumber factor, String operation) { + int startEndComparison = start.compareTo(end); + if (operation.equals("*")) { + return factor.compareTo(AutoStackingNumber.one(1)) > 0 + ? (startEndComparison < 0 ? current.compareTo(end) <= 0 : current.compareTo(end) >= 0) + : (startEndComparison > 0 ? current.compareTo(end) >= 0 : current.compareTo(end) <= 0); + } else { + return factor.compareTo(AutoStackingNumber.one(1)) > 0 + ? (startEndComparison > 0 ? current.compareTo(end) >= 0 : current.compareTo(end) <= 0) + : (startEndComparison < 0 ? current.compareTo(end) <= 0 : current.compareTo(end) >= 0); + } + } + + private Object convertToAppropriateType(AutoStackingNumber value, Object original) { + if ((original instanceof Integer || original instanceof Long || + original instanceof IntLiteral) && value.fitsInStacks(1)) { + try { + return (int) value.longValue(); + } catch (ArithmeticException e) { + return value.longValue(); + } + } + return value; + } +} diff --git a/src/main/java/cod/interpreter/handler/ContextHandler.java b/src/main/java/cod/interpreter/handler/ContextHandler.java new file mode 100644 index 00000000..6b26f96d --- /dev/null +++ b/src/main/java/cod/interpreter/handler/ContextHandler.java @@ -0,0 +1,89 @@ +package cod.interpreter.handler; + +import cod.ast.node.NoneLiteral; +import cod.ast.node.Type; +import cod.error.InternalError; +import cod.error.ProgramError; +import cod.interpreter.Interpreter; +import cod.interpreter.context.ExecutionContext; +import cod.range.NaturalArray; +import cod.ast.node.Range; + +import java.util.List; +import java.util.Map; + +public class ContextHandler { + private final Interpreter interpreter; + private Type internalRangeSpecType; + private Type internalMultiRangeSpecType; + + public ContextHandler(Interpreter interpreter) { + if (interpreter == null) { + throw new InternalError("ContextHandler constructed with null interpreter"); + } + this.interpreter = interpreter; + } + + public Object createNoneValue() { + return new NoneLiteral(); + } + + public Range getRangeFromArray(NaturalArray arr) { + try { + java.lang.reflect.Field rangeField = NaturalArray.class.getDeclaredField("baseRange"); + rangeField.setAccessible(true); + return (Range) rangeField.get(arr); + } catch (Exception e) { + return null; + } + } + + public Type resolveInternalRangeSpecType() { + if (internalRangeSpecType != null) { + return internalRangeSpecType; + } + try { + Type type = interpreter.getImportResolver().resolveImport("internal.range.RangeSpec"); + if (type == null) { + throw new ProgramError("Unable to load internal.range.RangeSpec"); + } + internalRangeSpecType = type; + return type; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Failed loading internal.range.RangeSpec", e); + } + } + + public Type resolveInternalMultiRangeSpecType() { + if (internalMultiRangeSpecType != null) { + return internalMultiRangeSpecType; + } + try { + Type type = interpreter.getImportResolver().resolveImport("internal.range.MultiRangeSpec"); + if (type == null) { + throw new ProgramError("Unable to load internal.range.MultiRangeSpec"); + } + internalMultiRangeSpecType = type; + return type; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Failed loading internal.range.MultiRangeSpec", e); + } + } + + public boolean isVariableDeclaredInAnyScope(ExecutionContext ctx, String name) { + if (ctx == null || name == null) return false; + List> localsStack = ctx.getLocalsStack(); + if (localsStack == null) return false; + for (int i = localsStack.size() - 1; i >= 0; i--) { + Map scope = localsStack.get(i); + if (scope != null && scope.containsKey(name)) { + return true; + } + } + return false; + } +} diff --git a/src/main/java/cod/interpreter/handler/LambdaInvokingHandler.java b/src/main/java/cod/interpreter/handler/LambdaInvokingHandler.java new file mode 100644 index 00000000..15935280 --- /dev/null +++ b/src/main/java/cod/interpreter/handler/LambdaInvokingHandler.java @@ -0,0 +1,506 @@ +package cod.interpreter.handler; + +import cod.ast.node.*; +import cod.error.InternalError; +import cod.error.ProgramError; +import cod.interpreter.InterpreterVisitor; +import cod.interpreter.TailCallSignal; +import cod.interpreter.context.ExecutionContext; +import cod.interpreter.context.LambdaClosure; + +import java.util.*; + +public class LambdaInvokingHandler { + private final TypeHandler typeSystem; + private final InterpreterVisitor dispatcher; + + public LambdaInvokingHandler(TypeHandler typeSystem, InterpreterVisitor dispatcher) { + if (typeSystem == null) { + throw new InternalError("LambdaInvokingHandler constructed with null typeSystem"); + } + if (dispatcher == null) { + throw new InternalError("LambdaInvokingHandler constructed with null dispatcher"); + } + this.typeSystem = typeSystem; + this.dispatcher = dispatcher; + } + + public LambdaClosure createLambdaClosure(Lambda node, ExecutionContext ctx) { + if (node == null) { + throw new InternalError("createLambdaClosure called with null node"); + } + if (ctx == null) { + throw new InternalError("createLambdaClosure called with null context"); + } + + Map captured = new HashMap(); + for (int i = 0; i < ctx.getScopeDepth(); i++) { + Map scope = ctx.getScope(i); + if (scope != null) { + captured.putAll(scope); + } + } + + return new LambdaClosure( + node, + captured, + ctx.objectInstance, + ctx.currentClass, + ctx.currentLambdaClosure, + Collections.emptyList()); + } + + public Object invokeLambdaCallback( + Object callbackObj, + List args, + ExecutionContext parentCtx, + String ownerMethod) { + + Object callback = typeSystem.unwrap(callbackObj); + LambdaClosure closure; + if (callback instanceof LambdaClosure) { + closure = (LambdaClosure) callback; + } else if (callback instanceof Lambda) { + closure = + new LambdaClosure( + (Lambda) callback, + parentCtx.locals(), + parentCtx.objectInstance, + parentCtx.currentClass, + parentCtx.currentLambdaClosure, + Collections.emptyList()); + } else { + String actualType = callback == null ? "null" : callback.getClass().getSimpleName(); + throw new ProgramError(ownerMethod + " expects a lambda callback, got: " + actualType); + } + + LambdaClosure activeClosure = closure; + List activeIncomingValues = args != null ? args : Collections.emptyList(); + + while (true) { + Lambda lambda = activeClosure.lambda; + List params = resolveLambdaParameters(lambda); + List combinedValues = + mergeBoundAndIncomingLambdaArgs(activeClosure.boundArguments, activeIncomingValues); + + if (shouldAutoCurry(params, combinedValues)) { + return createCurriedLambdaClosure(activeClosure, combinedValues); + } + + int parameterBindCount = Math.min(params.size(), combinedValues.size()); + List values = new ArrayList(combinedValues.subList(0, parameterBindCount)); + List leftoverValues = + new ArrayList(combinedValues.subList(parameterBindCount, combinedValues.size())); + + Map lambdaLocals = + bindLambdaArguments(params, values, activeClosure, ownerMethod); + if (lambda.inferParameters && params.isEmpty()) { + bindPositionalInferredPlaceholderAliases(lambdaLocals, values); + } + + Object result; + try { + if (lambda.expressionBody != null) { + result = evaluateLambdaExpressionBody(lambda, activeClosure, lambdaLocals); + } else { + result = evaluateLambdaBlockBody(lambda, activeClosure, lambdaLocals); + } + } catch (TailCallSignal tailCallSignal) { + if (tailCallSignal.lambdaClosure != null + && tailCallSignal.lambdaClosure == activeClosure) { + activeClosure = tailCallSignal.lambdaClosure; + activeIncomingValues = tailCallSignal.arguments; + continue; + } + throw tailCallSignal; + } + + if (!leftoverValues.isEmpty() && (result instanceof LambdaClosure || result instanceof Lambda)) { + return invokeLambdaCallback(result, leftoverValues, parentCtx, ownerMethod); + } + return result; + } + } + + private List resolveLambdaParameters(Lambda lambda) { + if (lambda == null) { + return new ArrayList(); + } + List params = + lambda.parameters != null ? lambda.parameters : new ArrayList(); + if (!params.isEmpty()) { + return params; + } + if (!lambda.inferParameters) { + return params; + } + + List inferred = inferLambdaParamsFromPlaceholders(lambda); + return inferred; + } + + private List mergeBoundAndIncomingLambdaArgs(List boundArgs, List incomingArgs) { + if ((boundArgs == null || boundArgs.isEmpty()) && (incomingArgs == null || incomingArgs.isEmpty())) { + return Collections.emptyList(); + } + List combined = new ArrayList(); + if (boundArgs != null && !boundArgs.isEmpty()) { + combined.addAll(boundArgs); + } + if (incomingArgs != null && !incomingArgs.isEmpty()) { + combined.addAll(incomingArgs); + } + return combined; + } + + private boolean shouldAutoCurry(List params, List values) { + if (params == null || params.isEmpty()) return false; + int requiredCount = 0; + for (Param param : params) { + if (param == null) continue; + if (!param.hasDefaultValue) { + requiredCount++; + } + } + return values.size() < requiredCount; + } + + private LambdaClosure createCurriedLambdaClosure( + LambdaClosure closure, + List boundArgs) { + + return new LambdaClosure( + closure.lambda, + closure.capturedLocals, + closure.objectInstance, + closure.currentClass, + closure.parentClosure, + boundArgs); + } + + private Map bindLambdaArguments( + List params, + List values, + LambdaClosure closure, + String ownerMethod) { + + Map lambdaLocals = new HashMap(closure.capturedLocals); + for (int i = 0; i < params.size(); i++) { + Param param = params.get(i); + if (param == null || param.name == null) continue; + + Object boundValue = resolveLambdaArgumentValue(i, param, values, closure, lambdaLocals, ownerMethod); + validateLambdaArgumentType(param, boundValue); + lambdaLocals.put(param.name, boundValue); + } + return lambdaLocals; + } + + private Object resolveLambdaArgumentValue( + int index, + Param param, + List values, + LambdaClosure closure, + Map lambdaLocals, + String ownerMethod) { + + if (index < values.size()) { + return values.get(index); + } + if (param.hasDefaultValue && param.defaultValue != null) { + return evaluateLambdaDefaultValue(param, closure, lambdaLocals); + } + throw new ProgramError( + "Missing value for lambda parameter '" + param.name + "' in " + ownerMethod + " callback"); + } + + private Object evaluateLambdaDefaultValue( + Param param, + LambdaClosure closure, + Map lambdaLocals) { + + ExecutionContext defaultCtx = + new ExecutionContext(closure.objectInstance, lambdaLocals, null, null, typeSystem); + defaultCtx.currentClass = closure.currentClass; + defaultCtx.currentLambdaClosure = closure; + dispatcher.pushContext(defaultCtx); + try { + return dispatcher.visit((Base) param.defaultValue); + } finally { + dispatcher.popContext(); + } + } + + private void validateLambdaArgumentType(Param param, Object boundValue) { + if (param.type != null && !typeSystem.validateType(param.type, boundValue)) { + throw new ProgramError( + "Lambda parameter type mismatch for '" + param.name + "'. Expected " + + param.type + ", got: " + typeSystem.getConcreteType(boundValue)); + } + } + + private Object evaluateLambdaExpressionBody( + Lambda lambda, + LambdaClosure closure, + Map lambdaLocals) { + + ExecutionContext exprCtx = + new ExecutionContext(closure.objectInstance, lambdaLocals, null, null, typeSystem); + exprCtx.currentClass = closure.currentClass; + exprCtx.currentLambdaClosure = closure; + dispatcher.pushContext(exprCtx); + try { + return dispatcher.dispatch(lambda.expressionBody); + } finally { + dispatcher.popContext(); + } + } + + private void bindPositionalInferredPlaceholderAliases( + Map lambdaLocals, + List values) { + + if (values == null || values.isEmpty()) return; + Object first = values.get(0); + putIfAbsent(lambdaLocals, "$item", first); + putIfAbsent(lambdaLocals, "$left", first); + putIfAbsent(lambdaLocals, "$acc", first); + putIfAbsent(lambdaLocals, "$value", first); + + if (values.size() > 1) { + Object second = values.get(1); + putIfAbsent(lambdaLocals, "$index", second); + putIfAbsent(lambdaLocals, "$right", second); + putIfAbsent(lambdaLocals, "$next", second); + } + + if (values.size() > 2) { + Object third = values.get(2); + putIfAbsent(lambdaLocals, "$index", third); + putIfAbsent(lambdaLocals, "$position", third); + } + } + + private void putIfAbsent(Map lambdaLocals, String name, Object value) { + if (!lambdaLocals.containsKey(name)) { + lambdaLocals.put(name, value); + } + } + + private Object evaluateLambdaBlockBody( + Lambda lambda, + LambdaClosure closure, + Map lambdaLocals) { + + List lambdaSlots = + lambda.returnSlots != null ? lambda.returnSlots : new ArrayList(); + if (lambdaSlots.isEmpty()) { + throw new ProgramError( + "Lambda with explicit body requires a return contract (::). " + + "Use expression body syntax for implicit return values."); + } + + Map slotValues = new LinkedHashMap(); + Map slotTypes = new LinkedHashMap(); + for (Slot slot : lambdaSlots) { + slotValues.put(slot.name, null); + slotTypes.put(slot.name, slot.type); + } + + ExecutionContext lambdaCtx = + new ExecutionContext(closure.objectInstance, lambdaLocals, slotValues, slotTypes, typeSystem); + lambdaCtx.currentClass = closure.currentClass; + lambdaCtx.currentLambdaClosure = closure; + dispatcher.pushContext(lambdaCtx); + try { + if (lambda.body != null) { + dispatcher.visit((Base) lambda.body); + } + } catch (cod.interpreter.exception.EarlyExitException e) { + // normal lambda early exit + } finally { + dispatcher.popContext(); + } + + if (lambdaSlots.size() == 1) { + return slotValues.get(lambdaSlots.get(0).name); + } + return slotValues; + } + + private List inferLambdaParamsFromPlaceholders(Lambda lambda) { + if (lambda == null) { + return new ArrayList(); + } + LinkedHashSet names = new LinkedHashSet(); + + if (lambda.expressionBody != null) { + collectPlaceholderNames(lambda.expressionBody, names); + } else if (lambda.body != null) { + collectPlaceholderNames(lambda.body, names); + } + + List params = new ArrayList(); + for (String name : names) { + Param param = new Param(); + param.name = name; + param.type = null; + param.typeInferred = true; + param.isLambdaParameter = true; + params.add(param); + } + return params; + } + + private void collectPlaceholderNames(Base node, LinkedHashSet names) { + if (node == null) return; + + if (node instanceof Identifier) { + String name = ((Identifier) node).name; + if (name != null && name.startsWith("$") && name.length() > 1) { + names.add(name); + } + return; + } + + if (node instanceof BinaryOp) { + BinaryOp n = (BinaryOp) node; + collectPlaceholderNames(n.left, names); + collectPlaceholderNames(n.right, names); + return; + } + if (node instanceof Unary) { + collectPlaceholderNames(((Unary) node).operand, names); + return; + } + if (node instanceof TypeCast) { + collectPlaceholderNames(((TypeCast) node).expression, names); + return; + } + if (node instanceof MethodCall) { + MethodCall n = (MethodCall) node; + if (n.arguments != null) { + for (Expr arg : n.arguments) { + collectPlaceholderNames(arg, names); + } + } + if (n.target != null) { + collectPlaceholderNames(n.target, names); + } + return; + } + if (node instanceof PropertyAccess) { + PropertyAccess n = (PropertyAccess) node; + collectPlaceholderNames(n.left, names); + collectPlaceholderNames(n.right, names); + return; + } + if (node instanceof IndexAccess) { + IndexAccess n = (IndexAccess) node; + collectPlaceholderNames(n.array, names); + collectPlaceholderNames(n.index, names); + return; + } + if (node instanceof Array) { + Array n = (Array) node; + if (n.elements != null) { + for (Expr elem : n.elements) { + collectPlaceholderNames(elem, names); + } + } + return; + } + if (node instanceof Tuple) { + Tuple n = (Tuple) node; + if (n.elements != null) { + for (Expr elem : n.elements) { + collectPlaceholderNames(elem, names); + } + } + return; + } + if (node instanceof ExprIf) { + ExprIf n = (ExprIf) node; + collectPlaceholderNames(n.condition, names); + collectPlaceholderNames(n.thenExpr, names); + collectPlaceholderNames(n.elseExpr, names); + return; + } + if (node instanceof BooleanChain) { + BooleanChain n = (BooleanChain) node; + if (n.expressions != null) { + for (Expr expr : n.expressions) { + collectPlaceholderNames(expr, names); + } + } + return; + } + if (node instanceof EqualityChain) { + EqualityChain n = (EqualityChain) node; + collectPlaceholderNames(n.left, names); + if (n.chainArguments != null) { + for (Expr expr : n.chainArguments) { + collectPlaceholderNames(expr, names); + } + } + return; + } + if (node instanceof ChainedComparison) { + ChainedComparison n = (ChainedComparison) node; + if (n.expressions != null) { + for (Expr expr : n.expressions) { + collectPlaceholderNames(expr, names); + } + } + return; + } + if (node instanceof ValueExpr) { + Object value = ((ValueExpr) node).getValue(); + if (value instanceof Base) { + collectPlaceholderNames((Base) value, names); + } + return; + } + if (node instanceof Lambda) { + return; + } + + if (node instanceof Block) { + Block n = (Block) node; + if (n.statements != null) { + for (Stmt stmt : n.statements) { + collectPlaceholderNames(stmt, names); + } + } + return; + } + if (node instanceof SlotAssignment) { + collectPlaceholderNames(((SlotAssignment) node).value, names); + return; + } + if (node instanceof MultipleSlotAssignment) { + MultipleSlotAssignment n = (MultipleSlotAssignment) node; + if (n.assignments != null) { + for (SlotAssignment asg : n.assignments) { + collectPlaceholderNames(asg, names); + } + } + return; + } + if (node instanceof Assignment) { + Assignment n = (Assignment) node; + collectPlaceholderNames(n.left, names); + collectPlaceholderNames(n.right, names); + return; + } + if (node instanceof Var) { + collectPlaceholderNames(((Var) node).value, names); + return; + } + if (node instanceof ReturnSlotAssignment) { + ReturnSlotAssignment n = (ReturnSlotAssignment) node; + collectPlaceholderNames(n.methodCall, names); + collectPlaceholderNames(n.lambda, names); + } + } +} diff --git a/src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java b/src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java new file mode 100644 index 00000000..06916039 --- /dev/null +++ b/src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java @@ -0,0 +1,875 @@ +package cod.interpreter.handler; + +import cod.ast.ASTFactory; +import cod.ast.node.*; +import cod.debug.DebugSystem; +import cod.error.InternalError; +import cod.error.ProgramError; +import cod.interpreter.InterpreterVisitor; +import cod.interpreter.TailCallSignal; +import cod.interpreter.context.ExecutionContext; +import cod.math.AutoStackingNumber; +import cod.range.ArrayTracker; +import cod.range.NaturalArray; +import cod.range.pattern.ConditionalPattern; +import cod.range.pattern.OutputAwarePattern; +import cod.range.pattern.SequencePattern; +import cod.range.formula.ConditionalFormula; +import cod.range.formula.SequenceFormula; + +import java.util.*; + +public class LoopOptimizationHandler { + private static final int LAZY_THRESHOLD = 10; + private static final int MAX_SUPPORTED_LAG = 64; + + private final InterpreterVisitor dispatcher; + private final TypeHandler typeSystem; + private final ExpressionHandler expressionHandler; + private final ArrayOperationHandler arrayOperationHandler; + private final PatternHandler patternHandler; + + public LoopOptimizationHandler( + InterpreterVisitor dispatcher, + TypeHandler typeSystem, + ExpressionHandler expressionHandler, + ArrayOperationHandler arrayOperationHandler, + PatternHandler patternHandler) { + if (dispatcher == null) throw new InternalError("LoopOptimizationHandler dispatcher is null"); + if (typeSystem == null) throw new InternalError("LoopOptimizationHandler typeSystem is null"); + if (expressionHandler == null) throw new InternalError("LoopOptimizationHandler expressionHandler is null"); + if (arrayOperationHandler == null) throw new InternalError("LoopOptimizationHandler arrayOperationHandler is null"); + if (patternHandler == null) throw new InternalError("LoopOptimizationHandler patternHandler is null"); + this.dispatcher = dispatcher; + this.typeSystem = typeSystem; + this.expressionHandler = expressionHandler; + this.arrayOperationHandler = arrayOperationHandler; + this.patternHandler = patternHandler; + } + + public Object executeForLoop(For node) { + if (node == null) { + throw new InternalError("visit(For) called with null node"); + } + + ExecutionContext ctx = dispatcher.getCurrentContext(); + int originalDepth = ctx.getScopeDepth(); + + long loopSize = estimateLoopSize(node, ctx); + boolean hasSideEffects = hasSideEffects(node.body); + + boolean useLazyExecution = shouldUseLazyExecution(loopSize, hasSideEffects); + + int loopId = ArrayTracker.beginLoop(node); + + ArrayTracker.setLoopSize(loopId, loopSize); + ArrayTracker.setSideEffects(loopId, hasSideEffects); + + try { + ctx.pushScope(); + + if (useLazyExecution) { + Object result = tryOptimizedExecution(node, loopId); + if (result != null) { + return result; + } + } else { + DebugSystem.debug("LOOP", + String.format("Skipping optimization: size=%d, sideEffects=%s", + loopSize, hasSideEffects)); + } + + ArrayTracker.incrementIteration(); + + if (node.range != null) { + return arrayOperationHandler.executeRangeLoop(ctx, node, node.iterator); + } else if (node.arraySource != null) { + Object arrayObj = dispatcher.dispatch(node.arraySource); + arrayObj = typeSystem.unwrap(arrayObj); + return arrayOperationHandler.executeArrayLoop(ctx, node, node.iterator, arrayObj); + } + throw new ProgramError("Invalid for loop: neither range nor array source specified"); + + } catch (ProgramError e) { + throw e; + } catch (TailCallSignal e) { + throw e; + } catch (Exception e) { + throw new InternalError("For loop execution failed", e); + } finally { + ArrayTracker.LoopStats stats = ArrayTracker.endLoop(); + if (stats != null) { + DebugSystem.debug("LOOP", stats.toString()); + } + + while (ctx.getScopeDepth() > originalDepth) { + ctx.popScope(); + } + } + } + + public boolean shouldUseLazyExecution(long loopSize, boolean hasSideEffects) { + if (loopSize < 0) { + return false; + } + + if (loopSize < LAZY_THRESHOLD) { + return !hasSideEffects; + } + + return true; + } + + public long estimateLoopSize(For node, ExecutionContext ctx) { + try { + if (node.range != null) { + Object startObj = dispatcher.dispatch(node.range.start); + Object endObj = dispatcher.dispatch(node.range.end); + + startObj = typeSystem.unwrap(startObj); + endObj = typeSystem.unwrap(endObj); + + AutoStackingNumber start = typeSystem.toAutoStackingNumber(startObj); + AutoStackingNumber end = typeSystem.toAutoStackingNumber(endObj); + + AutoStackingNumber step; + if (node.range.step != null) { + Object stepObj = dispatcher.dispatch(node.range.step); + step = typeSystem.toAutoStackingNumber(typeSystem.unwrap(stepObj)); + } else { + step = (start.compareTo(end) > 0) ? + AutoStackingNumber.minusOne(1) : AutoStackingNumber.one(1); + } + + if (step.isZero()) return 0; + + AutoStackingNumber diff = end.subtract(start); + AutoStackingNumber steps = diff.divide(step); + AutoStackingNumber size = steps.add(AutoStackingNumber.one(1)); + + return size.longValue(); + + } else if (node.arraySource != null) { + Object arrayObj = dispatcher.dispatch(node.arraySource); + arrayObj = typeSystem.unwrap(arrayObj); + + if (arrayObj instanceof NaturalArray) { + NaturalArray arr = (NaturalArray) arrayObj; + if (arr.hasPendingUpdates()) { + arr.commitUpdates(); + } + return arr.size(); + } else if (arrayObj instanceof List) { + return ((List) arrayObj).size(); + } + } + } catch (Exception e) { + DebugSystem.debug("LOOP", "Failed to estimate size: " + e.getMessage()); + } + + return -1; + } + + public boolean hasSideEffects(Block body) { + if (body == null || body.statements == null) return false; + + for (Stmt stmt : body.statements) { + if (stmt instanceof MethodCall) { + MethodCall call = (MethodCall) stmt; + if ("out".equals(call.name) || "outs".equals(call.name) || "in".equals(call.name)) { + return true; + } + return true; + } + + if (stmt instanceof StmtIf) { + StmtIf ifStmt = (StmtIf) stmt; + if (hasSideEffects(ifStmt.thenBlock) || hasSideEffects(ifStmt.elseBlock)) { + return true; + } + } + + if (stmt instanceof For) { + return true; + } + + if (stmt instanceof Assignment) { + Assignment assign = (Assignment) stmt; + if (assign.left instanceof PropertyAccess) { + return true; + } + } + } + + return false; + } + + public Object tryOptimizedExecution(For node, int loopId) { + OutputAwarePattern.OutputPattern outputPattern = + OutputAwarePattern.extract(node, node.iterator); + + if (outputPattern.isOptimizable) { + try { + Object result = executeOutputAwareLoop(node, outputPattern); + ArrayTracker.markLoopOptimized(loopId); + return result; + } catch (Exception e) { + DebugSystem.debug("OPTIMIZER", "Output pattern failed: " + e.getMessage()); + } + } + + List multiArrayPatterns = extractMultiArraySequencePatterns(node); + if (!multiArrayPatterns.isEmpty()) { + try { + Object result = patternHandler.applyPatterns(node, multiArrayPatterns); + ArrayTracker.markLoopOptimized(loopId); + return result; + } catch (Exception e) { + DebugSystem.debug("OPTIMIZER", "Multi-array pattern failed: " + e.getMessage()); + } + } + + PatternHandler.LinearRecurrencePattern recurrencePattern = extractLinearRecurrencePattern(node); + if (recurrencePattern != null) { + try { + List patterns = new ArrayList(); + patterns.add(new PatternHandler.PatternResult(PatternHandler.PatternType.LINEAR_RECURRENCE, recurrencePattern, recurrencePattern.targetArray)); + Object result = patternHandler.applyPatterns(node, patterns); + ArrayTracker.markLoopOptimized(loopId); + return result; + } catch (Exception e) { + DebugSystem.debug("OPTIMIZER", "Linear recurrence pattern failed: " + e.getMessage()); + } + } + + SequencePattern.Pattern seqPattern = + SequencePattern.extract(node.body.statements, node.iterator); + if (seqPattern != null && seqPattern.isOptimizable()) { + try { + List patterns = new ArrayList(); + patterns.add(new PatternHandler.PatternResult(PatternHandler.PatternType.SEQUENCE, seqPattern, seqPattern.targetArray)); + Object result = patternHandler.applyPatterns(node, patterns); + ArrayTracker.markLoopOptimized(loopId); + return result; + } catch (Exception e) { + DebugSystem.debug("OPTIMIZER", "Sequence pattern failed: " + e.getMessage()); + } + } + + List allPatterns = new ArrayList(); + for (Stmt stmt : node.body.statements) { + if (stmt instanceof StmtIf) { + StmtIf ifStmt = (StmtIf) stmt; + List patterns = extractConditionalPatterns(ifStmt, node.iterator); + for (ConditionalPattern pattern : patterns) { + if (pattern != null && pattern.isOptimizable()) { + allPatterns.add(new PatternHandler.PatternResult(PatternHandler.PatternType.CONDITIONAL, pattern, pattern.array)); + } + } + } + } + + if (!allPatterns.isEmpty()) { + try { + Object result = patternHandler.applyPatterns(node, allPatterns); + ArrayTracker.markLoopOptimized(loopId); + return result; + } catch (Exception e) { + DebugSystem.debug("OPTIMIZER", "Conditional pattern failed: " + e.getMessage()); + } + } + + return null; + } + + public PatternHandler.LinearRecurrencePattern extractLinearRecurrencePattern(For node) { + if (node == null || node.body == null || node.body.statements == null) { + return null; + } + if (node.body.statements.size() != 1) { + return null; + } + if (!(node.body.statements.get(0) instanceof Assignment)) { + return null; + } + Assignment assign = (Assignment) node.body.statements.get(0); + if (!(assign.left instanceof IndexAccess)) { + return null; + } + IndexAccess leftAccess = (IndexAccess) assign.left; + if (!(leftAccess.array instanceof Identifier) || !(leftAccess.index instanceof Identifier)) { + return null; + } + String iter = node.iterator; + Identifier idx = (Identifier) leftAccess.index; + if (!iter.equals(idx.name)) { + return null; + } + + Object resolved = dispatcher.dispatch(leftAccess.array); + resolved = typeSystem.unwrap(resolved); + if (!(resolved instanceof NaturalArray)) { + return null; + } + NaturalArray targetArray = (NaturalArray) resolved; + + Set deps = new HashSet(); + collectIndexedArrayRefs(assign.right, iter, deps); + String targetName = ((Identifier) leftAccess.array).name; + if (!deps.contains(targetName)) { + return null; + } + for (String dep : deps) { + if (!targetName.equals(dep)) { + return null; + } + } + + AutoStackingNumber[] coeff = new AutoStackingNumber[MAX_SUPPORTED_LAG + 1]; + for (int i = 0; i < coeff.length; i++) coeff[i] = AutoStackingNumber.fromLong(0L); + AutoStackingNumber[] constant = new AutoStackingNumber[]{AutoStackingNumber.fromLong(0L)}; + if (!collectLinearTerms(assign.right, targetName, iter, coeff, constant, AutoStackingNumber.fromLong(1L))) { + return null; + } + + int maxLag = 0; + boolean hasAnyLag = false; + for (int lag = 1; lag < coeff.length; lag++) { + if (!coeff[lag].isZero()) { + hasAnyLag = true; + if (lag > maxLag) maxLag = lag; + } + } + if (!hasAnyLag || maxLag <= 0) { + return null; + } + + int order = maxLag; + AutoStackingNumber[] coeffByLag = new AutoStackingNumber[order]; + for (int lag = 1; lag <= order; lag++) { + coeffByLag[lag - 1] = coeff[lag]; + } + + long[] bounds = resolveLoopBounds(node); + if (bounds == null) { + return null; + } + long min = bounds[0]; + long max = bounds[1]; + long recurrenceStart = min; + if (recurrenceStart < order) { + recurrenceStart = order; + } + if (recurrenceStart > max) { + return null; + } + + AutoStackingNumber[] seed = new AutoStackingNumber[order]; + long seedStart = recurrenceStart - order; + for (int i = 0; i < order; i++) { + long idxSeed = seedStart + i; + Object vObj = targetArray.get(idxSeed); + AutoStackingNumber v = typeSystem.toAutoStackingNumber(vObj); + if (v == null) { + return null; + } + seed[i] = v; + } + + return new PatternHandler.LinearRecurrencePattern( + leftAccess.array, + order, + coeffByLag, + constant[0], + recurrenceStart, + seedStart, + seed + ); + } + + private boolean collectLinearTerms( + Expr expr, + String targetArrayName, + String iterator, + AutoStackingNumber[] coeffByLag, + AutoStackingNumber[] constant, + AutoStackingNumber sign + ) { + if (expr == null) return false; + + if (expr instanceof BinaryOp) { + BinaryOp bin = (BinaryOp) expr; + if ("+".equals(bin.op)) { + return collectLinearTerms(bin.left, targetArrayName, iterator, coeffByLag, constant, sign) && + collectLinearTerms(bin.right, targetArrayName, iterator, coeffByLag, constant, sign); + } + if ("-".equals(bin.op)) { + return collectLinearTerms(bin.left, targetArrayName, iterator, coeffByLag, constant, sign) && + collectLinearTerms(bin.right, targetArrayName, iterator, coeffByLag, constant, sign.multiply(AutoStackingNumber.fromLong(-1L))); + } + if ("*".equals(bin.op)) { + TermRef ref = extractIndexedTargetTerm(bin.left, targetArrayName, iterator); + AutoStackingNumber scalar = toNumericLiteral(bin.right); + if (ref == null || scalar == null) { + ref = extractIndexedTargetTerm(bin.right, targetArrayName, iterator); + scalar = toNumericLiteral(bin.left); + } + if (ref != null && scalar != null) { + AutoStackingNumber c = sign.multiply(scalar); + coeffByLag[ref.lag] = coeffByLag[ref.lag].add(c); + return true; + } + return false; + } + return false; + } + + TermRef ref = extractIndexedTargetTerm(expr, targetArrayName, iterator); + if (ref != null) { + coeffByLag[ref.lag] = coeffByLag[ref.lag].add(sign); + return true; + } + + AutoStackingNumber literal = toNumericLiteral(expr); + if (literal != null) { + constant[0] = constant[0].add(sign.multiply(literal)); + return true; + } + + return false; + } + + private static class TermRef { + final int lag; + TermRef(int lag) { this.lag = lag; } + } + + private TermRef extractIndexedTargetTerm(Expr expr, String targetArrayName, String iterator) { + if (!(expr instanceof IndexAccess)) { + return null; + } + IndexAccess access = (IndexAccess) expr; + if (!(access.array instanceof Identifier)) { + return null; + } + String arrayName = ((Identifier) access.array).name; + if (!targetArrayName.equals(arrayName)) { + return null; + } + int lag = extractLag(access.index, iterator); + if (lag <= 0 || lag > MAX_SUPPORTED_LAG) { + return null; + } + return new TermRef(lag); + } + + private int extractLag(Expr indexExpr, String iterator) { + if (indexExpr instanceof BinaryOp) { + BinaryOp bin = (BinaryOp) indexExpr; + if ("-".equals(bin.op) && bin.left instanceof Identifier && + iterator.equals(((Identifier) bin.left).name)) { + AutoStackingNumber n = toNumericLiteral(bin.right); + if (n == null) return -1; + long lag = n.longValue(); + if (lag <= 0 || lag > Integer.MAX_VALUE) return -1; + return (int) lag; + } + } + return -1; + } + + private AutoStackingNumber toNumericLiteral(Expr expr) { + if (expr instanceof IntLiteral) { + return ((IntLiteral) expr).value; + } + if (expr instanceof FloatLiteral) { + return ((FloatLiteral) expr).value; + } + if (expr instanceof Unary) { + Unary unary = (Unary) expr; + if ("-".equals(unary.op)) { + AutoStackingNumber inner = toNumericLiteral(unary.operand); + if (inner == null) return null; + return AutoStackingNumber.fromLong(0L).subtract(inner); + } + if ("+".equals(unary.op)) { + return toNumericLiteral(unary.operand); + } + } + return null; + } + + private long[] resolveLoopBounds(For node) { + if (node == null) return null; + if (node.range != null) { + Object startObj = dispatcher.dispatch(node.range.start); + Object endObj = dispatcher.dispatch(node.range.end); + long start = expressionHandler.toLong(startObj); + long end = expressionHandler.toLong(endObj); + return new long[]{Math.min(start, end), Math.max(start, end)}; + } + if (node.arraySource != null) { + Object sourceObj = dispatcher.dispatch(node.arraySource); + sourceObj = typeSystem.unwrap(sourceObj); + if (sourceObj instanceof NaturalArray) { + NaturalArray sourceArr = (NaturalArray) sourceObj; + if (sourceArr.size() > 0) { + return new long[]{0L, sourceArr.size() - 1L}; + } + } else if (sourceObj instanceof List) { + List list = (List) sourceObj; + if (!list.isEmpty()) { + return new long[]{0L, list.size() - 1L}; + } + } + } + return null; + } + + public List extractMultiArraySequencePatterns(For node) { + List results = new ArrayList(); + if (node == null || node.body == null || node.body.statements == null) { + return results; + } + + List statements = node.body.statements; + if (statements.size() < 2) { + return results; + } + + List orderedTargets = new ArrayList(); + List orderedAssignments = new ArrayList(); + + for (Stmt stmt : statements) { + if (!(stmt instanceof Assignment)) { + return new ArrayList(); + } + + Assignment assign = (Assignment) stmt; + if (assign.isDeclaration || !(assign.left instanceof IndexAccess)) { + return new ArrayList(); + } + + IndexAccess indexAccess = (IndexAccess) assign.left; + if (!(indexAccess.array instanceof Identifier) || !(indexAccess.index instanceof Identifier)) { + return new ArrayList(); + } + + Identifier index = (Identifier) indexAccess.index; + if (!node.iterator.equals(index.name)) { + return new ArrayList(); + } + + String targetName = ((Identifier) indexAccess.array).name; + if (orderedTargets.contains(targetName)) { + return new ArrayList(); + } + + orderedTargets.add(targetName); + orderedAssignments.add(assign); + } + + for (int i = 0; i < orderedAssignments.size(); i++) { + Assignment assign = orderedAssignments.get(i); + IndexAccess indexAccess = (IndexAccess) assign.left; + Identifier targetArray = (Identifier) indexAccess.array; + + Set refs = new HashSet(); + collectIndexedArrayRefs(assign.right, node.iterator, refs); + + for (String ref : refs) { + int refIndex = orderedTargets.indexOf(ref); + if (refIndex == -1 || refIndex > i) { + return new ArrayList(); + } + } + + List steps = new ArrayList(); + steps.add(new SequencePattern.Step(null, assign.right)); + SequencePattern.Pattern pattern = new SequencePattern.Pattern(steps, targetArray, node.iterator); + results.add(new PatternHandler.PatternResult(PatternHandler.PatternType.SEQUENCE, pattern, targetArray)); + } + + return results; + } + + private void collectIndexedArrayRefs(Expr expr, String iterator, Set refs) { + if (expr == null || refs == null) { + return; + } + + if (expr instanceof IndexAccess) { + IndexAccess access = (IndexAccess) expr; + if (access.array instanceof Identifier && access.index instanceof Identifier) { + Identifier idx = (Identifier) access.index; + if (iterator.equals(idx.name)) { + refs.add(((Identifier) access.array).name); + } + } + collectIndexedArrayRefs(access.array, iterator, refs); + collectIndexedArrayRefs(access.index, iterator, refs); + return; + } + + if (expr instanceof BinaryOp) { + BinaryOp bin = (BinaryOp) expr; + collectIndexedArrayRefs(bin.left, iterator, refs); + collectIndexedArrayRefs(bin.right, iterator, refs); + return; + } + + if (expr instanceof Unary) { + collectIndexedArrayRefs(((Unary) expr).operand, iterator, refs); + return; + } + + if (expr instanceof MethodCall) { + MethodCall call = (MethodCall) expr; + if (call.arguments != null) { + for (Expr arg : call.arguments) { + collectIndexedArrayRefs(arg, iterator, refs); + } + } + return; + } + + if (expr instanceof TypeCast) { + collectIndexedArrayRefs(((TypeCast) expr).expression, iterator, refs); + return; + } + + if (expr instanceof PropertyAccess) { + PropertyAccess prop = (PropertyAccess) expr; + collectIndexedArrayRefs(prop.left, iterator, refs); + collectIndexedArrayRefs(prop.right, iterator, refs); + return; + } + + if (expr instanceof Tuple) { + Tuple tuple = (Tuple) expr; + if (tuple.elements != null) { + for (Expr elem : tuple.elements) { + collectIndexedArrayRefs(elem, iterator, refs); + } + } + return; + } + + if (expr instanceof Array) { + Array array = (Array) expr; + if (array.elements != null) { + for (Expr elem : array.elements) { + collectIndexedArrayRefs(elem, iterator, refs); + } + } + } + } + + public Object executeOutputAwareLoop(For node, OutputAwarePattern.OutputPattern pattern) { + ExecutionContext ctx = dispatcher.getCurrentContext(); + + try { + NaturalArray arr = createArrayFromOutputPattern(node, pattern.computation, ctx); + + ctx.enterOptimizedLoop(); + + if (node.range != null) { + executeOutputRangeLoop(ctx, node, arr, pattern.outputCalls); + } else if (node.arraySource != null) { + executeOutputArrayLoop(ctx, node, arr, pattern.outputCalls); + } + return arr; + } finally { + ctx.exitOptimizedLoop(); + } + } + + public NaturalArray createArrayFromOutputPattern(For node, Object computation, ExecutionContext ctx) { + if (computation instanceof SequencePattern.Pattern) { + SequencePattern.Pattern seqPattern = (SequencePattern.Pattern) computation; + + Range range = node.range; + if (range == null && node.arraySource != null) { + Object sourceObj = dispatcher.dispatch(node.arraySource); + sourceObj = typeSystem.unwrap(sourceObj); + + if (sourceObj instanceof NaturalArray) { + NaturalArray sourceArr = (NaturalArray) sourceObj; + long size = sourceArr.size(); + + Expr start = ASTFactory.createIntLiteral(0, null); + Expr end = ASTFactory.createIntLiteral((int)(size - 1), null); + range = ASTFactory.createRange(null, start, end, null, null); + } + } + + if (range == null) { + throw new ProgramError("Cannot create array from pattern: no range specified"); + } + + NaturalArray arr = new NaturalArray(range, dispatcher, ctx); + + if (seqPattern.isSimple()) { + SequenceFormula formula = SequenceFormula.createSimple( + 0, arr.size() - 1, + seqPattern.getFinalExpression(), + node.iterator + ); + arr.addSequenceFormula(formula); + } else { + SequenceFormula formula = SequenceFormula.createFromSequence( + 0, arr.size() - 1, node.iterator, + seqPattern.getTempVarNames(), + seqPattern.getTempExpressions(), + seqPattern.getFinalExpression() + ); + arr.addSequenceFormula(formula); + } + + return arr; + + } else if (computation instanceof ConditionalPattern) { + ConditionalPattern condPattern = (ConditionalPattern) computation; + + Range range = node.range; + if (range == null && node.arraySource != null) { + Object sourceObj = dispatcher.dispatch(node.arraySource); + sourceObj = typeSystem.unwrap(sourceObj); + + if (sourceObj instanceof NaturalArray) { + NaturalArray sourceArr = (NaturalArray) sourceObj; + long size = sourceArr.size(); + + Expr start = ASTFactory.createIntLiteral(0, null); + Expr end = ASTFactory.createIntLiteral((int)(size - 1), null); + range = ASTFactory.createRange(null, start, end, null, null); + } + } + + if (range == null) { + throw new ProgramError("Cannot create array from pattern: no range specified"); + } + + NaturalArray arr = new NaturalArray(range, dispatcher, ctx); + + List conditions = new ArrayList(); + List> branchStatements = new ArrayList>(); + + for (ConditionalPattern.Branch branch : condPattern.branches) { + conditions.add(branch.condition); + branchStatements.add(branch.statements); + } + + ConditionalFormula formula = new ConditionalFormula( + 0, arr.size() - 1, node.iterator, + conditions, + branchStatements, + condPattern.elseStatements + ); + arr.addConditionalFormula(formula); + + return arr; + } + + throw new ProgramError("Unknown computation pattern type"); + } + + public void executeOutputRangeLoop(ExecutionContext ctx, For node, + NaturalArray arr, List outputCalls) { + try { + Object startObj = dispatcher.dispatch(node.range.start); + Object endObj = dispatcher.dispatch(node.range.end); + startObj = typeSystem.unwrap(startObj); + endObj = typeSystem.unwrap(endObj); + + long start = expressionHandler.toLong(startObj); + long end = expressionHandler.toLong(endObj); + long step = arrayOperationHandler.calculateRangeStep(node.range); + + for (long i = start; i <= end; i += step) { + Object value = arr.get(i); + + arr.recordOutput(i, value); + + ctx.setVariable(node.iterator, value); + + for (MethodCall outputCall : outputCalls) { + MethodCall evalCall = new MethodCall(); + evalCall.name = outputCall.name; + evalCall.arguments = new ArrayList(); + + for (Expr arg : outputCall.arguments) { + if (arg instanceof Identifier && + "_".equals(((Identifier) arg).name)) { + evalCall.arguments.add(new ValueExpr(value)); + } else { + evalCall.arguments.add(arg); + } + } + + dispatcher.dispatch(evalCall); + } + } + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Output range loop execution failed", e); + } + } + + public void executeOutputArrayLoop(ExecutionContext ctx, For node, + NaturalArray arr, List outputCalls) { + try { + Object sourceObj = dispatcher.dispatch(node.arraySource); + sourceObj = typeSystem.unwrap(sourceObj); + + long size = 0; + if (sourceObj instanceof NaturalArray) { + size = ((NaturalArray) sourceObj).size(); + } else if (sourceObj instanceof List) { + size = ((List) sourceObj).size(); + } else { + throw new ProgramError("Cannot iterate over: " + + (sourceObj != null ? sourceObj.getClass().getSimpleName() : "null")); + } + + for (long i = 0; i < size; i++) { + Object value = arr.get(i); + + arr.recordOutput(i, value); + + ctx.setVariable(node.iterator, value); + + for (MethodCall outputCall : outputCalls) { + MethodCall evalCall = new MethodCall(); + evalCall.name = outputCall.name; + evalCall.arguments = new ArrayList(); + + for (Expr arg : outputCall.arguments) { + if (arg instanceof Identifier && + "_".equals(((Identifier) arg).name)) { + evalCall.arguments.add(new ValueExpr(value)); + } else { + evalCall.arguments.add(arg); + } + } + + dispatcher.dispatch(evalCall); + } + } + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Output array loop execution failed", e); + } + } + + public List extractConditionalPatterns(StmtIf ifStmt, String iterator) { + try { + return ConditionalPattern.extractAll(ifStmt, iterator); + } catch (Exception e) { + DebugSystem.debug("OPTIMIZER", "Failed to extract conditional pattern: " + e.getMessage()); + return new ArrayList(); + } + } +} diff --git a/src/main/java/cod/interpreter/handler/PatternHandler.java b/src/main/java/cod/interpreter/handler/PatternHandler.java new file mode 100644 index 00000000..1d2c6a24 --- /dev/null +++ b/src/main/java/cod/interpreter/handler/PatternHandler.java @@ -0,0 +1,288 @@ +package cod.interpreter.handler; + +import cod.ast.node.*; +import cod.debug.DebugSystem; +import cod.error.InternalError; +import cod.error.ProgramError; +import cod.interpreter.InterpreterVisitor; +import cod.math.AutoStackingNumber; +import cod.range.NaturalArray; +import cod.range.formula.ConditionalFormula; +import cod.range.formula.LinearRecurrenceFormula; +import cod.range.formula.SequenceFormula; +import cod.range.pattern.ConditionalPattern; +import cod.range.pattern.SequencePattern; + +import java.util.*; + +public class PatternHandler { + public enum PatternType { + CONDITIONAL, + SEQUENCE, + LINEAR_RECURRENCE + } + + public static class PatternResult { + public final PatternType type; + public final Object pattern; + public final Expr targetArray; + + public PatternResult(PatternType type, Object pattern, Expr targetArray) { + if (type == null) { + throw new InternalError("PatternResult constructed with null type"); + } + this.type = type; + this.pattern = pattern; + this.targetArray = targetArray; + } + } + + public static class LinearRecurrencePattern { + public final Expr targetArray; + public final int order; + public final AutoStackingNumber[] coefficientsByLag; + public final AutoStackingNumber constantTerm; + public final long recurrenceStart; + public final long seedStart; + public final AutoStackingNumber[] seedValues; + + public LinearRecurrencePattern( + Expr targetArray, + int order, + AutoStackingNumber[] coefficientsByLag, + AutoStackingNumber constantTerm, + long recurrenceStart, + long seedStart, + AutoStackingNumber[] seedValues + ) { + this.targetArray = targetArray; + this.order = order; + this.coefficientsByLag = coefficientsByLag; + this.constantTerm = constantTerm; + this.recurrenceStart = recurrenceStart; + this.seedStart = seedStart; + this.seedValues = seedValues; + } + } + + private final InterpreterVisitor dispatcher; + private final TypeHandler typeSystem; + private final ExpressionHandler expressionHandler; + private final ArrayOperationHandler arrayOperationHandler; + + public PatternHandler( + InterpreterVisitor dispatcher, + TypeHandler typeSystem, + ExpressionHandler expressionHandler, + ArrayOperationHandler arrayOperationHandler + ) { + if (dispatcher == null) throw new InternalError("PatternHandler dispatcher is null"); + if (typeSystem == null) throw new InternalError("PatternHandler typeSystem is null"); + if (expressionHandler == null) throw new InternalError("PatternHandler expressionHandler is null"); + if (arrayOperationHandler == null) throw new InternalError("PatternHandler arrayOperationHandler is null"); + this.dispatcher = dispatcher; + this.typeSystem = typeSystem; + this.expressionHandler = expressionHandler; + this.arrayOperationHandler = arrayOperationHandler; + } + + public Object applyPatterns(For node, List patterns) { + if (node == null) { + throw new InternalError("applyPatterns called with null node"); + } + if (patterns == null) { + throw new InternalError("applyPatterns called with null patterns"); + } + + try { + List targetArrays = new ArrayList(); + List> groupedPatterns = new ArrayList>(); + Map arrayIdToGroupIndex = new HashMap(); + + for (PatternResult result : patterns) { + if (result == null || result.targetArray == null) { + continue; + } + + Object resolvedArray = dispatcher.dispatch(result.targetArray); + resolvedArray = typeSystem.unwrap(resolvedArray); + + if (!(resolvedArray instanceof NaturalArray)) { + DebugSystem.debug("OPTIMIZER", "Array not optimizable, falling back to normal execution"); + return arrayOperationHandler.executeForLoopNormally(node); + } + + NaturalArray naturalArray = (NaturalArray) resolvedArray; + int arrayId = naturalArray.getArrayId(); + Integer existingGroup = arrayIdToGroupIndex.get(arrayId); + int groupIndex = existingGroup != null ? existingGroup.intValue() : -1; + + if (groupIndex == -1) { + targetArrays.add(naturalArray); + List newGroup = new ArrayList(); + newGroup.add(result); + groupedPatterns.add(newGroup); + arrayIdToGroupIndex.put(arrayId, targetArrays.size() - 1); + } else { + groupedPatterns.get(groupIndex).add(result); + } + } + + if (targetArrays.isEmpty()) { + DebugSystem.debug("OPTIMIZER", "No target arrays found, falling back to normal execution"); + return arrayOperationHandler.executeForLoopNormally(node); + } + + long start = 0, end = 0; + boolean boundsFound = false; + + if (node.range != null) { + Object startObj = dispatcher.dispatch(node.range.start); + Object endObj = dispatcher.dispatch(node.range.end); + start = expressionHandler.toLong(startObj); + end = expressionHandler.toLong(endObj); + boundsFound = true; + } else if (node.arraySource != null) { + Object sourceObj = dispatcher.dispatch(node.arraySource); + if (sourceObj instanceof NaturalArray) { + NaturalArray sourceArr = (NaturalArray) sourceObj; + if (sourceArr.size() > 0) { + start = 0; + end = sourceArr.size() - 1; + boundsFound = true; + } + } + } + + if (!boundsFound) { + DebugSystem.debug("OPTIMIZER", "Could not determine bounds, falling back to normal execution"); + return arrayOperationHandler.executeForLoopNormally(node); + } + + long min = Math.min(start, end); + long max = Math.max(start, end); + + for (int arrayIndex = 0; arrayIndex < targetArrays.size(); arrayIndex++) { + NaturalArray arr = targetArrays.get(arrayIndex); + List arrayPatterns = groupedPatterns.get(arrayIndex); + + for (PatternResult result : arrayPatterns) { + if (result.type == PatternType.SEQUENCE) { + applySequencePattern(arr, (SequencePattern.Pattern) result.pattern, min, max, node.iterator); + } else if (result.type == PatternType.CONDITIONAL) { + applyConditionalPattern(arr, (ConditionalPattern) result.pattern, min, max, node.iterator); + } else if (result.type == PatternType.LINEAR_RECURRENCE) { + applyLinearRecurrencePattern(arr, (LinearRecurrencePattern) result.pattern, min, max, node.iterator); + } + } + } + + return targetArrays.get(targetArrays.size() - 1); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Pattern application failed, falling back to normal execution", e); + } + } + + public void applyConditionalPattern(NaturalArray arr, ConditionalPattern pattern, + long min, long max, String iterator) { + if (pattern == null) { + throw new InternalError("applyConditionalPattern called with null pattern"); + } + if (arr == null) { + throw new InternalError("applyConditionalPattern called with null array"); + } + + try { + List conditions = new ArrayList(); + List> branchStatements = new ArrayList>(); + + for (ConditionalPattern.Branch branch : pattern.branches) { + conditions.add(branch.condition); + branchStatements.add(branch.statements); + } + + ConditionalFormula formula = new ConditionalFormula( + min, max, iterator, + conditions, + branchStatements, + pattern.elseStatements + ); + arr.addConditionalFormula(formula); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Failed to apply conditional pattern", e); + } + } + + public void applySequencePattern(NaturalArray arr, + SequencePattern.Pattern pattern, + long min, long max, String iterator) { + if (pattern == null) { + throw new InternalError("applySequencePattern called with null pattern"); + } + if (arr == null) { + throw new InternalError("applySequencePattern called with null array"); + } + + try { + SequenceFormula formula; + + if (pattern.isSimple()) { + formula = SequenceFormula.createSimple(min, max, pattern.getFinalExpression(), iterator); + } else { + formula = SequenceFormula.createFromSequence( + min, max, iterator, + pattern.getTempVarNames(), + pattern.getTempExpressions(), + pattern.getFinalExpression() + ); + } + + arr.addSequenceFormula(formula); + + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Failed to apply sequence pattern", e); + } + } + + public void applyLinearRecurrencePattern( + NaturalArray arr, + LinearRecurrencePattern pattern, + long min, + long max, + String iterator + ) { + if (arr == null) { + throw new InternalError("applyLinearRecurrencePattern called with null array"); + } + if (pattern == null) { + throw new InternalError("applyLinearRecurrencePattern called with null pattern"); + } + try { + long start = Math.max(min, pattern.seedStart); + long end = max; + if (end < start) { + return; + } + LinearRecurrenceFormula formula = new LinearRecurrenceFormula( + start, + end, + pattern.recurrenceStart, + pattern.coefficientsByLag, + pattern.constantTerm, + pattern.seedValues, + pattern.seedStart + ); + arr.addLinearRecurrenceFormula(formula); + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError("Failed to apply linear recurrence pattern", e); + } + } +} diff --git a/src/main/java/cod/ir/IRArtifactCodec.java b/src/main/java/cod/ir/IRArtifactCodec.java index 32fd874b..ed3162e6 100644 --- a/src/main/java/cod/ir/IRArtifactCodec.java +++ b/src/main/java/cod/ir/IRArtifactCodec.java @@ -1,14 +1,14 @@ package cod.ir; import cod.ast.node.Type; -import cod.ptac.CodPTACArtifact; -import cod.ptac.CodPTACFlag; -import cod.ptac.CodPTACFunction; -import cod.ptac.CodPTACInstruction; -import cod.ptac.CodPTACOperand; -import cod.ptac.CodPTACOperandKind; -import cod.ptac.CodPTACOpcode; -import cod.ptac.CodPTACUnit; +import cod.ptac.Artifact; +import cod.ptac.Flag; +import cod.ptac.Function; +import cod.ptac.Instruction; +import cod.ptac.Operand; +import cod.ptac.OperandKind; +import cod.ptac.Opcode; +import cod.ptac.Unit; import java.io.DataInput; import java.io.DataOutput; @@ -24,12 +24,12 @@ final class IRArtifactCodec { private IRArtifactCodec() {} - static void writeArtifact(DataOutput out, CodPTACArtifact artifact) throws IOException { + static void writeArtifact(DataOutput out, Artifact artifact) throws IOException { IRCodec.writeHeader(out); IRCodec.writeValue(out, encodeArtifact(artifact), 0); } - static CodPTACArtifact readArtifact(DataInput in) throws IOException { + static Artifact readArtifact(DataInput in) throws IOException { IRCodec.readHeader(in); Object value = IRCodec.readValue(in, 0); if (!(value instanceof Map)) { @@ -38,7 +38,7 @@ static CodPTACArtifact readArtifact(DataInput in) throws IOException { return decodeArtifact(castMap(value, "artifact")); } - private static Map encodeArtifact(CodPTACArtifact artifact) { + private static Map encodeArtifact(Artifact artifact) { Map out = new LinkedHashMap(); out.put("schemaVersion", Integer.valueOf(ARTIFACT_SCHEMA_VERSION)); out.put("version", Integer.valueOf(artifact.version)); @@ -49,7 +49,7 @@ private static Map encodeArtifact(CodPTACArtifact artifact) { return out; } - private static Map encodeUnit(CodPTACUnit unit) { + private static Map encodeUnit(Unit unit) { if (unit == null) return null; Map out = new LinkedHashMap(); out.put("unitName", unit.unitName); @@ -59,10 +59,10 @@ private static Map encodeUnit(CodPTACUnit unit) { return out; } - private static List encodeFunctions(List functions) { + private static List encodeFunctions(List functions) { if (functions == null) return null; List out = new ArrayList(functions.size()); - for (CodPTACFunction fn : functions) { + for (Function fn : functions) { Map value = new LinkedHashMap(); value.put("name", fn == null ? null : fn.name); value.put("parameters", fn == null ? null : new ArrayList(safeStringList(fn.parameters))); @@ -74,10 +74,10 @@ private static List encodeFunctions(List functions) { return out; } - private static List encodeInstructions(List instructions) { + private static List encodeInstructions(List instructions) { if (instructions == null) return null; List out = new ArrayList(instructions.size()); - for (CodPTACInstruction instruction : instructions) { + for (Instruction instruction : instructions) { Map value = new LinkedHashMap(); value.put("opcode", instruction == null || instruction.opcode == null ? null : instruction.opcode.name()); value.put("dest", instruction == null ? null : instruction.dest); @@ -88,10 +88,10 @@ private static List encodeInstructions(List instruct return out; } - private static List encodeOperands(List operands) { + private static List encodeOperands(List operands) { if (operands == null) return null; List out = new ArrayList(operands.size()); - for (CodPTACOperand operand : operands) { + for (Operand operand : operands) { Map value = new LinkedHashMap(); value.put("kind", operand == null || operand.kind == null ? null : operand.kind.name()); value.put("value", operand == null ? null : operand.value); @@ -100,22 +100,22 @@ private static List encodeOperands(List operands) { return out; } - private static List encodeFlags(EnumSet flags) { + private static List encodeFlags(EnumSet flags) { if (flags == null) return null; List out = new ArrayList(flags.size()); - for (CodPTACFlag flag : flags) { + for (Flag flag : flags) { out.add(flag == null ? null : flag.name()); } return out; } - private static CodPTACArtifact decodeArtifact(Map map) throws IOException { + private static Artifact decodeArtifact(Map map) throws IOException { int schemaVersion = readInt(map.get("schemaVersion"), "schemaVersion"); if (schemaVersion != ARTIFACT_SCHEMA_VERSION) { throw new IOException("Unsupported artifact schema version: " + schemaVersion); } - CodPTACArtifact artifact = new CodPTACArtifact(); + Artifact artifact = new Artifact(); artifact.version = readInt(map.get("version"), "version"); artifact.unitName = asString(map.get("unitName"), "unitName"); artifact.className = asString(map.get("className"), "className"); @@ -124,10 +124,10 @@ private static CodPTACArtifact decodeArtifact(Map map) throws IO return artifact; } - private static CodPTACUnit decodeUnit(Object value) throws IOException { + private static Unit decodeUnit(Object value) throws IOException { if (value == null) return null; Map map = castMap(value, "unit"); - CodPTACUnit unit = new CodPTACUnit(); + Unit unit = new Unit(); unit.unitName = asString(map.get("unitName"), "unit.unitName"); unit.className = asString(map.get("className"), "unit.className"); unit.entryFunction = asString(map.get("entryFunction"), "unit.entryFunction"); @@ -135,13 +135,13 @@ private static CodPTACUnit decodeUnit(Object value) throws IOException { return unit; } - private static List decodeFunctions(Object value) throws IOException { - if (value == null) return new ArrayList(); + private static List decodeFunctions(Object value) throws IOException { + if (value == null) return new ArrayList(); List list = castList(value, "unit.functions"); - List out = new ArrayList(list.size()); + List out = new ArrayList(list.size()); for (Object item : list) { Map map = castMap(item, "function"); - CodPTACFunction function = new CodPTACFunction(); + Function function = new Function(); function.name = asString(map.get("name"), "function.name"); function.parameters = asStringList(map.get("parameters"), "function.parameters"); function.instructions = decodeInstructions(map.get("instructions")); @@ -152,65 +152,65 @@ private static List decodeFunctions(Object value) throws IOExce return out; } - private static List decodeInstructions(Object value) throws IOException { - if (value == null) return new ArrayList(); + private static List decodeInstructions(Object value) throws IOException { + if (value == null) return new ArrayList(); List list = castList(value, "function.instructions"); - List out = new ArrayList(list.size()); + List out = new ArrayList(list.size()); for (Object item : list) { Map map = castMap(item, "instruction"); String opcodeName = asString(map.get("opcode"), "instruction.opcode"); - CodPTACOpcode opcode = parseEnum(CodPTACOpcode.class, opcodeName, "instruction.opcode"); + Opcode opcode = parseEnum(Opcode.class, opcodeName, "instruction.opcode"); String dest = asString(map.get("dest"), "instruction.dest"); - List operands = decodeOperands(map.get("operands")); - EnumSet flags = decodeFlags(map.get("flags")); - out.add(new CodPTACInstruction(opcode, dest, operands, flags)); + List operands = decodeOperands(map.get("operands")); + EnumSet flags = decodeFlags(map.get("flags")); + out.add(new Instruction(opcode, dest, operands, flags)); } return out; } - private static List decodeOperands(Object value) throws IOException { - if (value == null) return new ArrayList(); + private static List decodeOperands(Object value) throws IOException { + if (value == null) return new ArrayList(); List list = castList(value, "instruction.operands"); - List out = new ArrayList(list.size()); + List out = new ArrayList(list.size()); for (Object item : list) { Map map = castMap(item, "operand"); String kindName = asString(map.get("kind"), "operand.kind"); - CodPTACOperandKind kind = parseEnum(CodPTACOperandKind.class, kindName, "operand.kind"); + OperandKind kind = parseEnum(OperandKind.class, kindName, "operand.kind"); Object operandValue = map.get("value"); out.add(createOperand(kind, operandValue)); } return out; } - private static EnumSet decodeFlags(Object value) throws IOException { - EnumSet out = EnumSet.noneOf(CodPTACFlag.class); + private static EnumSet decodeFlags(Object value) throws IOException { + EnumSet out = EnumSet.noneOf(Flag.class); if (value == null) return out; List list = castList(value, "instruction.flags"); for (Object item : list) { String flagName = asString(item, "instruction.flag"); - CodPTACFlag flag = parseEnum(CodPTACFlag.class, flagName, "instruction.flag"); + Flag flag = parseEnum(Flag.class, flagName, "instruction.flag"); out.add(flag); } return out; } - private static CodPTACOperand createOperand(CodPTACOperandKind kind, Object value) throws IOException { + private static Operand createOperand(OperandKind kind, Object value) throws IOException { if (kind == null) { throw new IOException("operand.kind is null"); } switch (kind) { case REGISTER: - return CodPTACOperand.register(asString(value, "operand.value")); + return Operand.register(asString(value, "operand.value")); case IMMEDIATE: - return CodPTACOperand.immediate(value); + return Operand.immediate(value); case LABEL: - return CodPTACOperand.label(asString(value, "operand.value")); + return Operand.label(asString(value, "operand.value")); case FUNCTION: - return CodPTACOperand.function(asString(value, "operand.value")); + return Operand.function(asString(value, "operand.value")); case SLOT: - return CodPTACOperand.slot(asString(value, "operand.value")); + return Operand.slot(asString(value, "operand.value")); case IDENTIFIER: - return CodPTACOperand.identifier(asString(value, "operand.value")); + return Operand.identifier(asString(value, "operand.value")); default: throw new IOException("Unsupported operand kind: " + kind); } diff --git a/src/main/java/cod/ir/IRManager.java b/src/main/java/cod/ir/IRManager.java index 9bfe5f72..943988ed 100644 --- a/src/main/java/cod/ir/IRManager.java +++ b/src/main/java/cod/ir/IRManager.java @@ -1,9 +1,9 @@ package cod.ir; import cod.ast.node.Type; -import cod.ptac.CodPTACArtifact; -import cod.ptac.CodPTACCompiler; -import cod.ptac.CodPTACUnit; +import cod.ptac.Artifact; +import cod.ptac.Compiler; +import cod.ptac.Unit; import java.io.File; import java.io.IOException; @@ -18,16 +18,16 @@ public class IRManager { private final IRWriter writer; private final IRReader reader; private final Map> cache; - private final Map> artifactCache; - private final CodPTACCompiler compiler; + private final Map> artifactCache; + private final Compiler compiler; public IRManager(String projectRoot) { this.projectRoot = projectRoot; this.writer = new IRWriter(); this.reader = new IRReader(); this.cache = new HashMap>(); - this.artifactCache = new HashMap>(); - this.compiler = new CodPTACCompiler(); + this.artifactCache = new HashMap>(); + this.compiler = new Compiler(); } public Type load(String unit, String className) { @@ -49,7 +49,7 @@ public Type load(String unit, String className) { } try { - CodPTACArtifact artifact = reader.readArtifact(file); + Artifact artifact = reader.readArtifact(file); if (artifact != null) { putArtifactCache(unit, className, artifact); Type type = artifact.typeSnapshot; @@ -70,19 +70,19 @@ public void save(String unit, Type type) { } File file = getIRFile(unit, type.name); try { - CodPTACArtifact artifact = compiler.compile(unit, type); + Artifact artifact = compiler.compile(unit, type); writer.writeArtifact(file, artifact); putCache(unit, type.name, type); putArtifactCache(unit, type.name, artifact); } catch (IOException ignored) {} } - public CodPTACArtifact loadArtifact(String unit, String className) { + public Artifact loadArtifact(String unit, String className) { if (unit == null || className == null) { return null; } - Map unitCache = artifactCache.get(unit); + Map unitCache = artifactCache.get(unit); if (unitCache != null && unitCache.containsKey(className)) { return unitCache.get(className); } @@ -93,7 +93,7 @@ public CodPTACArtifact loadArtifact(String unit, String className) { } try { - CodPTACArtifact artifact = reader.readArtifact(file); + Artifact artifact = reader.readArtifact(file); if (artifact != null) { putArtifactCache(unit, className, artifact); if (artifact.typeSnapshot != null) { @@ -106,12 +106,12 @@ public CodPTACArtifact loadArtifact(String unit, String className) { } } - public CodPTACUnit loadCodPTACUnit(String unit, String className) { - CodPTACArtifact artifact = loadArtifact(unit, className); + public Unit loadCodPTACUnit(String unit, String className) { + Artifact artifact = loadArtifact(unit, className); return artifact != null ? artifact.unit : null; } - public void saveArtifact(String unit, CodPTACArtifact artifact) { + public void saveArtifact(String unit, Artifact artifact) { if (artifact == null || unit == null || artifact.className == null) return; File file = getIRFile(unit, artifact.className); try { @@ -137,7 +137,7 @@ public Map getCacheStats() { stats.put("units", cache.size()); stats.put("classes", total); int artifacts = 0; - for (Map unitArtifacts : artifactCache.values()) { + for (Map unitArtifacts : artifactCache.values()) { artifacts += unitArtifacts.size(); } stats.put("artifacts", artifacts); @@ -153,10 +153,10 @@ private void putCache(String unit, String className, Type type) { unitCache.put(className, type); } - private void putArtifactCache(String unit, String className, CodPTACArtifact artifact) { - Map unitCache = artifactCache.get(unit); + private void putArtifactCache(String unit, String className, Artifact artifact) { + Map unitCache = artifactCache.get(unit); if (unitCache == null) { - unitCache = new HashMap(); + unitCache = new HashMap(); artifactCache.put(unit, unitCache); } unitCache.put(className, artifact); diff --git a/src/main/java/cod/ir/IRReader.java b/src/main/java/cod/ir/IRReader.java index b3df4457..ad41d07e 100644 --- a/src/main/java/cod/ir/IRReader.java +++ b/src/main/java/cod/ir/IRReader.java @@ -1,7 +1,7 @@ package cod.ir; import cod.ast.node.Type; -import cod.ptac.CodPTACArtifact; +import cod.ptac.Artifact; import java.io.BufferedInputStream; import java.io.DataInputStream; @@ -11,14 +11,14 @@ public final class IRReader { public Type read(File file) throws IOException { - CodPTACArtifact artifact = readArtifact(file); + Artifact artifact = readArtifact(file); if (artifact == null) { return null; } return artifact.typeSnapshot; } - public CodPTACArtifact readArtifact(File file) throws IOException { + public Artifact readArtifact(File file) throws IOException { if (file == null) { throw new IOException("IR source file is null"); } diff --git a/src/main/java/cod/ir/IRWriter.java b/src/main/java/cod/ir/IRWriter.java index 689e6bef..9d1a2996 100644 --- a/src/main/java/cod/ir/IRWriter.java +++ b/src/main/java/cod/ir/IRWriter.java @@ -1,7 +1,7 @@ package cod.ir; import cod.ast.node.Type; -import cod.ptac.CodPTACArtifact; +import cod.ptac.Artifact; import java.io.BufferedOutputStream; import java.io.DataOutputStream; @@ -11,13 +11,13 @@ public final class IRWriter { public void write(File file, Type type) throws IOException { - CodPTACArtifact artifact = new CodPTACArtifact(); + Artifact artifact = new Artifact(); artifact.className = type != null ? type.name : null; artifact.typeSnapshot = type; writeArtifact(file, artifact); } - public void writeArtifact(File file, CodPTACArtifact artifact) throws IOException { + public void writeArtifact(File file, Artifact artifact) throws IOException { if (file == null) { throw new IOException("IR target file is null"); } diff --git a/src/main/java/cod/ptac/CodPTACArtifact.java b/src/main/java/cod/ptac/Artifact.java similarity index 84% rename from src/main/java/cod/ptac/CodPTACArtifact.java rename to src/main/java/cod/ptac/Artifact.java index 165f2f48..dda56721 100644 --- a/src/main/java/cod/ptac/CodPTACArtifact.java +++ b/src/main/java/cod/ptac/Artifact.java @@ -3,13 +3,13 @@ import cod.ast.node.Type; -public final class CodPTACArtifact { +public final class Artifact { public static final int FORMAT_VERSION = 1; public int version = FORMAT_VERSION; public String unitName; public String className; - public CodPTACUnit unit; + public Unit unit; public Type typeSnapshot; public boolean hasExecutableUnit() { diff --git a/src/main/java/cod/ptac/CodPTACCompiler.java b/src/main/java/cod/ptac/CodPTACCompiler.java deleted file mode 100644 index ae8b67b8..00000000 --- a/src/main/java/cod/ptac/CodPTACCompiler.java +++ /dev/null @@ -1,27 +0,0 @@ -package cod.ptac; - -import cod.ast.node.Type; - -public final class CodPTACCompiler { - private final CodPTACLowerer lowerer; - private final CodPTACOptimizer optimizer; - - public CodPTACCompiler() { - this(false); - } - - public CodPTACCompiler(boolean enableOptionalLowering) { - this.lowerer = new CodPTACLowerer(); - this.optimizer = new CodPTACOptimizer(enableOptionalLowering); - } - - public CodPTACArtifact compile(String unitName, Type type) { - CodPTACArtifact artifact = new CodPTACArtifact(); - artifact.version = CodPTACArtifact.FORMAT_VERSION; - artifact.unitName = unitName; - artifact.className = type != null ? type.name : null; - artifact.typeSnapshot = type; - artifact.unit = optimizer.optimize(lowerer.lower(unitName, type)); - return artifact; - } -} diff --git a/src/main/java/cod/ptac/CodPTACInstruction.java b/src/main/java/cod/ptac/CodPTACInstruction.java deleted file mode 100644 index eeef6ffc..00000000 --- a/src/main/java/cod/ptac/CodPTACInstruction.java +++ /dev/null @@ -1,34 +0,0 @@ -package cod.ptac; - -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; - -public final class CodPTACInstruction { - public final CodPTACOpcode opcode; - public final String dest; - public final List operands; - public final EnumSet flags; - - public CodPTACInstruction( - CodPTACOpcode opcode, - String dest, - List operands, - EnumSet flags - ) { - this.opcode = opcode; - this.dest = dest; - this.operands = operands != null ? operands : new ArrayList(); - this.flags = flags != null ? flags : EnumSet.noneOf(CodPTACFlag.class); - } - - public CodPTACInstruction(CodPTACOpcode opcode, String dest, List operands) { - this(opcode, dest, operands, null); - } - - public CodPTACInstruction withFlag(CodPTACFlag flag) { - EnumSet copy = EnumSet.copyOf(this.flags); - copy.add(flag); - return new CodPTACInstruction(this.opcode, this.dest, this.operands, copy); - } -} diff --git a/src/main/java/cod/ptac/CodPTACOpcode.java b/src/main/java/cod/ptac/CodPTACOpcode.java deleted file mode 100644 index aeb2caa2..00000000 --- a/src/main/java/cod/ptac/CodPTACOpcode.java +++ /dev/null @@ -1,78 +0,0 @@ -package cod.ptac; - - -public enum CodPTACOpcode { - // Core TAC - NOP(CodPTACLayer.CORE_TAC), - ASSIGN(CodPTACLayer.CORE_TAC), - ADD(CodPTACLayer.CORE_TAC), - SUB(CodPTACLayer.CORE_TAC), - MUL(CodPTACLayer.CORE_TAC), - DIV(CodPTACLayer.CORE_TAC), - MOD(CodPTACLayer.CORE_TAC), - EQ(CodPTACLayer.CORE_TAC), - NE(CodPTACLayer.CORE_TAC), - GT(CodPTACLayer.CORE_TAC), - LT(CodPTACLayer.CORE_TAC), - GTE(CodPTACLayer.CORE_TAC), - LTE(CodPTACLayer.CORE_TAC), - BRANCH(CodPTACLayer.CORE_TAC), - BRANCH_IF(CodPTACLayer.CORE_TAC), - CALL(CodPTACLayer.CORE_TAC), - RETURN(CodPTACLayer.CORE_TAC), - LOAD(CodPTACLayer.CORE_TAC), - STORE(CodPTACLayer.CORE_TAC), - - // Coderive pattern ops - RANGE(CodPTACLayer.PATTERN), - RANGE_Q(CodPTACLayer.PATTERN), - RANGE_S(CodPTACLayer.PATTERN), - RANGE_L(CodPTACLayer.PATTERN), - RANGE_LS(CodPTACLayer.PATTERN), - MAP(CodPTACLayer.PATTERN), - FILTER(CodPTACLayer.PATTERN), - REDUCE(CodPTACLayer.PATTERN), - WHERE(CodPTACLayer.PATTERN), - SCAN(CodPTACLayer.PATTERN), - ZIP(CodPTACLayer.PATTERN), - TAKE(CodPTACLayer.PATTERN), - FILTER_MAP(CodPTACLayer.PATTERN), - FILTER_MAP_REDUCE(CodPTACLayer.PATTERN), - - // Lambda / recursion / closures - LAMBDA(CodPTACLayer.PATTERN), - CLOSURE(CodPTACLayer.PATTERN), - ANCESTOR(CodPTACLayer.PATTERN), - SELF(CodPTACLayer.PATTERN), - TAIL_CALL(CodPTACLayer.PATTERN), - - // Slot ops - SLOT_GET(CodPTACLayer.PATTERN), - SLOT_SET(CodPTACLayer.PATTERN), - SLOT_RET(CodPTACLayer.PATTERN), - SLOT_UNPACK(CodPTACLayer.PATTERN), - SLOT_DIV(CodPTACLayer.PATTERN), - - // Lazy ops - LAZY_GET(CodPTACLayer.PATTERN), - LAZY_SET(CodPTACLayer.PATTERN), - LAZY_COMMIT(CodPTACLayer.PATTERN), - LAZY_SIZE(CodPTACLayer.PATTERN), - LAZY_SLICE(CodPTACLayer.PATTERN), - - // Formula ops - FORMULA_SEQ(CodPTACLayer.PATTERN), - FORMULA_COND(CodPTACLayer.PATTERN), - FORMULA_RECUR(CodPTACLayer.PATTERN), - FORMULA_FUSE(CodPTACLayer.PATTERN); - - private final CodPTACLayer layer; - - CodPTACOpcode(CodPTACLayer layer) { - this.layer = layer; - } - - public CodPTACLayer getLayer() { - return layer; - } -} diff --git a/src/main/java/cod/ptac/CodPTACOperand.java b/src/main/java/cod/ptac/CodPTACOperand.java deleted file mode 100644 index b777fafd..00000000 --- a/src/main/java/cod/ptac/CodPTACOperand.java +++ /dev/null @@ -1,36 +0,0 @@ -package cod.ptac; - - -public final class CodPTACOperand { - public final CodPTACOperandKind kind; - public final Object value; - - private CodPTACOperand(CodPTACOperandKind kind, Object value) { - this.kind = kind; - this.value = value; - } - - public static CodPTACOperand register(String name) { - return new CodPTACOperand(CodPTACOperandKind.REGISTER, name); - } - - public static CodPTACOperand immediate(Object value) { - return new CodPTACOperand(CodPTACOperandKind.IMMEDIATE, value); - } - - public static CodPTACOperand label(String name) { - return new CodPTACOperand(CodPTACOperandKind.LABEL, name); - } - - public static CodPTACOperand function(String name) { - return new CodPTACOperand(CodPTACOperandKind.FUNCTION, name); - } - - public static CodPTACOperand slot(String name) { - return new CodPTACOperand(CodPTACOperandKind.SLOT, name); - } - - public static CodPTACOperand identifier(String name) { - return new CodPTACOperand(CodPTACOperandKind.IDENTIFIER, name); - } -} diff --git a/src/main/java/cod/ptac/CodPTACOptimizer.java b/src/main/java/cod/ptac/CodPTACOptimizer.java deleted file mode 100644 index 5142344f..00000000 --- a/src/main/java/cod/ptac/CodPTACOptimizer.java +++ /dev/null @@ -1,31 +0,0 @@ -package cod.ptac; - -import cod.ptac.opt.*; - -import java.util.ArrayList; -import java.util.List; - -public final class CodPTACOptimizer { - private final List passes; - - public CodPTACOptimizer() { - this(false); - } - - public CodPTACOptimizer(boolean enableOptionalLowering) { - this.passes = new ArrayList(); - this.passes.add(new CodPTACPatternFusionPass()); - this.passes.add(new CodPTACLazyRangePropagationPass()); - this.passes.add(new CodPTACConstantFoldingPass()); - this.passes.add(new CodPTACDeadTempEliminationPass()); - this.passes.add(new CodPTACOptionalPatternLoweringPass(enableOptionalLowering)); - } - - public CodPTACUnit optimize(CodPTACUnit unit) { - if (unit == null) return null; - for (CodPTACOptimizationPass pass : passes) { - pass.apply(unit); - } - return unit; - } -} diff --git a/src/main/java/cod/ptac/Compiler.java b/src/main/java/cod/ptac/Compiler.java new file mode 100644 index 00000000..226ccd6a --- /dev/null +++ b/src/main/java/cod/ptac/Compiler.java @@ -0,0 +1,27 @@ +package cod.ptac; + +import cod.ast.node.Type; + +public final class Compiler { + private final Lowerer lowerer; + private final Optimizer optimizer; + + public Compiler() { + this(false); + } + + public Compiler(boolean enableOptionalLowering) { + this.lowerer = new Lowerer(); + this.optimizer = new Optimizer(enableOptionalLowering); + } + + public Artifact compile(String unitName, Type type) { + Artifact artifact = new Artifact(); + artifact.version = Artifact.FORMAT_VERSION; + artifact.unitName = unitName; + artifact.className = type != null ? type.name : null; + artifact.typeSnapshot = type; + artifact.unit = optimizer.optimize(lowerer.lower(unitName, type)); + return artifact; + } +} diff --git a/src/main/java/cod/ptac/CodPTACExecutor.java b/src/main/java/cod/ptac/Executor.java similarity index 58% rename from src/main/java/cod/ptac/CodPTACExecutor.java rename to src/main/java/cod/ptac/Executor.java index 0ad9d9df..4bbd28af 100644 --- a/src/main/java/cod/ptac/CodPTACExecutor.java +++ b/src/main/java/cod/ptac/Executor.java @@ -10,27 +10,27 @@ import java.util.List; import java.util.Map; -public final class CodPTACExecutor { - private final CodPTACOptions options; +public final class Executor { + private final Options options; private static final Object FALLBACK_SENTINEL = new Object(); - private static final class CodPTACRange { + private static final class Range { final BigInteger start; final BigInteger end; final BigInteger step; - CodPTACRange(BigInteger start, BigInteger end, BigInteger step) { + Range(BigInteger start, BigInteger end, BigInteger step) { this.start = start; this.end = end; this.step = step; } } - public CodPTACExecutor(CodPTACOptions options) { - this.options = options != null ? options : CodPTACOptions.current(); + public Executor(Options options) { + this.options = options != null ? options : Options.current(); } - public Object execute(CodPTACArtifact artifact, Interpreter fallbackInterpreter) { + public Object execute(Artifact artifact, Interpreter fallbackInterpreter) { if (artifact == null) { throw new ProgramError("Cannot execute null CodP-TAC artifact"); } @@ -39,7 +39,7 @@ public Object execute(CodPTACArtifact artifact, Interpreter fallbackInterpreter) return fallback(artifact, fallbackInterpreter, "No executable CodP-TAC unit in artifact"); } - CodPTACFunction entry = findEntry(artifact.unit); + Function entry = findEntry(artifact.unit); if (entry == null) { return fallback(artifact, fallbackInterpreter, "No entry function found in CodP-TAC unit"); } @@ -48,11 +48,11 @@ public Object execute(CodPTACArtifact artifact, Interpreter fallbackInterpreter) } private Object executeFunction( - CodPTACUnit unit, - CodPTACFunction function, + Unit unit, + Function function, List args, Interpreter fallbackInterpreter, - CodPTACArtifact artifact + Artifact artifact ) { Map registers = new HashMap(); if (function.parameters != null) { @@ -64,13 +64,13 @@ private Object executeFunction( if (function.instructions == null) return null; - for (CodPTACInstruction inst : function.instructions) { + for (Instruction inst : function.instructions) { if (inst == null) continue; Object result = runInstruction(unit, inst, registers, fallbackInterpreter, artifact); if (result == FALLBACK_SENTINEL) { return FALLBACK_SENTINEL; } - if (inst.opcode == CodPTACOpcode.RETURN) { + if (inst.opcode == Opcode.RETURN) { return result; } } @@ -78,13 +78,13 @@ private Object executeFunction( } private Object runInstruction( - CodPTACUnit unit, - CodPTACInstruction inst, + Unit unit, + Instruction inst, Map registers, Interpreter fallbackInterpreter, - CodPTACArtifact artifact + Artifact artifact ) { - if (inst.opcode == CodPTACOpcode.ASSIGN) { + if (inst.opcode == Opcode.ASSIGN) { Object value = operandValue(inst.operands, 0, registers); registers.put(inst.dest, value); return value; @@ -106,65 +106,65 @@ private Object runInstruction( return out; } - if (inst.opcode == CodPTACOpcode.RANGE - || inst.opcode == CodPTACOpcode.RANGE_Q - || inst.opcode == CodPTACOpcode.RANGE_S - || inst.opcode == CodPTACOpcode.RANGE_L - || inst.opcode == CodPTACOpcode.RANGE_LS) { + if (inst.opcode == Opcode.RANGE + || inst.opcode == Opcode.RANGE_Q + || inst.opcode == Opcode.RANGE_S + || inst.opcode == Opcode.RANGE_L + || inst.opcode == Opcode.RANGE_LS) { Object start = operandValue(inst.operands, 0, registers); Object end = operandValue(inst.operands, 1, registers); Object stepVal = inst.operands != null && inst.operands.size() > 2 ? operandValue(inst.operands, 2, registers) : 1; - CodPTACRange range = new CodPTACRange(toBigInt(start), toBigInt(end), toBigInt(stepVal)); + Range range = new Range(toBigInt(start), toBigInt(end), toBigInt(stepVal)); if (inst.dest != null) registers.put(inst.dest, range); return range; } - if (inst.opcode == CodPTACOpcode.TAKE) { - CodPTACRange source = asRange(operandValue(inst.operands, 0, registers)); + if (inst.opcode == Opcode.TAKE) { + Range source = asRange(operandValue(inst.operands, 0, registers)); BigInteger n = toBigInt(operandValue(inst.operands, 1, registers)); List out = take(source, n); if (inst.dest != null) registers.put(inst.dest, out); return out; } - if (inst.opcode == CodPTACOpcode.FILTER - || inst.opcode == CodPTACOpcode.MAP - || inst.opcode == CodPTACOpcode.FILTER_MAP - || inst.opcode == CodPTACOpcode.REDUCE - || inst.opcode == CodPTACOpcode.SCAN - || inst.opcode == CodPTACOpcode.ZIP - || inst.opcode == CodPTACOpcode.WHERE - || inst.opcode == CodPTACOpcode.FILTER_MAP_REDUCE - || inst.opcode == CodPTACOpcode.LAZY_GET - || inst.opcode == CodPTACOpcode.LAZY_SET - || inst.opcode == CodPTACOpcode.LAZY_COMMIT - || inst.opcode == CodPTACOpcode.LAZY_SIZE - || inst.opcode == CodPTACOpcode.LAZY_SLICE - || inst.opcode == CodPTACOpcode.SLOT_GET - || inst.opcode == CodPTACOpcode.SLOT_SET - || inst.opcode == CodPTACOpcode.SLOT_RET - || inst.opcode == CodPTACOpcode.SLOT_UNPACK - || inst.opcode == CodPTACOpcode.SLOT_DIV - || inst.opcode == CodPTACOpcode.ANCESTOR - || inst.opcode == CodPTACOpcode.SELF - || inst.opcode == CodPTACOpcode.TAIL_CALL - || inst.opcode == CodPTACOpcode.CLOSURE - || inst.opcode == CodPTACOpcode.FORMULA_SEQ - || inst.opcode == CodPTACOpcode.FORMULA_COND - || inst.opcode == CodPTACOpcode.FORMULA_RECUR - || inst.opcode == CodPTACOpcode.FORMULA_FUSE - || inst.opcode == CodPTACOpcode.STORE - || inst.opcode == CodPTACOpcode.LOAD - || inst.opcode == CodPTACOpcode.BRANCH - || inst.opcode == CodPTACOpcode.BRANCH_IF) { + if (inst.opcode == Opcode.FILTER + || inst.opcode == Opcode.MAP + || inst.opcode == Opcode.FILTER_MAP + || inst.opcode == Opcode.REDUCE + || inst.opcode == Opcode.SCAN + || inst.opcode == Opcode.ZIP + || inst.opcode == Opcode.WHERE + || inst.opcode == Opcode.FILTER_MAP_REDUCE + || inst.opcode == Opcode.LAZY_GET + || inst.opcode == Opcode.LAZY_SET + || inst.opcode == Opcode.LAZY_COMMIT + || inst.opcode == Opcode.LAZY_SIZE + || inst.opcode == Opcode.LAZY_SLICE + || inst.opcode == Opcode.SLOT_GET + || inst.opcode == Opcode.SLOT_SET + || inst.opcode == Opcode.SLOT_RET + || inst.opcode == Opcode.SLOT_UNPACK + || inst.opcode == Opcode.SLOT_DIV + || inst.opcode == Opcode.ANCESTOR + || inst.opcode == Opcode.SELF + || inst.opcode == Opcode.TAIL_CALL + || inst.opcode == Opcode.CLOSURE + || inst.opcode == Opcode.FORMULA_SEQ + || inst.opcode == Opcode.FORMULA_COND + || inst.opcode == Opcode.FORMULA_RECUR + || inst.opcode == Opcode.FORMULA_FUSE + || inst.opcode == Opcode.STORE + || inst.opcode == Opcode.LOAD + || inst.opcode == Opcode.BRANCH + || inst.opcode == Opcode.BRANCH_IF) { return fallback(artifact, fallbackInterpreter, "Opcode not yet natively executed: " + inst.opcode); } - if (inst.opcode == CodPTACOpcode.CALL) { + if (inst.opcode == Opcode.CALL) { String functionName = String.valueOf(operandValue(inst.operands, 0, registers)); - CodPTACFunction target = findFunction(unit, functionName); + Function target = findFunction(unit, functionName); if (target == null) { return fallback(artifact, fallbackInterpreter, "Unknown function: " + functionName); } @@ -180,14 +180,14 @@ private Object runInstruction( return result; } - if (inst.opcode == CodPTACOpcode.RETURN) { + if (inst.opcode == Opcode.RETURN) { return operandValue(inst.operands, 0, registers); } return null; } - private Object fallback(CodPTACArtifact artifact, Interpreter fallbackInterpreter, String reason) { + private Object fallback(Artifact artifact, Interpreter fallbackInterpreter, String reason) { if (!options.isFallbackEnabled()) { throw new ProgramError("CodP-TAC execution failed without fallback: " + reason); } @@ -205,62 +205,62 @@ private Object fallback(CodPTACArtifact artifact, Interpreter fallbackInterprete throw new ProgramError("CodP-TAC fallback unavailable: " + reason); } - private CodPTACFunction findEntry(CodPTACUnit unit) { + private Function findEntry(Unit unit) { if (unit.entryFunction != null) { - CodPTACFunction explicit = findFunction(unit, unit.entryFunction); + Function explicit = findFunction(unit, unit.entryFunction); if (explicit != null) return explicit; } return findFunction(unit, "main"); } - private CodPTACFunction findFunction(CodPTACUnit unit, String name) { + private Function findFunction(Unit unit, String name) { if (unit == null || unit.functions == null || name == null) return null; - for (CodPTACFunction fn : unit.functions) { + for (Function fn : unit.functions) { if (fn != null && name.equals(fn.name)) return fn; } return null; } - private Object operandValue(List operands, int index, Map registers) { + private Object operandValue(List operands, int index, Map registers) { if (operands == null || index >= operands.size()) return null; - CodPTACOperand operand = operands.get(index); + Operand operand = operands.get(index); if (operand == null) return null; - if (operand.kind == CodPTACOperandKind.REGISTER && operand.value instanceof String) { + if (operand.kind == OperandKind.REGISTER && operand.value instanceof String) { return registers.get(operand.value); } return operand.value; } - private boolean isMath(CodPTACOpcode opcode) { - return opcode == CodPTACOpcode.ADD - || opcode == CodPTACOpcode.SUB - || opcode == CodPTACOpcode.MUL - || opcode == CodPTACOpcode.DIV - || opcode == CodPTACOpcode.MOD; + private boolean isMath(Opcode opcode) { + return opcode == Opcode.ADD + || opcode == Opcode.SUB + || opcode == Opcode.MUL + || opcode == Opcode.DIV + || opcode == Opcode.MOD; } - private boolean isCompare(CodPTACOpcode opcode) { - return opcode == CodPTACOpcode.EQ - || opcode == CodPTACOpcode.NE - || opcode == CodPTACOpcode.GT - || opcode == CodPTACOpcode.LT - || opcode == CodPTACOpcode.GTE - || opcode == CodPTACOpcode.LTE; + private boolean isCompare(Opcode opcode) { + return opcode == Opcode.EQ + || opcode == Opcode.NE + || opcode == Opcode.GT + || opcode == Opcode.LT + || opcode == Opcode.GTE + || opcode == Opcode.LTE; } - private Object evaluateMath(CodPTACOpcode opcode, Object a, Object b) { + private Object evaluateMath(Opcode opcode, Object a, Object b) { BigInteger left = toBigInt(a); BigInteger right = toBigInt(b); - if (opcode == CodPTACOpcode.ADD) return left.add(right); - if (opcode == CodPTACOpcode.SUB) return left.subtract(right); - if (opcode == CodPTACOpcode.MUL) return left.multiply(right); - if (opcode == CodPTACOpcode.DIV) { + if (opcode == Opcode.ADD) return left.add(right); + if (opcode == Opcode.SUB) return left.subtract(right); + if (opcode == Opcode.MUL) return left.multiply(right); + if (opcode == Opcode.DIV) { if (right.equals(BigInteger.ZERO)) { throw new ProgramError("CodP-TAC division by zero"); } return left.divide(right); } - if (opcode == CodPTACOpcode.MOD) { + if (opcode == Opcode.MOD) { if (right.equals(BigInteger.ZERO)) { throw new ProgramError("CodP-TAC modulo by zero"); } @@ -269,25 +269,25 @@ private Object evaluateMath(CodPTACOpcode opcode, Object a, Object b) { return BigInteger.ZERO; } - private Boolean evaluateCompare(CodPTACOpcode opcode, Object a, Object b) { + private Boolean evaluateCompare(Opcode opcode, Object a, Object b) { BigInteger left = toBigInt(a); BigInteger right = toBigInt(b); int cmp = left.compareTo(right); - if (opcode == CodPTACOpcode.EQ) return cmp == 0; - if (opcode == CodPTACOpcode.NE) return cmp != 0; - if (opcode == CodPTACOpcode.GT) return cmp > 0; - if (opcode == CodPTACOpcode.LT) return cmp < 0; - if (opcode == CodPTACOpcode.GTE) return cmp >= 0; - if (opcode == CodPTACOpcode.LTE) return cmp <= 0; + if (opcode == Opcode.EQ) return cmp == 0; + if (opcode == Opcode.NE) return cmp != 0; + if (opcode == Opcode.GT) return cmp > 0; + if (opcode == Opcode.LT) return cmp < 0; + if (opcode == Opcode.GTE) return cmp >= 0; + if (opcode == Opcode.LTE) return cmp <= 0; return false; } - private CodPTACRange asRange(Object value) { - if (value instanceof CodPTACRange) return (CodPTACRange) value; - return new CodPTACRange(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); + private Range asRange(Object value) { + if (value instanceof Range) return (Range) value; + return new Range(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); } - private List take(CodPTACRange range, BigInteger n) { + private List take(Range range, BigInteger n) { List out = new ArrayList(); if (range == null || n == null || n.compareTo(BigInteger.ZERO) <= 0) return out; diff --git a/src/main/java/cod/ptac/CodPTACFlag.java b/src/main/java/cod/ptac/Flag.java similarity index 72% rename from src/main/java/cod/ptac/CodPTACFlag.java rename to src/main/java/cod/ptac/Flag.java index 5ff2f640..1f7cb675 100644 --- a/src/main/java/cod/ptac/CodPTACFlag.java +++ b/src/main/java/cod/ptac/Flag.java @@ -1,7 +1,7 @@ package cod.ptac; -public enum CodPTACFlag { +public enum Flag { LAZY, TAIL, PURE, diff --git a/src/main/java/cod/ptac/CodPTACFunction.java b/src/main/java/cod/ptac/Function.java similarity index 63% rename from src/main/java/cod/ptac/CodPTACFunction.java rename to src/main/java/cod/ptac/Function.java index f95a572e..a32de63b 100644 --- a/src/main/java/cod/ptac/CodPTACFunction.java +++ b/src/main/java/cod/ptac/Function.java @@ -3,10 +3,10 @@ import java.util.ArrayList; import java.util.List; -public final class CodPTACFunction { +public final class Function { public String name; public List parameters = new ArrayList(); - public List instructions = new ArrayList(); + public List instructions = new ArrayList(); public boolean lambdaBlock; public int closureLevel; } diff --git a/src/main/java/cod/ptac/Instruction.java b/src/main/java/cod/ptac/Instruction.java new file mode 100644 index 00000000..5aa8fedc --- /dev/null +++ b/src/main/java/cod/ptac/Instruction.java @@ -0,0 +1,34 @@ +package cod.ptac; + +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; + +public final class Instruction { + public final Opcode opcode; + public final String dest; + public final List operands; + public final EnumSet flags; + + public Instruction( + Opcode opcode, + String dest, + List operands, + EnumSet flags + ) { + this.opcode = opcode; + this.dest = dest; + this.operands = operands != null ? operands : new ArrayList(); + this.flags = flags != null ? flags : EnumSet.noneOf(Flag.class); + } + + public Instruction(Opcode opcode, String dest, List operands) { + this(opcode, dest, operands, null); + } + + public Instruction withFlag(Flag flag) { + EnumSet copy = EnumSet.copyOf(this.flags); + copy.add(flag); + return new Instruction(this.opcode, this.dest, this.operands, copy); + } +} diff --git a/src/main/java/cod/ptac/CodPTACLayer.java b/src/main/java/cod/ptac/Layer.java similarity index 64% rename from src/main/java/cod/ptac/CodPTACLayer.java rename to src/main/java/cod/ptac/Layer.java index 737906e9..5b18c15a 100644 --- a/src/main/java/cod/ptac/CodPTACLayer.java +++ b/src/main/java/cod/ptac/Layer.java @@ -1,7 +1,7 @@ package cod.ptac; -public enum CodPTACLayer { +public enum Layer { CORE_TAC, PATTERN } diff --git a/src/main/java/cod/ptac/CodPTACLowerer.java b/src/main/java/cod/ptac/Lowerer.java similarity index 57% rename from src/main/java/cod/ptac/CodPTACLowerer.java rename to src/main/java/cod/ptac/Lowerer.java index 2465a17e..bac3a95d 100644 --- a/src/main/java/cod/ptac/CodPTACLowerer.java +++ b/src/main/java/cod/ptac/Lowerer.java @@ -8,13 +8,13 @@ import java.util.Arrays; import java.util.List; -public final class CodPTACLowerer { +public final class Lowerer { private int tempCounter = 0; private int patternCounter = 0; private int lambdaCounter = 0; - public CodPTACUnit lower(String unitName, Type type) { - CodPTACUnit unit = new CodPTACUnit(); + public Unit lower(String unitName, Type type) { + Unit unit = new Unit(); unit.unitName = unitName; unit.className = type != null ? type.name : null; @@ -34,8 +34,8 @@ public CodPTACUnit lower(String unitName, Type type) { return unit; } - private CodPTACFunction lowerMethod(Method method, CodPTACUnit unit) { - CodPTACFunction fn = new CodPTACFunction(); + private Function lowerMethod(Method method, Unit unit) { + Function fn = new Function(); fn.name = method != null ? method.methodName : "anonymous"; if (method != null && method.parameters != null) { for (Param param : method.parameters) { @@ -49,22 +49,22 @@ private CodPTACFunction lowerMethod(Method method, CodPTACUnit unit) { for (Stmt stmt : method.body) { lowerStmt(stmt, fn, unit); } - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.RETURN, + fn.instructions.add(new Instruction( + Opcode.RETURN, null, - Arrays.asList(CodPTACOperand.immediate(null)) + Arrays.asList(Operand.immediate(null)) )); return fn; } - private void lowerStmt(Stmt stmt, CodPTACFunction fn, CodPTACUnit unit) { + private void lowerStmt(Stmt stmt, Function fn, Unit unit) { if (stmt == null) return; if (stmt instanceof Var) { Var var = (Var) stmt; - CodPTACOperand value = lowerExpr(var.value, fn, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.ASSIGN, + Operand value = lowerExpr(var.value, fn, unit); + fn.instructions.add(new Instruction( + Opcode.ASSIGN, var.name, Arrays.asList(value) )); @@ -73,19 +73,19 @@ private void lowerStmt(Stmt stmt, CodPTACFunction fn, CodPTACUnit unit) { if (stmt instanceof Assignment) { Assignment assign = (Assignment) stmt; - CodPTACOperand rhs = lowerExpr(assign.right, fn, unit); + Operand rhs = lowerExpr(assign.right, fn, unit); if (assign.left instanceof Identifier) { - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.ASSIGN, + fn.instructions.add(new Instruction( + Opcode.ASSIGN, ((Identifier) assign.left).name, Arrays.asList(rhs) )); } else if (assign.left instanceof IndexAccess) { IndexAccess access = (IndexAccess) assign.left; - CodPTACOperand arr = lowerExpr(access.array, fn, unit); - CodPTACOperand idx = lowerExpr(access.index, fn, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.LAZY_SET, + Operand arr = lowerExpr(access.array, fn, unit); + Operand idx = lowerExpr(access.index, fn, unit); + fn.instructions.add(new Instruction( + Opcode.LAZY_SET, null, Arrays.asList(arr, idx, rhs) )); @@ -95,11 +95,11 @@ private void lowerStmt(Stmt stmt, CodPTACFunction fn, CodPTACUnit unit) { if (stmt instanceof SlotAssignment) { SlotAssignment slot = (SlotAssignment) stmt; - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.SLOT_SET, + fn.instructions.add(new Instruction( + Opcode.SLOT_SET, null, Arrays.asList( - CodPTACOperand.slot(slot.slotName), + Operand.slot(slot.slotName), lowerExpr(slot.value, fn, unit) ) )); @@ -119,16 +119,16 @@ private void lowerStmt(Stmt stmt, CodPTACFunction fn, CodPTACUnit unit) { if (stmt instanceof ReturnSlotAssignment) { ReturnSlotAssignment ret = (ReturnSlotAssignment) stmt; if (ret.methodCall != null) { - CodPTACOperand result = lowerExpr(ret.methodCall, fn, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.SLOT_UNPACK, + Operand result = lowerExpr(ret.methodCall, fn, unit); + fn.instructions.add(new Instruction( + Opcode.SLOT_UNPACK, null, Arrays.asList(result) )); } else if (ret.lambda != null) { - CodPTACOperand lambdaReg = lowerExpr(ret.lambda, fn, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.SLOT_UNPACK, + Operand lambdaReg = lowerExpr(ret.lambda, fn, unit); + fn.instructions.add(new Instruction( + Opcode.SLOT_UNPACK, null, Arrays.asList(lambdaReg) )); @@ -145,36 +145,36 @@ private void lowerStmt(Stmt stmt, CodPTACFunction fn, CodPTACUnit unit) { StmtIf ifStmt = (StmtIf) stmt; String thenLabel = "L_then_" + nextTemp(); String endLabel = "L_end_" + nextTemp(); - CodPTACOperand cond = lowerExpr(ifStmt.condition, fn, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.BRANCH_IF, + Operand cond = lowerExpr(ifStmt.condition, fn, unit); + fn.instructions.add(new Instruction( + Opcode.BRANCH_IF, null, - Arrays.asList(cond, CodPTACOperand.label(thenLabel)) + Arrays.asList(cond, Operand.label(thenLabel)) )); if (ifStmt.elseBlock != null && ifStmt.elseBlock.statements != null) { for (Stmt elseStmt : ifStmt.elseBlock.statements) { lowerStmt(elseStmt, fn, unit); } } - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.BRANCH, + fn.instructions.add(new Instruction( + Opcode.BRANCH, null, - Arrays.asList(CodPTACOperand.label(endLabel)) + Arrays.asList(Operand.label(endLabel)) )); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.NOP, + fn.instructions.add(new Instruction( + Opcode.NOP, thenLabel, - new ArrayList() + new ArrayList() )); if (ifStmt.thenBlock != null && ifStmt.thenBlock.statements != null) { for (Stmt thenStmt : ifStmt.thenBlock.statements) { lowerStmt(thenStmt, fn, unit); } } - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.NOP, + fn.instructions.add(new Instruction( + Opcode.NOP, endLabel, - new ArrayList() + new ArrayList() )); return; } @@ -190,10 +190,10 @@ private void lowerStmt(Stmt stmt, CodPTACFunction fn, CodPTACUnit unit) { } if (stmt instanceof Exit) { - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.RETURN, + fn.instructions.add(new Instruction( + Opcode.RETURN, null, - Arrays.asList(CodPTACOperand.immediate(null)) + Arrays.asList(Operand.immediate(null)) )); return; } @@ -203,21 +203,21 @@ private void lowerStmt(Stmt stmt, CodPTACFunction fn, CodPTACUnit unit) { return; } - fn.instructions.add(new CodPTACInstruction(CodPTACOpcode.NOP, null, new ArrayList())); + fn.instructions.add(new Instruction(Opcode.NOP, null, new ArrayList())); } - private void lowerFor(For node, CodPTACFunction fn, CodPTACUnit unit) { + private void lowerFor(For node, Function fn, Unit unit) { if (node == null || node.range == null) return; String rangeReg = nextPattern(); - CodPTACOpcode rangeOpcode = selectRangeOpcode(node.range); - List rangeOps = new ArrayList(); + Opcode rangeOpcode = selectRangeOpcode(node.range); + List rangeOps = new ArrayList(); rangeOps.add(lowerExpr(node.range.start, fn, unit)); rangeOps.add(lowerExpr(node.range.end, fn, unit)); if (node.range.step != null) { rangeOps.add(lowerExpr(node.range.step, fn, unit)); } - fn.instructions.add(new CodPTACInstruction(rangeOpcode, rangeReg, rangeOps)); + fn.instructions.add(new Instruction(rangeOpcode, rangeReg, rangeOps)); List body = node.body != null ? node.body.statements : null; if (body == null || body.isEmpty()) return; @@ -231,13 +231,13 @@ private void lowerFor(For node, CodPTACFunction fn, CodPTACUnit unit) { if (seq != null && seq.isOptimizable() && cond != null && cond.isOptimizable()) { String condLambda = lowerConditionLambda(cond, unit); String mapLambda = lowerSequenceLambda(seq, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.FILTER_MAP, + fn.instructions.add(new Instruction( + Opcode.FILTER_MAP, nextPattern(), Arrays.asList( - CodPTACOperand.register(rangeReg), - CodPTACOperand.function(condLambda), - CodPTACOperand.function(mapLambda) + Operand.register(rangeReg), + Operand.function(condLambda), + Operand.function(mapLambda) ) )); return; @@ -245,20 +245,20 @@ private void lowerFor(For node, CodPTACFunction fn, CodPTACUnit unit) { if (cond != null && cond.isOptimizable()) { String condLambda = lowerConditionLambda(cond, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.FILTER, + fn.instructions.add(new Instruction( + Opcode.FILTER, nextPattern(), - Arrays.asList(CodPTACOperand.register(rangeReg), CodPTACOperand.function(condLambda)) + Arrays.asList(Operand.register(rangeReg), Operand.function(condLambda)) )); return; } if (seq != null && seq.isOptimizable()) { String lambdaName = lowerSequenceLambda(seq, unit); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.MAP, + fn.instructions.add(new Instruction( + Opcode.MAP, nextPattern(), - Arrays.asList(CodPTACOperand.register(rangeReg), CodPTACOperand.function(lambdaName)) + Arrays.asList(Operand.register(rangeReg), Operand.function(lambdaName)) )); return; } @@ -268,9 +268,9 @@ private void lowerFor(For node, CodPTACFunction fn, CodPTACUnit unit) { } } - private String lowerConditionLambda(ConditionalPattern pattern, CodPTACUnit unit) { + private String lowerConditionLambda(ConditionalPattern pattern, Unit unit) { String lambdaName = nextLambdaName("cond"); - CodPTACFunction lambda = new CodPTACFunction(); + Function lambda = new Function(); lambda.name = lambdaName; lambda.lambdaBlock = true; lambda.parameters.add(pattern.indexVar != null ? pattern.indexVar : "p0"); @@ -278,28 +278,28 @@ private String lowerConditionLambda(ConditionalPattern pattern, CodPTACUnit unit if (pattern.branches != null && !pattern.branches.isEmpty()) { ConditionalPattern.Branch first = pattern.branches.get(0); if (first != null && first.condition != null) { - CodPTACOperand condition = lowerExpr(first.condition, lambda, unit); - lambda.instructions.add(new CodPTACInstruction( - CodPTACOpcode.RETURN, + Operand condition = lowerExpr(first.condition, lambda, unit); + lambda.instructions.add(new Instruction( + Opcode.RETURN, null, Arrays.asList(condition) )); } } if (lambda.instructions.isEmpty()) { - lambda.instructions.add(new CodPTACInstruction( - CodPTACOpcode.RETURN, + lambda.instructions.add(new Instruction( + Opcode.RETURN, null, - Arrays.asList(CodPTACOperand.immediate(Boolean.TRUE)) + Arrays.asList(Operand.immediate(Boolean.TRUE)) )); } unit.functions.add(lambda); return lambdaName; } - private String lowerSequenceLambda(SequencePattern.Pattern pattern, CodPTACUnit unit) { + private String lowerSequenceLambda(SequencePattern.Pattern pattern, Unit unit) { String lambdaName = nextLambdaName("seq"); - CodPTACFunction lambda = new CodPTACFunction(); + Function lambda = new Function(); lambda.name = lambdaName; lambda.lambdaBlock = true; lambda.parameters.add(pattern.indexVar != null ? pattern.indexVar : "p0"); @@ -309,16 +309,16 @@ private String lowerSequenceLambda(SequencePattern.Pattern pattern, CodPTACUnit if (step == null) { continue; } - CodPTACOperand value = lowerExpr(step.expression, lambda, unit); + Operand value = lowerExpr(step.expression, lambda, unit); if (step.tempVar != null) { - lambda.instructions.add(new CodPTACInstruction( - CodPTACOpcode.ASSIGN, + lambda.instructions.add(new Instruction( + Opcode.ASSIGN, step.tempVar, Arrays.asList(value) )); } else { - lambda.instructions.add(new CodPTACInstruction( - CodPTACOpcode.RETURN, + lambda.instructions.add(new Instruction( + Opcode.RETURN, null, Arrays.asList(value) )); @@ -326,92 +326,92 @@ private String lowerSequenceLambda(SequencePattern.Pattern pattern, CodPTACUnit } } if (lambda.instructions.isEmpty()) { - lambda.instructions.add(new CodPTACInstruction( - CodPTACOpcode.RETURN, + lambda.instructions.add(new Instruction( + Opcode.RETURN, null, - Arrays.asList(CodPTACOperand.immediate(null)) + Arrays.asList(Operand.immediate(null)) )); } unit.functions.add(lambda); return lambdaName; } - private CodPTACOperand lowerExpr(Expr expr, CodPTACFunction fn, CodPTACUnit unit) { - if (expr == null) return CodPTACOperand.immediate(null); + private Operand lowerExpr(Expr expr, Function fn, Unit unit) { + if (expr == null) return Operand.immediate(null); - if (expr instanceof IntLiteral) return CodPTACOperand.immediate(((IntLiteral) expr).value); - if (expr instanceof FloatLiteral) return CodPTACOperand.immediate(((FloatLiteral) expr).value); - if (expr instanceof BoolLiteral) return CodPTACOperand.immediate(((BoolLiteral) expr).value); - if (expr instanceof TextLiteral) return CodPTACOperand.immediate(((TextLiteral) expr).value); - if (expr instanceof NoneLiteral) return CodPTACOperand.immediate(null); - if (expr instanceof Identifier) return CodPTACOperand.register(((Identifier) expr).name); + if (expr instanceof IntLiteral) return Operand.immediate(((IntLiteral) expr).value); + if (expr instanceof FloatLiteral) return Operand.immediate(((FloatLiteral) expr).value); + if (expr instanceof BoolLiteral) return Operand.immediate(((BoolLiteral) expr).value); + if (expr instanceof TextLiteral) return Operand.immediate(((TextLiteral) expr).value); + if (expr instanceof NoneLiteral) return Operand.immediate(null); + if (expr instanceof Identifier) return Operand.register(((Identifier) expr).name); if (expr instanceof Range) { Range range = (Range) expr; String dest = nextPattern(); - CodPTACOpcode op = selectRangeOpcode(range); - List ops = new ArrayList(); + Opcode op = selectRangeOpcode(range); + List ops = new ArrayList(); ops.add(lowerExpr(range.start, fn, unit)); ops.add(lowerExpr(range.end, fn, unit)); if (range.step != null) ops.add(lowerExpr(range.step, fn, unit)); - fn.instructions.add(new CodPTACInstruction(op, dest, ops)); - return CodPTACOperand.register(dest); + fn.instructions.add(new Instruction(op, dest, ops)); + return Operand.register(dest); } if (expr instanceof IndexAccess) { IndexAccess access = (IndexAccess) expr; String dest = nextTemp(); - fn.instructions.add(new CodPTACInstruction( - CodPTACOpcode.LAZY_GET, + fn.instructions.add(new Instruction( + Opcode.LAZY_GET, dest, Arrays.asList( lowerExpr(access.array, fn, unit), lowerExpr(access.index, fn, unit) ) )); - return CodPTACOperand.register(dest); + return Operand.register(dest); } if (expr instanceof BinaryOp) { BinaryOp binary = (BinaryOp) expr; - CodPTACOperand left = lowerExpr(binary.left, fn, unit); - CodPTACOperand right = lowerExpr(binary.right, fn, unit); + Operand left = lowerExpr(binary.left, fn, unit); + Operand right = lowerExpr(binary.right, fn, unit); String dest = nextTemp(); - fn.instructions.add(new CodPTACInstruction( + fn.instructions.add(new Instruction( mapBinary(binary.op), dest, Arrays.asList(left, right) )); - return CodPTACOperand.register(dest); + return Operand.register(dest); } if (expr instanceof MethodCall) { MethodCall call = (MethodCall) expr; - List ops = new ArrayList(); + List ops = new ArrayList(); if (call.isSelfCall) { for (Expr arg : call.arguments) { ops.add(lowerExpr(arg, fn, unit)); } String dest = nextTemp(); - fn.instructions.add(new CodPTACInstruction(CodPTACOpcode.SELF, dest, ops)); - return CodPTACOperand.register(dest); + fn.instructions.add(new Instruction(Opcode.SELF, dest, ops)); + return Operand.register(dest); } - ops.add(CodPTACOperand.function(call.name)); + ops.add(Operand.function(call.name)); if (call.arguments != null) { for (Expr arg : call.arguments) { ops.add(lowerExpr(arg, fn, unit)); } } String dest = nextTemp(); - fn.instructions.add(new CodPTACInstruction(CodPTACOpcode.CALL, dest, ops)); - return CodPTACOperand.register(dest); + fn.instructions.add(new Instruction(Opcode.CALL, dest, ops)); + return Operand.register(dest); } if (expr instanceof Lambda) { Lambda lambdaNode = (Lambda) expr; String lambdaName = nextLambdaName("inline"); - CodPTACFunction lambda = new CodPTACFunction(); + Function lambda = new Function(); lambda.name = lambdaName; lambda.lambdaBlock = true; if (lambdaNode.parameters != null) { @@ -420,45 +420,45 @@ private CodPTACOperand lowerExpr(Expr expr, CodPTACFunction fn, CodPTACUnit unit } } if (lambdaNode.expressionBody != null) { - CodPTACOperand val = lowerExpr(lambdaNode.expressionBody, lambda, unit); - lambda.instructions.add(new CodPTACInstruction(CodPTACOpcode.RETURN, null, Arrays.asList(val))); + Operand val = lowerExpr(lambdaNode.expressionBody, lambda, unit); + lambda.instructions.add(new Instruction(Opcode.RETURN, null, Arrays.asList(val))); } else if (lambdaNode.body != null) { lowerStmt(lambdaNode.body, lambda, unit); } if (lambda.instructions.isEmpty()) { - lambda.instructions.add(new CodPTACInstruction( - CodPTACOpcode.RETURN, + lambda.instructions.add(new Instruction( + Opcode.RETURN, null, - Arrays.asList(CodPTACOperand.immediate(null)) + Arrays.asList(Operand.immediate(null)) )); } unit.functions.add(lambda); - return CodPTACOperand.function(lambdaName); + return Operand.function(lambdaName); } - return CodPTACOperand.identifier(String.valueOf(expr)); + return Operand.identifier(String.valueOf(expr)); } - private CodPTACOpcode selectRangeOpcode(Range range) { + private Opcode selectRangeOpcode(Range range) { if (range != null && (range.start instanceof TextLiteral || range.end instanceof TextLiteral)) { - return range.step == null ? CodPTACOpcode.RANGE_L : CodPTACOpcode.RANGE_LS; + return range.step == null ? Opcode.RANGE_L : Opcode.RANGE_LS; } - return range != null && range.step != null ? CodPTACOpcode.RANGE_S : CodPTACOpcode.RANGE; + return range != null && range.step != null ? Opcode.RANGE_S : Opcode.RANGE; } - private CodPTACOpcode mapBinary(String op) { - if ("+".equals(op)) return CodPTACOpcode.ADD; - if ("-".equals(op)) return CodPTACOpcode.SUB; - if ("*".equals(op)) return CodPTACOpcode.MUL; - if ("/".equals(op)) return CodPTACOpcode.DIV; - if ("%".equals(op)) return CodPTACOpcode.MOD; - if ("==".equals(op)) return CodPTACOpcode.EQ; - if ("!=".equals(op)) return CodPTACOpcode.NE; - if (">".equals(op)) return CodPTACOpcode.GT; - if ("<".equals(op)) return CodPTACOpcode.LT; - if (">=".equals(op)) return CodPTACOpcode.GTE; - if ("<=".equals(op)) return CodPTACOpcode.LTE; - return CodPTACOpcode.NOP; + private Opcode mapBinary(String op) { + if ("+".equals(op)) return Opcode.ADD; + if ("-".equals(op)) return Opcode.SUB; + if ("*".equals(op)) return Opcode.MUL; + if ("/".equals(op)) return Opcode.DIV; + if ("%".equals(op)) return Opcode.MOD; + if ("==".equals(op)) return Opcode.EQ; + if ("!=".equals(op)) return Opcode.NE; + if (">".equals(op)) return Opcode.GT; + if ("<".equals(op)) return Opcode.LT; + if (">=".equals(op)) return Opcode.GTE; + if ("<=".equals(op)) return Opcode.LTE; + return Opcode.NOP; } private String nextTemp() { @@ -473,9 +473,9 @@ private String nextLambdaName(String prefix) { return "lambda$" + prefix + "$" + (lambdaCounter++); } - private CodPTACFunction findFunction(CodPTACUnit unit, String name) { + private Function findFunction(Unit unit, String name) { if (unit == null || unit.functions == null) return null; - for (CodPTACFunction fn : unit.functions) { + for (Function fn : unit.functions) { if (fn != null && name.equals(fn.name)) return fn; } return null; diff --git a/src/main/java/cod/ptac/Opcode.java b/src/main/java/cod/ptac/Opcode.java new file mode 100644 index 00000000..b44e9d9b --- /dev/null +++ b/src/main/java/cod/ptac/Opcode.java @@ -0,0 +1,78 @@ +package cod.ptac; + + +public enum Opcode { + // Core TAC + NOP(Layer.CORE_TAC), + ASSIGN(Layer.CORE_TAC), + ADD(Layer.CORE_TAC), + SUB(Layer.CORE_TAC), + MUL(Layer.CORE_TAC), + DIV(Layer.CORE_TAC), + MOD(Layer.CORE_TAC), + EQ(Layer.CORE_TAC), + NE(Layer.CORE_TAC), + GT(Layer.CORE_TAC), + LT(Layer.CORE_TAC), + GTE(Layer.CORE_TAC), + LTE(Layer.CORE_TAC), + BRANCH(Layer.CORE_TAC), + BRANCH_IF(Layer.CORE_TAC), + CALL(Layer.CORE_TAC), + RETURN(Layer.CORE_TAC), + LOAD(Layer.CORE_TAC), + STORE(Layer.CORE_TAC), + + // Coderive pattern ops + RANGE(Layer.PATTERN), + RANGE_Q(Layer.PATTERN), + RANGE_S(Layer.PATTERN), + RANGE_L(Layer.PATTERN), + RANGE_LS(Layer.PATTERN), + MAP(Layer.PATTERN), + FILTER(Layer.PATTERN), + REDUCE(Layer.PATTERN), + WHERE(Layer.PATTERN), + SCAN(Layer.PATTERN), + ZIP(Layer.PATTERN), + TAKE(Layer.PATTERN), + FILTER_MAP(Layer.PATTERN), + FILTER_MAP_REDUCE(Layer.PATTERN), + + // Lambda / recursion / closures + LAMBDA(Layer.PATTERN), + CLOSURE(Layer.PATTERN), + ANCESTOR(Layer.PATTERN), + SELF(Layer.PATTERN), + TAIL_CALL(Layer.PATTERN), + + // Slot ops + SLOT_GET(Layer.PATTERN), + SLOT_SET(Layer.PATTERN), + SLOT_RET(Layer.PATTERN), + SLOT_UNPACK(Layer.PATTERN), + SLOT_DIV(Layer.PATTERN), + + // Lazy ops + LAZY_GET(Layer.PATTERN), + LAZY_SET(Layer.PATTERN), + LAZY_COMMIT(Layer.PATTERN), + LAZY_SIZE(Layer.PATTERN), + LAZY_SLICE(Layer.PATTERN), + + // Formula ops + FORMULA_SEQ(Layer.PATTERN), + FORMULA_COND(Layer.PATTERN), + FORMULA_RECUR(Layer.PATTERN), + FORMULA_FUSE(Layer.PATTERN); + + private final Layer layer; + + Opcode(Layer layer) { + this.layer = layer; + } + + public Layer getLayer() { + return layer; + } +} diff --git a/src/main/java/cod/ptac/Operand.java b/src/main/java/cod/ptac/Operand.java new file mode 100644 index 00000000..414c6e85 --- /dev/null +++ b/src/main/java/cod/ptac/Operand.java @@ -0,0 +1,36 @@ +package cod.ptac; + + +public final class Operand { + public final OperandKind kind; + public final Object value; + + private Operand(OperandKind kind, Object value) { + this.kind = kind; + this.value = value; + } + + public static Operand register(String name) { + return new Operand(OperandKind.REGISTER, name); + } + + public static Operand immediate(Object value) { + return new Operand(OperandKind.IMMEDIATE, value); + } + + public static Operand label(String name) { + return new Operand(OperandKind.LABEL, name); + } + + public static Operand function(String name) { + return new Operand(OperandKind.FUNCTION, name); + } + + public static Operand slot(String name) { + return new Operand(OperandKind.SLOT, name); + } + + public static Operand identifier(String name) { + return new Operand(OperandKind.IDENTIFIER, name); + } +} diff --git a/src/main/java/cod/ptac/CodPTACOperandKind.java b/src/main/java/cod/ptac/OperandKind.java similarity index 75% rename from src/main/java/cod/ptac/CodPTACOperandKind.java rename to src/main/java/cod/ptac/OperandKind.java index 08d09c3a..d5f4f972 100644 --- a/src/main/java/cod/ptac/CodPTACOperandKind.java +++ b/src/main/java/cod/ptac/OperandKind.java @@ -1,7 +1,7 @@ package cod.ptac; -public enum CodPTACOperandKind { +public enum OperandKind { REGISTER, IMMEDIATE, LABEL, diff --git a/src/main/java/cod/ptac/Optimizer.java b/src/main/java/cod/ptac/Optimizer.java new file mode 100644 index 00000000..4ca314bb --- /dev/null +++ b/src/main/java/cod/ptac/Optimizer.java @@ -0,0 +1,31 @@ +package cod.ptac; + +import cod.ptac.opt.*; + +import java.util.ArrayList; +import java.util.List; + +public final class Optimizer { + private final List passes; + + public Optimizer() { + this(false); + } + + public Optimizer(boolean enableOptionalLowering) { + this.passes = new ArrayList(); + this.passes.add(new PatternFusion()); + this.passes.add(new RangePropagation()); + this.passes.add(new ConstantFolding()); + this.passes.add(new DeadTempElimination()); + this.passes.add(new PatternLowering(enableOptionalLowering)); + } + + public Unit optimize(Unit unit) { + if (unit == null) return null; + for (Optimization pass : passes) { + pass.apply(unit); + } + return unit; + } +} diff --git a/src/main/java/cod/ptac/CodPTACOptions.java b/src/main/java/cod/ptac/Options.java similarity index 83% rename from src/main/java/cod/ptac/CodPTACOptions.java rename to src/main/java/cod/ptac/Options.java index d46c99fc..1f479bbb 100644 --- a/src/main/java/cod/ptac/CodPTACOptions.java +++ b/src/main/java/cod/ptac/Options.java @@ -1,6 +1,6 @@ package cod.ptac; -public final class CodPTACOptions { +public final class Options { public enum Mode { INTERPRETER, COMPILE_ONLY, @@ -10,12 +10,12 @@ public enum Mode { private final Mode mode; private final boolean fallbackEnabled; - private CodPTACOptions(Mode mode, boolean fallbackEnabled) { + private Options(Mode mode, boolean fallbackEnabled) { this.mode = mode; this.fallbackEnabled = fallbackEnabled; } - public static CodPTACOptions current() { + public static Options current() { String rawMode = firstNonEmpty( System.getProperty("cod.ptac.mode"), System.getenv("COD_PTAC_MODE") @@ -27,11 +27,11 @@ public static CodPTACOptions current() { Mode mode = parseMode(rawMode); boolean fallback = rawFallback == null || !"false".equalsIgnoreCase(rawFallback.trim()); - return new CodPTACOptions(mode, fallback); + return new Options(mode, fallback); } - public static CodPTACOptions compileExecuteWithFallback(boolean fallback) { - return new CodPTACOptions(Mode.COMPILE_EXECUTE, fallback); + public static Options compileExecuteWithFallback(boolean fallback) { + return new Options(Mode.COMPILE_EXECUTE, fallback); } public Mode getMode() { diff --git a/src/main/java/cod/ptac/CodPTACUnit.java b/src/main/java/cod/ptac/Unit.java similarity index 59% rename from src/main/java/cod/ptac/CodPTACUnit.java rename to src/main/java/cod/ptac/Unit.java index dcd9f72f..fb105e54 100644 --- a/src/main/java/cod/ptac/CodPTACUnit.java +++ b/src/main/java/cod/ptac/Unit.java @@ -3,9 +3,9 @@ import java.util.ArrayList; import java.util.List; -public final class CodPTACUnit { +public final class Unit { public String unitName; public String className; public String entryFunction; - public List functions = new ArrayList(); + public List functions = new ArrayList(); } diff --git a/src/main/java/cod/ptac/opt/CodPTACConstantFoldingPass.java b/src/main/java/cod/ptac/opt/CodPTACConstantFoldingPass.java deleted file mode 100644 index be398b36..00000000 --- a/src/main/java/cod/ptac/opt/CodPTACConstantFoldingPass.java +++ /dev/null @@ -1,59 +0,0 @@ -package cod.ptac.opt; - -import cod.ptac.*; - -import java.util.ArrayList; -import java.util.List; - -public final class CodPTACConstantFoldingPass implements CodPTACOptimizationPass { - @Override - public void apply(CodPTACUnit unit) { - if (unit == null || unit.functions == null) return; - - for (CodPTACFunction function : unit.functions) { - if (function == null || function.instructions == null) continue; - List rewritten = new ArrayList(); - - for (CodPTACInstruction inst : function.instructions) { - rewritten.add(fold(inst)); - } - function.instructions = rewritten; - } - } - - private CodPTACInstruction fold(CodPTACInstruction inst) { - if (inst == null || inst.operands == null || inst.operands.size() != 2) return inst; - - if (!isFoldable(inst.opcode)) return inst; - CodPTACOperand left = inst.operands.get(0); - CodPTACOperand right = inst.operands.get(1); - if (left.kind != CodPTACOperandKind.IMMEDIATE || right.kind != CodPTACOperandKind.IMMEDIATE) return inst; - if (!(left.value instanceof Number) || !(right.value instanceof Number)) return inst; - - double a = ((Number) left.value).doubleValue(); - double b = ((Number) right.value).doubleValue(); - Object folded = compute(inst.opcode, a, b); - if (folded == null) return inst; - - List operands = new ArrayList(); - operands.add(CodPTACOperand.immediate(folded)); - return new CodPTACInstruction(CodPTACOpcode.ASSIGN, inst.dest, operands, inst.flags); - } - - private boolean isFoldable(CodPTACOpcode opcode) { - return opcode == CodPTACOpcode.ADD - || opcode == CodPTACOpcode.SUB - || opcode == CodPTACOpcode.MUL - || opcode == CodPTACOpcode.DIV - || opcode == CodPTACOpcode.MOD; - } - - private Object compute(CodPTACOpcode opcode, double a, double b) { - if (opcode == CodPTACOpcode.ADD) return a + b; - if (opcode == CodPTACOpcode.SUB) return a - b; - if (opcode == CodPTACOpcode.MUL) return a * b; - if (opcode == CodPTACOpcode.DIV) return b == 0.0d ? null : a / b; - if (opcode == CodPTACOpcode.MOD) return b == 0.0d ? null : a % b; - return null; - } -} diff --git a/src/main/java/cod/ptac/opt/CodPTACDeadTempEliminationPass.java b/src/main/java/cod/ptac/opt/CodPTACDeadTempEliminationPass.java deleted file mode 100644 index f9c8a713..00000000 --- a/src/main/java/cod/ptac/opt/CodPTACDeadTempEliminationPass.java +++ /dev/null @@ -1,54 +0,0 @@ -package cod.ptac.opt; - -import cod.ptac.*; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -public final class CodPTACDeadTempEliminationPass implements CodPTACOptimizationPass { - @Override - public void apply(CodPTACUnit unit) { - if (unit == null || unit.functions == null) return; - - for (CodPTACFunction function : unit.functions) { - if (function == null || function.instructions == null) continue; - Set used = collectUsedRegisters(function.instructions); - List rewritten = new ArrayList(); - for (CodPTACInstruction inst : function.instructions) { - if (canDrop(inst, used)) continue; - rewritten.add(inst); - } - function.instructions = rewritten; - } - } - - private Set collectUsedRegisters(List instructions) { - Set used = new HashSet(); - for (CodPTACInstruction inst : instructions) { - if (inst == null || inst.operands == null) continue; - for (CodPTACOperand operand : inst.operands) { - if (operand != null && operand.kind == CodPTACOperandKind.REGISTER && operand.value instanceof String) { - used.add((String) operand.value); - } - } - } - return used; - } - - private boolean canDrop(CodPTACInstruction inst, Set used) { - if (inst == null || inst.dest == null) return false; - if (!inst.dest.startsWith("t")) return false; - if (hasSideEffect(inst.opcode)) return false; - return !used.contains(inst.dest); - } - - private boolean hasSideEffect(CodPTACOpcode opcode) { - return opcode == CodPTACOpcode.STORE - || opcode == CodPTACOpcode.CALL - || opcode == CodPTACOpcode.SLOT_SET - || opcode == CodPTACOpcode.LAZY_SET - || opcode == CodPTACOpcode.LAZY_COMMIT; - } -} diff --git a/src/main/java/cod/ptac/opt/CodPTACLazyRangePropagationPass.java b/src/main/java/cod/ptac/opt/CodPTACLazyRangePropagationPass.java deleted file mode 100644 index 8a6ebfea..00000000 --- a/src/main/java/cod/ptac/opt/CodPTACLazyRangePropagationPass.java +++ /dev/null @@ -1,41 +0,0 @@ -package cod.ptac.opt; - -import cod.ptac.*; - -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; - -public final class CodPTACLazyRangePropagationPass implements CodPTACOptimizationPass { - @Override - public void apply(CodPTACUnit unit) { - if (unit == null || unit.functions == null) return; - - for (CodPTACFunction function : unit.functions) { - if (function == null || function.instructions == null) continue; - List rewritten = new ArrayList(); - - for (CodPTACInstruction inst : function.instructions) { - rewritten.add(markLazyIfNeeded(inst)); - } - - function.instructions = rewritten; - } - } - - private CodPTACInstruction markLazyIfNeeded(CodPTACInstruction inst) { - if (inst == null) return null; - if (inst.opcode != CodPTACOpcode.RANGE - && inst.opcode != CodPTACOpcode.RANGE_Q - && inst.opcode != CodPTACOpcode.RANGE_S - && inst.opcode != CodPTACOpcode.RANGE_L - && inst.opcode != CodPTACOpcode.RANGE_LS) { - return inst; - } - EnumSet flags = inst.flags != null - ? EnumSet.copyOf(inst.flags) - : EnumSet.noneOf(CodPTACFlag.class); - flags.add(CodPTACFlag.LAZY); - return new CodPTACInstruction(inst.opcode, inst.dest, inst.operands, flags); - } -} diff --git a/src/main/java/cod/ptac/opt/CodPTACOptimizationPass.java b/src/main/java/cod/ptac/opt/CodPTACOptimizationPass.java deleted file mode 100644 index f5e7a148..00000000 --- a/src/main/java/cod/ptac/opt/CodPTACOptimizationPass.java +++ /dev/null @@ -1,7 +0,0 @@ -package cod.ptac.opt; - -import cod.ptac.CodPTACUnit; - -public interface CodPTACOptimizationPass { - void apply(CodPTACUnit unit); -} diff --git a/src/main/java/cod/ptac/opt/CodPTACOptionalPatternLoweringPass.java b/src/main/java/cod/ptac/opt/CodPTACOptionalPatternLoweringPass.java deleted file mode 100644 index 619fdf3d..00000000 --- a/src/main/java/cod/ptac/opt/CodPTACOptionalPatternLoweringPass.java +++ /dev/null @@ -1,43 +0,0 @@ -package cod.ptac.opt; - -import cod.ptac.*; - -import java.util.ArrayList; -import java.util.List; - -public final class CodPTACOptionalPatternLoweringPass implements CodPTACOptimizationPass { - private final boolean enabled; - - public CodPTACOptionalPatternLoweringPass(boolean enabled) { - this.enabled = enabled; - } - - @Override - public void apply(CodPTACUnit unit) { - if (!enabled || unit == null || unit.functions == null) return; - - for (CodPTACFunction function : unit.functions) { - if (function == null || function.instructions == null) continue; - - List rewritten = new ArrayList(); - for (CodPTACInstruction inst : function.instructions) { - if (inst == null) continue; - if (isUnsupportedPattern(inst.opcode)) { - List operands = new ArrayList(); - operands.add(CodPTACOperand.function(inst.opcode.name())); - rewritten.add(new CodPTACInstruction(CodPTACOpcode.CALL, inst.dest, operands, inst.flags)); - } else { - rewritten.add(inst); - } - } - function.instructions = rewritten; - } - } - - private boolean isUnsupportedPattern(CodPTACOpcode opcode) { - return opcode == CodPTACOpcode.ZIP - || opcode == CodPTACOpcode.SCAN - || opcode == CodPTACOpcode.FORMULA_RECUR - || opcode == CodPTACOpcode.FORMULA_FUSE; - } -} diff --git a/src/main/java/cod/ptac/opt/ConstantFolding.java b/src/main/java/cod/ptac/opt/ConstantFolding.java new file mode 100644 index 00000000..a80a5232 --- /dev/null +++ b/src/main/java/cod/ptac/opt/ConstantFolding.java @@ -0,0 +1,59 @@ +package cod.ptac.opt; + +import cod.ptac.*; + +import java.util.ArrayList; +import java.util.List; + +public final class ConstantFolding implements Optimization { + @Override + public void apply(Unit unit) { + if (unit == null || unit.functions == null) return; + + for (Function function : unit.functions) { + if (function == null || function.instructions == null) continue; + List rewritten = new ArrayList(); + + for (Instruction inst : function.instructions) { + rewritten.add(fold(inst)); + } + function.instructions = rewritten; + } + } + + private Instruction fold(Instruction inst) { + if (inst == null || inst.operands == null || inst.operands.size() != 2) return inst; + + if (!isFoldable(inst.opcode)) return inst; + Operand left = inst.operands.get(0); + Operand right = inst.operands.get(1); + if (left.kind != OperandKind.IMMEDIATE || right.kind != OperandKind.IMMEDIATE) return inst; + if (!(left.value instanceof Number) || !(right.value instanceof Number)) return inst; + + double a = ((Number) left.value).doubleValue(); + double b = ((Number) right.value).doubleValue(); + Object folded = compute(inst.opcode, a, b); + if (folded == null) return inst; + + List operands = new ArrayList(); + operands.add(Operand.immediate(folded)); + return new Instruction(Opcode.ASSIGN, inst.dest, operands, inst.flags); + } + + private boolean isFoldable(Opcode opcode) { + return opcode == Opcode.ADD + || opcode == Opcode.SUB + || opcode == Opcode.MUL + || opcode == Opcode.DIV + || opcode == Opcode.MOD; + } + + private Object compute(Opcode opcode, double a, double b) { + if (opcode == Opcode.ADD) return a + b; + if (opcode == Opcode.SUB) return a - b; + if (opcode == Opcode.MUL) return a * b; + if (opcode == Opcode.DIV) return b == 0.0d ? null : a / b; + if (opcode == Opcode.MOD) return b == 0.0d ? null : a % b; + return null; + } +} diff --git a/src/main/java/cod/ptac/opt/DeadTempElimination.java b/src/main/java/cod/ptac/opt/DeadTempElimination.java new file mode 100644 index 00000000..49272333 --- /dev/null +++ b/src/main/java/cod/ptac/opt/DeadTempElimination.java @@ -0,0 +1,54 @@ +package cod.ptac.opt; + +import cod.ptac.*; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public final class DeadTempElimination implements Optimization { + @Override + public void apply(Unit unit) { + if (unit == null || unit.functions == null) return; + + for (Function function : unit.functions) { + if (function == null || function.instructions == null) continue; + Set used = collectUsedRegisters(function.instructions); + List rewritten = new ArrayList(); + for (Instruction inst : function.instructions) { + if (canDrop(inst, used)) continue; + rewritten.add(inst); + } + function.instructions = rewritten; + } + } + + private Set collectUsedRegisters(List instructions) { + Set used = new HashSet(); + for (Instruction inst : instructions) { + if (inst == null || inst.operands == null) continue; + for (Operand operand : inst.operands) { + if (operand != null && operand.kind == OperandKind.REGISTER && operand.value instanceof String) { + used.add((String) operand.value); + } + } + } + return used; + } + + private boolean canDrop(Instruction inst, Set used) { + if (inst == null || inst.dest == null) return false; + if (!inst.dest.startsWith("t")) return false; + if (hasSideEffect(inst.opcode)) return false; + return !used.contains(inst.dest); + } + + private boolean hasSideEffect(Opcode opcode) { + return opcode == Opcode.STORE + || opcode == Opcode.CALL + || opcode == Opcode.SLOT_SET + || opcode == Opcode.LAZY_SET + || opcode == Opcode.LAZY_COMMIT; + } +} diff --git a/src/main/java/cod/ptac/opt/Optimization.java b/src/main/java/cod/ptac/opt/Optimization.java new file mode 100644 index 00000000..7a164d08 --- /dev/null +++ b/src/main/java/cod/ptac/opt/Optimization.java @@ -0,0 +1,7 @@ +package cod.ptac.opt; + +import cod.ptac.Unit; + +public interface Optimization { + void apply(Unit unit); +} diff --git a/src/main/java/cod/ptac/opt/CodPTACPatternFusionPass.java b/src/main/java/cod/ptac/opt/PatternFusion.java similarity index 50% rename from src/main/java/cod/ptac/opt/CodPTACPatternFusionPass.java rename to src/main/java/cod/ptac/opt/PatternFusion.java index 3da52459..a038bd0d 100644 --- a/src/main/java/cod/ptac/opt/CodPTACPatternFusionPass.java +++ b/src/main/java/cod/ptac/opt/PatternFusion.java @@ -5,26 +5,26 @@ import java.util.ArrayList; import java.util.List; -public final class CodPTACPatternFusionPass implements CodPTACOptimizationPass { +public final class PatternFusion implements Optimization { @Override - public void apply(CodPTACUnit unit) { + public void apply(Unit unit) { if (unit == null || unit.functions == null) return; - for (CodPTACFunction function : unit.functions) { + for (Function function : unit.functions) { if (function == null || function.instructions == null) continue; - List rewritten = new ArrayList(); + List rewritten = new ArrayList(); for (int i = 0; i < function.instructions.size(); i++) { - CodPTACInstruction current = function.instructions.get(i); - CodPTACInstruction next = i + 1 < function.instructions.size() + Instruction current = function.instructions.get(i); + Instruction next = i + 1 < function.instructions.size() ? function.instructions.get(i + 1) : null; if (canFuseFilterMap(current, next)) { - List fusedOps = new ArrayList(); + List fusedOps = new ArrayList(); fusedOps.add(current.operands.get(0)); // source fusedOps.add(current.operands.get(1)); // filter lambda fusedOps.add(next.operands.get(1)); // map lambda - rewritten.add(new CodPTACInstruction(CodPTACOpcode.FILTER_MAP, next.dest, fusedOps, next.flags)); + rewritten.add(new Instruction(Opcode.FILTER_MAP, next.dest, fusedOps, next.flags)); i++; continue; } @@ -35,13 +35,13 @@ public void apply(CodPTACUnit unit) { } } - private boolean canFuseFilterMap(CodPTACInstruction filter, CodPTACInstruction map) { + private boolean canFuseFilterMap(Instruction filter, Instruction map) { if (filter == null || map == null) return false; - if (filter.opcode != CodPTACOpcode.FILTER) return false; - if (map.opcode != CodPTACOpcode.MAP) return false; + if (filter.opcode != Opcode.FILTER) return false; + if (map.opcode != Opcode.MAP) return false; if (filter.dest == null) return false; if (map.operands == null || map.operands.isEmpty()) return false; - CodPTACOperand mapSource = map.operands.get(0); - return mapSource.kind == CodPTACOperandKind.REGISTER && filter.dest.equals(mapSource.value); + Operand mapSource = map.operands.get(0); + return mapSource.kind == OperandKind.REGISTER && filter.dest.equals(mapSource.value); } } diff --git a/src/main/java/cod/ptac/opt/PatternLowering.java b/src/main/java/cod/ptac/opt/PatternLowering.java new file mode 100644 index 00000000..7f9f5cf9 --- /dev/null +++ b/src/main/java/cod/ptac/opt/PatternLowering.java @@ -0,0 +1,43 @@ +package cod.ptac.opt; + +import cod.ptac.*; + +import java.util.ArrayList; +import java.util.List; + +public final class PatternLowering implements Optimization { + private final boolean enabled; + + public PatternLowering(boolean enabled) { + this.enabled = enabled; + } + + @Override + public void apply(Unit unit) { + if (!enabled || unit == null || unit.functions == null) return; + + for (Function function : unit.functions) { + if (function == null || function.instructions == null) continue; + + List rewritten = new ArrayList(); + for (Instruction inst : function.instructions) { + if (inst == null) continue; + if (isUnsupportedPattern(inst.opcode)) { + List operands = new ArrayList(); + operands.add(Operand.function(inst.opcode.name())); + rewritten.add(new Instruction(Opcode.CALL, inst.dest, operands, inst.flags)); + } else { + rewritten.add(inst); + } + } + function.instructions = rewritten; + } + } + + private boolean isUnsupportedPattern(Opcode opcode) { + return opcode == Opcode.ZIP + || opcode == Opcode.SCAN + || opcode == Opcode.FORMULA_RECUR + || opcode == Opcode.FORMULA_FUSE; + } +} diff --git a/src/main/java/cod/ptac/opt/RangePropagation.java b/src/main/java/cod/ptac/opt/RangePropagation.java new file mode 100644 index 00000000..d88d388b --- /dev/null +++ b/src/main/java/cod/ptac/opt/RangePropagation.java @@ -0,0 +1,41 @@ +package cod.ptac.opt; + +import cod.ptac.*; + +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; + +public final class RangePropagation implements Optimization { + @Override + public void apply(Unit unit) { + if (unit == null || unit.functions == null) return; + + for (Function function : unit.functions) { + if (function == null || function.instructions == null) continue; + List rewritten = new ArrayList(); + + for (Instruction inst : function.instructions) { + rewritten.add(markLazyIfNeeded(inst)); + } + + function.instructions = rewritten; + } + } + + private Instruction markLazyIfNeeded(Instruction inst) { + if (inst == null) return null; + if (inst.opcode != Opcode.RANGE + && inst.opcode != Opcode.RANGE_Q + && inst.opcode != Opcode.RANGE_S + && inst.opcode != Opcode.RANGE_L + && inst.opcode != Opcode.RANGE_LS) { + return inst; + } + EnumSet flags = inst.flags != null + ? EnumSet.copyOf(inst.flags) + : EnumSet.noneOf(Flag.class); + flags.add(Flag.LAZY); + return new Instruction(inst.opcode, inst.dest, inst.operands, flags); + } +} diff --git a/src/main/java/cod/runner/CodPTACParityRunner.java b/src/main/java/cod/runner/CodPTACParityRunner.java index a6fd8eb1..8ac91370 100644 --- a/src/main/java/cod/runner/CodPTACParityRunner.java +++ b/src/main/java/cod/runner/CodPTACParityRunner.java @@ -6,9 +6,9 @@ import cod.interpreter.Index; import cod.interpreter.Interpreter; import cod.ir.IRManager; -import cod.ptac.CodPTACArtifact; -import cod.ptac.CodPTACExecutor; -import cod.ptac.CodPTACOptions; +import cod.ptac.Artifact; +import cod.ptac.Executor; +import cod.ptac.Options; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.FilterInputStream; @@ -224,7 +224,7 @@ private String runCodPTACPath(String file) throws Exception { } manager.save(unitName, entryType); - CodPTACArtifact artifact = manager.loadArtifact(unitName, entryType.name); + Artifact artifact = manager.loadArtifact(unitName, entryType.name); if (artifact == null) { throw new Exception("Failed to load CodP-TAC artifact for: " + file); @@ -282,7 +282,7 @@ private String captureOutput(Interpreter interpreter, Program ast) { } } - private String captureOutputPTAC(CodPTACArtifact artifact, Interpreter interpreter) { + private String captureOutputPTAC(Artifact artifact, Interpreter interpreter) { PrintStream oldOut = System.out; PrintStream oldErr = System.err; java.io.InputStream oldIn = System.in; @@ -295,7 +295,7 @@ private String captureOutputPTAC(CodPTACArtifact artifact, Interpreter interpret System.setOut(outReplacement); System.setErr(errReplacement); System.setIn(new ByteArrayInputStream(DEFAULT_INPUT.getBytes(StandardCharsets.UTF_8))); - new CodPTACExecutor(CodPTACOptions.compileExecuteWithFallback(true)) + new Executor(Options.compileExecuteWithFallback(true)) .execute(artifact, interpreter); outReplacement.flush(); errReplacement.flush(); diff --git a/src/main/java/cod/runner/CommandRunner.java b/src/main/java/cod/runner/CommandRunner.java index ba07469d..d6e89ed3 100644 --- a/src/main/java/cod/runner/CommandRunner.java +++ b/src/main/java/cod/runner/CommandRunner.java @@ -6,21 +6,21 @@ import cod.interpreter.Interpreter; import cod.interpreter.Index; import cod.ir.IRManager; -import cod.ptac.CodPTACArtifact; -import cod.ptac.CodPTACExecutor; -import cod.ptac.CodPTACOptions; +import cod.ptac.Artifact; +import cod.ptac.Executor; +import cod.ptac.Options; public class CommandRunner extends BaseRunner { private final Interpreter interpreter; private IRManager irManager; - private final CodPTACOptions ptacOptions; + private final Options ptacOptions; private static final String NAME = "COMMAND"; public CommandRunner() { this.interpreter = new Interpreter(); - this.ptacOptions = CodPTACOptions.current(); + this.ptacOptions = Options.current(); } @Override @@ -191,14 +191,14 @@ private void executeInterpretation(Program ast) { if (ptacOptions.isCompileExecuteEnabled() && irManager != null && ast != null && ast.unit != null) { Type entryType = findMainType(ast); if (entryType != null) { - CodPTACArtifact artifact = irManager.loadArtifact(ast.unit.name, entryType.name); + Artifact artifact = irManager.loadArtifact(ast.unit.name, entryType.name); if (artifact == null) { irManager.save(ast.unit.name, entryType); artifact = irManager.loadArtifact(ast.unit.name, entryType.name); } if (artifact != null) { DebugSystem.info(NAME + LOG_TAG, "Executing using CodP-TAC executor"); - new CodPTACExecutor(ptacOptions).execute(artifact, interpreter); + new Executor(ptacOptions).execute(artifact, interpreter); DebugSystem.info(NAME + LOG_TAG, "Program interpretation completed"); return; } diff --git a/src/main/java/cod/runner/IRValidationRunner.java b/src/main/java/cod/runner/IRValidationRunner.java index 106e9d1f..79deedb6 100644 --- a/src/main/java/cod/runner/IRValidationRunner.java +++ b/src/main/java/cod/runner/IRValidationRunner.java @@ -5,7 +5,7 @@ import cod.ir.IRReader; import cod.ir.IRWriter; import cod.interpreter.Interpreter; -import cod.ptac.CodPTACArtifact; +import cod.ptac.Artifact; import cod.semantic.ImportResolver; import java.io.File; @@ -40,7 +40,7 @@ public void run(String[] args) throws Exception { IRWriter writer = new IRWriter(); IRReader reader = new IRReader(); writer.write(tmp, original); - CodPTACArtifact artifact = reader.readArtifact(tmp); + Artifact artifact = reader.readArtifact(tmp); Type loaded = artifact != null ? artifact.typeSnapshot : null; assertTrue(loaded != null, "Loaded type is null"); @@ -65,7 +65,7 @@ public void run(String[] args) throws Exception { Type managerLoaded = manager.load(program.unit.name, original.name); assertTrue(managerLoaded != null, "IRManager failed to load saved class"); assertTrue(equalsSafe(original.name, managerLoaded.name), "IRManager loaded wrong class"); - CodPTACArtifact managerArtifact = manager.loadArtifact(program.unit.name, original.name); + Artifact managerArtifact = manager.loadArtifact(program.unit.name, original.name); assertTrue(managerArtifact != null, "IRManager failed to load saved CodP-TAC artifact"); assertTrue(managerArtifact.unit != null, "CodP-TAC unit missing from artifact"); } @@ -94,7 +94,7 @@ private void validateInternalImportIRPath() throws Exception { IRManager manager = new IRManager(codProjectRoot); Type internalType = internalProgram.unit.types.get(0); manager.save(internalProgram.unit.name, internalType); - CodPTACArtifact artifact = manager.loadArtifact(internalProgram.unit.name, internalType.name); + Artifact artifact = manager.loadArtifact(internalProgram.unit.name, internalType.name); assertTrue(artifact != null, "Failed to save/load internal CodP-TAC artifact"); Interpreter internalMultiRangeInterpreter = new Interpreter(); @@ -107,7 +107,7 @@ private void validateInternalImportIRPath() throws Exception { Type internalMultiRangeType = internalMultiRangeProgram.unit.types.get(0); manager.save(internalMultiRangeProgram.unit.name, internalMultiRangeType); - CodPTACArtifact multiArtifact = manager.loadArtifact(internalMultiRangeProgram.unit.name, internalMultiRangeType.name); + Artifact multiArtifact = manager.loadArtifact(internalMultiRangeProgram.unit.name, internalMultiRangeType.name); assertTrue(multiArtifact != null, "Failed to save/load internal multi-range CodP-TAC artifact"); ImportResolver resolver = new ImportResolver(); diff --git a/src/main/java/cod/runner/TestRunner.java b/src/main/java/cod/runner/TestRunner.java index 3347c821..f46a57ff 100644 --- a/src/main/java/cod/runner/TestRunner.java +++ b/src/main/java/cod/runner/TestRunner.java @@ -6,9 +6,9 @@ import cod.debug.Linter; import cod.interpreter.Index; import cod.ir.IRManager; -import cod.ptac.CodPTACArtifact; -import cod.ptac.CodPTACExecutor; -import cod.ptac.CodPTACOptions; +import cod.ptac.Artifact; +import cod.ptac.Executor; +import cod.ptac.Options; import java.io.File; import java.io.FilterInputStream; @@ -36,11 +36,11 @@ public class TestRunner extends BaseRunner { private final Interpreter interpreter; private IRManager irManager; - private final CodPTACOptions ptacOptions; + private final Options ptacOptions; public TestRunner() { this.interpreter = new Interpreter(); - this.ptacOptions = CodPTACOptions.current(); + this.ptacOptions = Options.current(); } @Override @@ -273,14 +273,14 @@ private void executeWithManualInterpreter(Program ast) { if (ptacOptions.isCompileExecuteEnabled() && irManager != null && ast != null && ast.unit != null) { Type entryType = findMainType(ast); if (entryType != null) { - CodPTACArtifact artifact = irManager.loadArtifact(ast.unit.name, entryType.name); + Artifact artifact = irManager.loadArtifact(ast.unit.name, entryType.name); if (artifact == null) { irManager.save(ast.unit.name, entryType); artifact = irManager.loadArtifact(ast.unit.name, entryType.name); } if (artifact != null) { DebugSystem.info(NAME + LOG_TAG, "Executing using CodP-TAC executor"); - new CodPTACExecutor(ptacOptions).execute(artifact, interpreter); + new Executor(ptacOptions).execute(artifact, interpreter); double duration = DebugSystem.stopTimer("interpretation"); DebugSystem.info(NAME + LOG_TAG, String.format("Interpretation completed in %.3f ms", duration)); return; diff --git a/src/main/java/cod/semantic/ImportResolver.java b/src/main/java/cod/semantic/ImportResolver.java index a8c9dde7..b298bdca 100644 --- a/src/main/java/cod/semantic/ImportResolver.java +++ b/src/main/java/cod/semantic/ImportResolver.java @@ -9,7 +9,7 @@ import cod.debug.DebugSystem; import cod.interpreter.Index; import cod.ir.IRManager; -import cod.ptac.CodPTACArtifact; +import cod.ptac.Artifact; import java.util.*; import java.io.*; @@ -61,7 +61,7 @@ public class ImportResolver { // Cache for loaded TypeNodes (bytecode or parsed) private Map loadedTypes = createBoundedMap(LOADED_TYPES_CACHE_LIMIT); - private Map loadedArtifacts = createBoundedMap(LOADED_TYPES_CACHE_LIMIT); + private Map loadedArtifacts = createBoundedMap(LOADED_TYPES_CACHE_LIMIT); // Filesystem result cache private Map fileCache = createBoundedMap(FILE_CACHE_LIMIT); @@ -733,7 +733,7 @@ public Type resolveImport(String importName) throws Exception { // ========== TRY CODE-P-TAC ARTIFACT FIRST (FAST PATH) ========== if (irManager != null) { - CodPTACArtifact artifact = irManager.loadArtifact(unitName, className); + Artifact artifact = irManager.loadArtifact(unitName, className); if (artifact != null) { bytecodeCacheHits++; DebugSystem.debug("IR", "Loaded " + className + " CodP-TAC artifact from .codb (cache hit)");