diff --git a/source_.jar b/source_.jar index 6390be5c..d8e3af8b 100644 Binary files a/source_.jar and b/source_.jar differ diff --git a/src/main/cod/demo/src/main/test/linearrecurrenceoptimization/VectorLinearRecurrenceOptimization.cod b/src/main/cod/demo/src/main/test/linearrecurrenceoptimization/VectorLinearRecurrenceOptimization.cod new file mode 100644 index 00000000..a9d5cec8 --- /dev/null +++ b/src/main/cod/demo/src/main/test/linearrecurrenceoptimization/VectorLinearRecurrenceOptimization.cod @@ -0,0 +1,43 @@ +unit test.linearrecurrenceoptimization + +share main() { + out("=== Vector linear recurrence optimization ===") + start := timer() + + a := [0 to 200] + b := [0 to 200] + + a[0] = 1 + a[1] = 2 + b[0] = 0 + b[1] = 1 + + for i of [2 to 200] { + a[i] = a[i-1] + b[i-1] + b[i] = a[i-1] - b[i-1] + } + + out("pair1-a[10]=" + a[10] + " expected=48") + out("pair1-b[10]=" + b[10] + " expected=16") + out("pair1-a[20]=" + a[20] + " expected=1536") + out("pair1-b[20]=" + b[20] + " expected=512") + + c := [0 to 200] + d := [0 to 200] + + c[0] = 1 + c[1] = 2 + d[0] = 3 + d[1] = 4 + + for i of [2 to 200] { + c[i] = c[i-1] + d[i-2] + d[i] = 2 * c[i-2] - d[i-1] + 3 + } + + out("pair2-c[10]=" + c[10] + " expected=95") + out("pair2-d[10]=" + d[10] + " expected=41") + out("pair2-c[20]=" + c[20] + " expected=3070") + out("pair2-d[20]=" + d[20] + " expected=1367") + out("elapsed_ms=" + (timer() - start)) +} diff --git a/src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java b/src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java index 06916039..3b269601 100644 --- a/src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java +++ b/src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java @@ -22,6 +22,8 @@ public class LoopOptimizationHandler { private static final int LAZY_THRESHOLD = 10; private static final int MAX_SUPPORTED_LAG = 64; + private static final int MIN_VECTOR_SEQUENCES = 2; + private static final int MAX_VECTOR_SEQUENCES = 64; private final InterpreterVisitor dispatcher; private final TypeHandler typeSystem; @@ -218,6 +220,17 @@ public Object tryOptimizedExecution(For node, int loopId) { } } + List vectorRecurrencePatterns = extractVectorLinearRecurrencePatterns(node); + if (!vectorRecurrencePatterns.isEmpty()) { + try { + Object result = patternHandler.applyPatterns(node, vectorRecurrencePatterns); + ArrayTracker.markLoopOptimized(loopId); + return result; + } catch (Exception e) { + DebugSystem.debug("OPTIMIZER", "Vector recurrence pattern failed: " + e.getMessage()); + } + } + List multiArrayPatterns = extractMultiArraySequencePatterns(node); if (!multiArrayPatterns.isEmpty()) { try { @@ -282,6 +295,155 @@ public Object tryOptimizedExecution(For node, int loopId) { return null; } + public List extractVectorLinearRecurrencePatterns(For node) { + List results = new ArrayList(); + if (node == null || node.body == null || node.body.statements == null) { + return results; + } + + List assignments = collectVectorAssignments(node); + if (assignments.size() < MIN_VECTOR_SEQUENCES || assignments.size() > MAX_VECTOR_SEQUENCES) { + return results; + } + + List orderedTargets = new ArrayList(); + List targetExprs = new ArrayList(); + List targetArrays = new ArrayList(); + Map targetIndexByName = new LinkedHashMap(); + + for (Assignment assignment : assignments) { + IndexAccess leftAccess = (IndexAccess) assignment.left; + Identifier targetId = (Identifier) leftAccess.array; + String targetName = targetId.name; + if (targetIndexByName.containsKey(targetName)) { + return new ArrayList(); + } + Object resolved = dispatcher.dispatch(targetId); + resolved = typeSystem.unwrap(resolved); + if (!(resolved instanceof NaturalArray)) { + return new ArrayList(); + } + orderedTargets.add(targetName); + targetExprs.add(targetId); + targetArrays.add((NaturalArray) resolved); + targetIndexByName.put(targetName, targetIndexByName.size()); + } + + long expectedSize = targetArrays.get(0).size(); + for (int i = 1; i < targetArrays.size(); i++) { + if (targetArrays.get(i).size() != expectedSize) { + return new ArrayList(); + } + } + + int dimension = assignments.size(); + // coeffByLag[targetRow][lag][sourceColumn] + AutoStackingNumber[][][] coeffByLag = new AutoStackingNumber[dimension][MAX_SUPPORTED_LAG + 1][dimension]; + AutoStackingNumber[] constants = new AutoStackingNumber[dimension]; + for (int row = 0; row < dimension; row++) { + constants[row] = AutoStackingNumber.fromLong(0L); + for (int lag = 0; lag <= MAX_SUPPORTED_LAG; lag++) { + for (int col = 0; col < dimension; col++) { + coeffByLag[row][lag][col] = AutoStackingNumber.fromLong(0L); + } + } + } + + int maxLag = 0; + for (int row = 0; row < assignments.size(); row++) { + Assignment assign = assignments.get(row); + AutoStackingNumber[] constantRef = new AutoStackingNumber[]{AutoStackingNumber.fromLong(0L)}; + if (!collectVectorLinearTerms( + assign.right, + targetIndexByName, + node.iterator, + coeffByLag[row], + constantRef, + AutoStackingNumber.fromLong(1L))) { + return new ArrayList(); + } + constants[row] = constantRef[0]; + } + + for (int row = 0; row < dimension; row++) { + boolean hasDependency = false; + for (int lag = 1; lag <= MAX_SUPPORTED_LAG; lag++) { + for (int col = 0; col < dimension; col++) { + if (!coeffByLag[row][lag][col].isZero()) { + hasDependency = true; + if (lag > maxLag) maxLag = lag; + } + } + } + if (!hasDependency) { + return new ArrayList(); + } + } + if (maxLag <= 0) { + return new ArrayList(); + } + + long[] bounds = resolveLoopBounds(node); + if (bounds == null) { + return new ArrayList(); + } + long min = bounds[0]; + long max = bounds[1]; + long recurrenceStart = min; + if (recurrenceStart < maxLag) { + recurrenceStart = maxLag; + } + if (recurrenceStart > max) { + return new ArrayList(); + } + + long seedStart = recurrenceStart - maxLag; + AutoStackingNumber[][] seedValues = new AutoStackingNumber[dimension][maxLag]; + for (int seq = 0; seq < dimension; seq++) { + NaturalArray arr = targetArrays.get(seq); + for (int offset = 0; offset < maxLag; offset++) { + long seedIndex = seedStart + offset; + Object seedObj = arr.get(seedIndex); + AutoStackingNumber seedNum = typeSystem.toAutoStackingNumber(seedObj); + if (seedNum == null) { + return new ArrayList(); + } + seedValues[seq][offset] = seedNum; + } + } + + AutoStackingNumber[][] flatCoefficients = new AutoStackingNumber[dimension][dimension * maxLag]; + for (int row = 0; row < dimension; row++) { + for (int lag = 1; lag <= maxLag; lag++) { + for (int col = 0; col < dimension; col++) { + int flatCol = ((lag - 1) * dimension) + col; + flatCoefficients[row][flatCol] = coeffByLag[row][lag][col]; + } + } + } + + PatternHandler.VectorRecurrencePattern pattern = new PatternHandler.VectorRecurrencePattern( + targetExprs, + dimension, + maxLag, + flatCoefficients, + constants, + recurrenceStart, + seedStart, + seedValues, + targetIndexByName + ); + + for (Expr targetExpr : targetExprs) { + results.add(new PatternHandler.PatternResult( + PatternHandler.PatternType.VECTOR_LINEAR_RECURRENCE, + pattern, + targetExpr + )); + } + return results; + } + public PatternHandler.LinearRecurrencePattern extractLinearRecurrencePattern(For node) { if (node == null || node.body == null || node.body.statements == null) { return null; @@ -367,9 +529,9 @@ public PatternHandler.LinearRecurrencePattern extractLinearRecurrencePattern(For AutoStackingNumber[] seed = new AutoStackingNumber[order]; long seedStart = recurrenceStart - order; for (int i = 0; i < order; i++) { - long idxSeed = seedStart + i; - Object vObj = targetArray.get(idxSeed); - AutoStackingNumber v = typeSystem.toAutoStackingNumber(vObj); + long seedIndex = seedStart + i; + Object seedValue = targetArray.get(seedIndex); + AutoStackingNumber v = typeSystem.toAutoStackingNumber(seedValue); if (v == null) { return null; } @@ -444,6 +606,15 @@ private static class TermRef { TermRef(int lag) { this.lag = lag; } } + private static class VectorTermRef { + final int lag; + final int sequenceIndex; + VectorTermRef(int lag, int sequenceIndex) { + this.lag = lag; + this.sequenceIndex = sequenceIndex; + } + } + private TermRef extractIndexedTargetTerm(Expr expr, String targetArrayName, String iterator) { if (!(expr instanceof IndexAccess)) { return null; @@ -463,6 +634,111 @@ private TermRef extractIndexedTargetTerm(Expr expr, String targetArrayName, Stri return new TermRef(lag); } + private List collectVectorAssignments(For node) { + List assignments = new ArrayList(); + if (node == null || node.body == null || node.body.statements == null) { + return assignments; + } + for (Stmt stmt : node.body.statements) { + if (!(stmt instanceof Assignment)) { + return new ArrayList(); + } + Assignment assign = (Assignment) stmt; + if (assign.isDeclaration || !(assign.left instanceof IndexAccess)) { + return new ArrayList(); + } + IndexAccess access = (IndexAccess) assign.left; + if (!(access.array instanceof Identifier) || !(access.index instanceof Identifier)) { + return new ArrayList(); + } + Identifier idx = (Identifier) access.index; + if (!node.iterator.equals(idx.name)) { + return new ArrayList(); + } + assignments.add(assign); + } + return assignments; + } + + private boolean collectVectorLinearTerms( + Expr expr, + Map targetIndexByName, + String iterator, + AutoStackingNumber[][] coefficientsByLagAndSequence, + AutoStackingNumber[] constant, + AutoStackingNumber sign + ) { + if (expr == null) return false; + + if (expr instanceof BinaryOp) { + BinaryOp bin = (BinaryOp) expr; + if ("+".equals(bin.op)) { + return collectVectorLinearTerms(bin.left, targetIndexByName, iterator, coefficientsByLagAndSequence, constant, sign) && + collectVectorLinearTerms(bin.right, targetIndexByName, iterator, coefficientsByLagAndSequence, constant, sign); + } + if ("-".equals(bin.op)) { + return collectVectorLinearTerms(bin.left, targetIndexByName, iterator, coefficientsByLagAndSequence, constant, sign) && + collectVectorLinearTerms(bin.right, targetIndexByName, iterator, coefficientsByLagAndSequence, constant, + sign.multiply(AutoStackingNumber.fromLong(-1L))); + } + if ("*".equals(bin.op)) { + VectorTermRef ref = extractIndexedVectorTerm(bin.left, targetIndexByName, iterator); + AutoStackingNumber scalar = toNumericLiteral(bin.right); + if (ref == null || scalar == null) { + ref = extractIndexedVectorTerm(bin.right, targetIndexByName, iterator); + scalar = toNumericLiteral(bin.left); + } + if (ref != null && scalar != null) { + AutoStackingNumber delta = sign.multiply(scalar); + coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex] = + coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex].add(delta); + return true; + } + return false; + } + return false; + } + + VectorTermRef ref = extractIndexedVectorTerm(expr, targetIndexByName, iterator); + if (ref != null) { + coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex] = + coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex].add(sign); + return true; + } + + AutoStackingNumber literal = toNumericLiteral(expr); + if (literal != null) { + constant[0] = constant[0].add(sign.multiply(literal)); + return true; + } + + return false; + } + + private VectorTermRef extractIndexedVectorTerm( + Expr expr, + Map targetIndexByName, + String iterator + ) { + if (!(expr instanceof IndexAccess)) { + return null; + } + IndexAccess access = (IndexAccess) expr; + if (!(access.array instanceof Identifier)) { + return null; + } + String arrayName = ((Identifier) access.array).name; + Integer sequenceIndex = targetIndexByName.get(arrayName); + if (sequenceIndex == null) { + return null; + } + int lag = extractLag(access.index, iterator); + if (lag <= 0 || lag > MAX_SUPPORTED_LAG) { + return null; + } + return new VectorTermRef(lag, sequenceIndex.intValue()); + } + private int extractLag(Expr indexExpr, String iterator) { if (indexExpr instanceof BinaryOp) { BinaryOp bin = (BinaryOp) indexExpr; diff --git a/src/main/java/cod/interpreter/handler/PatternHandler.java b/src/main/java/cod/interpreter/handler/PatternHandler.java index 1d2c6a24..331dac45 100644 --- a/src/main/java/cod/interpreter/handler/PatternHandler.java +++ b/src/main/java/cod/interpreter/handler/PatternHandler.java @@ -10,6 +10,7 @@ import cod.range.formula.ConditionalFormula; import cod.range.formula.LinearRecurrenceFormula; import cod.range.formula.SequenceFormula; +import cod.range.formula.VectorRecurrenceFormula; import cod.range.pattern.ConditionalPattern; import cod.range.pattern.SequencePattern; @@ -19,7 +20,8 @@ public class PatternHandler { public enum PatternType { CONDITIONAL, SEQUENCE, - LINEAR_RECURRENCE + LINEAR_RECURRENCE, + VECTOR_LINEAR_RECURRENCE } public static class PatternResult { @@ -65,6 +67,40 @@ public LinearRecurrencePattern( } } + public static class VectorRecurrencePattern { + public final List targetArrays; + public final int dimension; + public final int order; + public final AutoStackingNumber[][] coefficients; + public final AutoStackingNumber[] constantTerms; + public final long recurrenceStart; + public final long seedStart; + public final AutoStackingNumber[][] seedValues; + public final Map targetIndexByName; + + public VectorRecurrencePattern( + List targetArrays, + int dimension, + int order, + AutoStackingNumber[][] coefficients, + AutoStackingNumber[] constantTerms, + long recurrenceStart, + long seedStart, + AutoStackingNumber[][] seedValues, + Map targetIndexByName + ) { + this.targetArrays = targetArrays; + this.dimension = dimension; + this.order = order; + this.coefficients = coefficients; + this.constantTerms = constantTerms; + this.recurrenceStart = recurrenceStart; + this.seedStart = seedStart; + this.seedValues = seedValues; + this.targetIndexByName = targetIndexByName; + } + } + private final InterpreterVisitor dispatcher; private final TypeHandler typeSystem; private final ExpressionHandler expressionHandler; @@ -95,6 +131,10 @@ public Object applyPatterns(For node, List patterns) { } try { + if (isVectorRecurrencePatternSet(patterns)) { + return applyVectorRecurrencePatterns(node, patterns); + } + List targetArrays = new ArrayList(); List> groupedPatterns = new ArrayList>(); Map arrayIdToGroupIndex = new HashMap(); @@ -285,4 +325,103 @@ public void applyLinearRecurrencePattern( throw new InternalError("Failed to apply linear recurrence pattern", e); } } + + private boolean isVectorRecurrencePatternSet(List patterns) { + if (patterns == null || patterns.isEmpty()) return false; + for (PatternResult result : patterns) { + if (result == null || result.type != PatternType.VECTOR_LINEAR_RECURRENCE) { + return false; + } + } + return true; + } + + private Object applyVectorRecurrencePatterns(For node, List patterns) { + PatternResult first = patterns.get(0); + if (!(first.pattern instanceof VectorRecurrencePattern)) { + throw new InternalError("Invalid vector recurrence pattern payload"); + } + VectorRecurrencePattern pattern = (VectorRecurrencePattern) first.pattern; + + long start = 0L; + long end = 0L; + boolean boundsFound = false; + + if (node.range != null) { + Object startObj = dispatcher.dispatch(node.range.start); + Object endObj = dispatcher.dispatch(node.range.end); + start = expressionHandler.toLong(startObj); + end = expressionHandler.toLong(endObj); + boundsFound = true; + } else if (node.arraySource != null) { + Object sourceObj = dispatcher.dispatch(node.arraySource); + sourceObj = typeSystem.unwrap(sourceObj); + if (sourceObj instanceof NaturalArray) { + NaturalArray sourceArr = (NaturalArray) sourceObj; + if (sourceArr.size() > 0) { + start = 0L; + end = sourceArr.size() - 1L; + boundsFound = true; + } + } else if (sourceObj instanceof List) { + List sourceList = (List) sourceObj; + if (!sourceList.isEmpty()) { + start = 0L; + end = sourceList.size() - 1L; + boundsFound = true; + } + } + } + + if (!boundsFound) { + DebugSystem.debug("OPTIMIZER", "Vector recurrence: unable to resolve loop bounds"); + return arrayOperationHandler.executeForLoopNormally(node); + } + + long min = Math.min(start, end); + long max = Math.max(start, end); + long formulaStart = Math.max(min, pattern.seedStart); + long formulaEnd = max; + if (formulaEnd < formulaStart) { + return arrayOperationHandler.executeForLoopNormally(node); + } + + VectorRecurrenceFormula formula = new VectorRecurrenceFormula( + formulaStart, + formulaEnd, + pattern.recurrenceStart, + pattern.seedStart, + pattern.dimension, + pattern.order, + pattern.coefficients, + pattern.constantTerms, + pattern.seedValues + ); + + List attachedArrays = new ArrayList(); + for (Expr targetExpr : pattern.targetArrays) { + Object resolvedArray = dispatcher.dispatch(targetExpr); + resolvedArray = typeSystem.unwrap(resolvedArray); + if (!(resolvedArray instanceof NaturalArray)) { + DebugSystem.debug("OPTIMIZER", "Vector recurrence target not NaturalArray; fallback"); + return arrayOperationHandler.executeForLoopNormally(node); + } + NaturalArray arr = (NaturalArray) resolvedArray; + if (!(targetExpr instanceof Identifier)) { + return arrayOperationHandler.executeForLoopNormally(node); + } + String name = ((Identifier) targetExpr).name; + Integer seqIndex = pattern.targetIndexByName.get(name); + if (seqIndex == null) { + return arrayOperationHandler.executeForLoopNormally(node); + } + arr.addVectorRecurrenceFormula(formula, seqIndex.intValue()); + attachedArrays.add(arr); + } + + if (attachedArrays.isEmpty()) { + return arrayOperationHandler.executeForLoopNormally(node); + } + return attachedArrays.get(attachedArrays.size() - 1); + } } diff --git a/src/main/java/cod/range/NaturalArray.java b/src/main/java/cod/range/NaturalArray.java index 91519054..bc572ab4 100644 --- a/src/main/java/cod/range/NaturalArray.java +++ b/src/main/java/cod/range/NaturalArray.java @@ -63,6 +63,7 @@ public class NaturalArray { private List sequenceFormulas = new ArrayList(); private List conditionalFormulas = new ArrayList(); private List linearRecurrenceFormulas = new ArrayList(); + private List vectorRecurrenceFormulas = new ArrayList(); private Map computedCache = new HashMap(); // Pending updates for lazy assignment @@ -236,6 +237,16 @@ public int compareTo(PendingRangeUpdate other) { } } + private static class VectorRecurrenceBinding { + final VectorRecurrenceFormula formula; + final int sequenceIndex; + + VectorRecurrenceBinding(VectorRecurrenceFormula formula, int sequenceIndex) { + this.formula = formula; + this.sequenceIndex = sequenceIndex; + } + } + // ========== CONSTRUCTORS ========== public NaturalArray(Range range, Evaluator evaluator, ExecutionContext context) { @@ -832,6 +843,15 @@ public Object get(long index) { } // Then linear recurrence formulas + Object vectorRecurrenceResult = evaluateVectorRecurrenceFormulas(index); + if (vectorRecurrenceResult != null) { + lastIndex = index; + lastValue = vectorRecurrenceResult; + updateRecentCache(index, vectorRecurrenceResult); + return maybeConvert(vectorRecurrenceResult); + } + + // Then scalar linear recurrence formulas Object recurrenceResult = evaluateLinearRecurrenceFormulas(index); if (recurrenceResult != null) { lastIndex = index; @@ -1437,6 +1457,22 @@ public void addLinearRecurrenceFormula(LinearRecurrenceFormula formula) { clearCache(); } + public void addVectorRecurrenceFormula(VectorRecurrenceFormula formula, int sequenceIndex) { + if (formula == null) { + throw new InternalError("Attempted to add null VectorRecurrenceFormula"); + } + if (sequenceIndex < 0 || sequenceIndex >= formula.dimension) { + throw new ProgramError("Invalid vector recurrence sequence index: " + sequenceIndex); + } + + if (tracked) { + ArrayTracker.recordFormulaApplication(this); + } + + vectorRecurrenceFormulas.add(new VectorRecurrenceBinding(formula, sequenceIndex)); + clearCache(); + } + public void clearCache() { if (computedCache != null) { computedCache.clear(); @@ -1537,6 +1573,35 @@ private Object evaluateLinearRecurrenceFormulas(long index) { return null; } + private Object evaluateVectorRecurrenceFormulas(long index) { + if (vectorRecurrenceFormulas.isEmpty()) return null; + + for (int i = vectorRecurrenceFormulas.size() - 1; i >= 0; i--) { + VectorRecurrenceBinding binding = vectorRecurrenceFormulas.get(i); + if (binding == null || binding.formula == null) { + throw new InternalError("Null VectorRecurrenceFormula binding in list"); + } + if (binding.formula.contains(index)) { + try { + Object result = binding.formula.evaluate(index, binding.sequenceIndex); + if (result != null) { + if (computedCache == null) { + computedCache = new HashMap(); + } + computedCache.put(index, result); + } + return result; + } catch (ProgramError e) { + throw e; + } catch (Exception e) { + throw new InternalError( + "Vector recurrence formula evaluation failed at index " + index, e); + } + } + } + return null; + } + // ========== OUTPUT CACHING METHODS ========== public void recordOutput(long index, Object value) { @@ -1694,6 +1759,9 @@ public String toString() { if (!linearRecurrenceFormulas.isEmpty()) { sb.append("\n Linear recurrence formulas: ").append(linearRecurrenceFormulas.size()); } + if (!vectorRecurrenceFormulas.isEmpty()) { + sb.append("\n Vector recurrence formulas: ").append(vectorRecurrenceFormulas.size()); + } return sb.toString(); } catch (ProgramError e) { diff --git a/src/main/java/cod/range/formula/LinearRecurrenceFormula.java b/src/main/java/cod/range/formula/LinearRecurrenceFormula.java index d0bc02be..bfe21854 100644 --- a/src/main/java/cod/range/formula/LinearRecurrenceFormula.java +++ b/src/main/java/cod/range/formula/LinearRecurrenceFormula.java @@ -1,6 +1,7 @@ package cod.range.formula; import cod.math.AutoStackingNumber; +import java.util.Arrays; public class LinearRecurrenceFormula { public final long start; @@ -14,6 +15,8 @@ public class LinearRecurrenceFormula { private final boolean hasConstantTerm; private final LinearRecurrenceFormula newerFormula; private final LinearRecurrenceFormula olderFormula; + private transient long rollingIndex = Long.MIN_VALUE; + private transient AutoStackingNumber[] rollingState = null; private static final AutoStackingNumber ZERO = AutoStackingNumber.fromLong(0L); private static final AutoStackingNumber ONE = AutoStackingNumber.fromLong(1L); @@ -37,6 +40,7 @@ public LinearRecurrenceFormula( this.hasConstantTerm = !this.constantTerm.isZero(); this.newerFormula = null; this.olderFormula = null; + resetRollingState(); } private LinearRecurrenceFormula(long start, long end, @@ -53,6 +57,7 @@ private LinearRecurrenceFormula(long start, long end, this.hasConstantTerm = false; this.newerFormula = newerFormula; this.olderFormula = olderFormula; + resetRollingState(); } public static LinearRecurrenceFormula compose(LinearRecurrenceFormula newerFormula, LinearRecurrenceFormula olderFormula) { @@ -94,10 +99,38 @@ public Object evaluate(long index) { return null; } + synchronized (this) { + if (rollingState != null && index == rollingIndex) { + return rollingState[0]; + } + if (rollingState != null && index == rollingIndex + 1L) { + advanceRollingState(); + rollingIndex = index; + return rollingState[0]; + } + } + long lastSeedIndex = recurrenceStart - 1L; long steps = index - lastSeedIndex; int dim = hasConstantTerm ? order + 1 : order; + AutoStackingNumber[][] transition = buildTransition(dim); + AutoStackingNumber[] state = buildBaseState(dim); + + AutoStackingNumber[][] power = matrixPow(transition, steps); + AutoStackingNumber[] result = multiply(power, state); + synchronized (this) { + rollingState = copyState(result); + rollingIndex = index; + } + return result[0]; + } + + private boolean isComposite() { + return newerFormula != null || olderFormula != null; + } + + private AutoStackingNumber[][] buildTransition(int dim) { AutoStackingNumber[][] transition = new AutoStackingNumber[dim][dim]; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { @@ -118,6 +151,10 @@ public Object evaluate(long index) { transition[last][last] = ONE; } + return transition; + } + + private AutoStackingNumber[] buildBaseState(int dim) { AutoStackingNumber[] state = new AutoStackingNumber[dim]; for (int j = 0; j < order; j++) { state[j] = seedValues[order - 1 - j]; @@ -125,14 +162,37 @@ public Object evaluate(long index) { if (hasConstantTerm) { state[dim - 1] = ONE; } + return state; + } - AutoStackingNumber[][] power = matrixPow(transition, steps); - AutoStackingNumber[] result = multiply(power, state); - return result[0]; + private AutoStackingNumber[] copyState(AutoStackingNumber[] state) { + return Arrays.copyOf(state, state.length); } - private boolean isComposite() { - return newerFormula != null || olderFormula != null; + private void advanceRollingState() { + AutoStackingNumber next = hasConstantTerm ? constantTerm : ZERO; + for (int lag = 1; lag <= order; lag++) { + AutoStackingNumber coeff = coefficientsByLag[lag - 1]; + if (coeff != null && !coeff.isZero()) { + next = next.add(coeff.multiply(rollingState[lag - 1])); + } + } + + AutoStackingNumber[] nextState = new AutoStackingNumber[rollingState.length]; + nextState[0] = next; + for (int i = 1; i < order; i++) { + nextState[i] = rollingState[i - 1]; + } + + if (hasConstantTerm) { + nextState[nextState.length - 1] = ONE; + } + rollingState = nextState; + } + + private void resetRollingState() { + rollingIndex = Long.MIN_VALUE; + rollingState = null; } private AutoStackingNumber[][] matrixPow(AutoStackingNumber[][] base, long exp) { diff --git a/src/main/java/cod/range/formula/VectorRecurrenceFormula.java b/src/main/java/cod/range/formula/VectorRecurrenceFormula.java new file mode 100644 index 00000000..b91d9c25 --- /dev/null +++ b/src/main/java/cod/range/formula/VectorRecurrenceFormula.java @@ -0,0 +1,269 @@ +package cod.range.formula; + +import cod.math.AutoStackingNumber; + +import java.util.Arrays; + +public class VectorRecurrenceFormula { + public final long start; + public final long end; + public final long recurrenceStart; + public final long seedStartIndex; + public final int dimension; + public final int order; + public final AutoStackingNumber[][] coefficients; + public final AutoStackingNumber[] constant; + public final AutoStackingNumber[][] seedValues; + private final boolean hasConstantTerm; + private transient long rollingIndex = Long.MIN_VALUE; + private transient AutoStackingNumber[] rollingState = null; + + private static final AutoStackingNumber ZERO = AutoStackingNumber.fromLong(0L); + private static final AutoStackingNumber ONE = AutoStackingNumber.fromLong(1L); + + public VectorRecurrenceFormula( + long start, + long end, + long recurrenceStart, + long seedStartIndex, + int dimension, + int order, + AutoStackingNumber[][] coefficients, + AutoStackingNumber[] constant, + AutoStackingNumber[][] seedValues + ) { + this.start = start; + this.end = end; + this.recurrenceStart = recurrenceStart; + this.seedStartIndex = seedStartIndex; + this.dimension = dimension; + this.order = order; + this.coefficients = coefficients; + this.constant = constant != null ? constant : zerosVector(dimension); + this.seedValues = seedValues; + this.hasConstantTerm = hasNonZeroConstant(this.constant); + resetRollingState(); + } + + public boolean contains(long index) { + return index >= start && index <= end; + } + + public synchronized Object evaluate(long index, int sequenceIndex) { + if (sequenceIndex < 0 || sequenceIndex >= dimension) { + return null; + } + if (order <= 0 || dimension <= 0 || coefficients == null || seedValues == null) { + return null; + } + if (seedValues.length != dimension) { + return null; + } + for (int i = 0; i < dimension; i++) { + if (seedValues[i] == null || seedValues[i].length != order) { + return null; + } + } + + if (index < recurrenceStart) { + Integer seedOffset = validateOffsetInIntRange(index - seedStartIndex); + if (seedOffset == null) { + return null; + } + if (seedOffset.intValue() >= order) { + return null; + } + return seedValues[sequenceIndex][seedOffset.intValue()]; + } + + if (rollingState != null && index == rollingIndex) { + return rollingState[sequenceIndex]; + } + if (rollingState != null && index == rollingIndex + 1L) { + advanceRollingState(); + rollingIndex = index; + return rollingState[sequenceIndex]; + } + + long lastSeedIndex = recurrenceStart - 1L; + long steps = index - lastSeedIndex; + int baseDim = dimension * order; + int matrixDim = hasConstantTerm ? baseDim + 1 : baseDim; + + AutoStackingNumber[][] transition = buildTransition(baseDim, matrixDim); + AutoStackingNumber[] state = buildBaseState(baseDim, matrixDim); + AutoStackingNumber[][] power = matrixPow(transition, steps); + AutoStackingNumber[] result = multiply(power, state); + + rollingState = Arrays.copyOf(result, baseDim); + rollingIndex = index; + return rollingState[sequenceIndex]; + } + + private AutoStackingNumber[][] buildTransition(int baseDim, int matrixDim) { + AutoStackingNumber[][] transition = new AutoStackingNumber[matrixDim][matrixDim]; + for (int i = 0; i < matrixDim; i++) { + for (int j = 0; j < matrixDim; j++) { + transition[i][j] = ZERO; + } + } + + for (int row = 0; row < dimension; row++) { + AutoStackingNumber[] coeffRow = coefficients[row]; + if (coeffRow == null || coeffRow.length != baseDim) { + continue; + } + for (int col = 0; col < baseDim; col++) { + AutoStackingNumber c = coeffRow[col]; + transition[row][col] = c != null ? c : ZERO; + } + } + + for (int block = 1; block < order; block++) { + for (int seq = 0; seq < dimension; seq++) { + int row = (block * dimension) + seq; + int col = ((block - 1) * dimension) + seq; + transition[row][col] = ONE; + } + } + + if (hasConstantTerm) { + int constCol = matrixDim - 1; + for (int row = 0; row < dimension; row++) { + AutoStackingNumber c = constant[row]; + transition[row][constCol] = c != null ? c : ZERO; + } + transition[constCol][constCol] = ONE; + } + return transition; + } + + private AutoStackingNumber[] buildBaseState(int baseDim, int matrixDim) { + AutoStackingNumber[] state = new AutoStackingNumber[matrixDim]; + for (int i = 0; i < matrixDim; i++) { + state[i] = ZERO; + } + for (int block = 0; block < order; block++) { + long sourceIndex = (recurrenceStart - 1L) - block; + Integer seedOffset = validateOffsetInIntRange(sourceIndex - seedStartIndex); + if (seedOffset == null) { + return null; + } + for (int seq = 0; seq < dimension; seq++) { + state[(block * dimension) + seq] = seedValues[seq][seedOffset.intValue()]; + } + } + if (hasConstantTerm) { + state[matrixDim - 1] = ONE; + } + return state; + } + + private void advanceRollingState() { + int baseDim = dimension * order; + AutoStackingNumber[] nextState = new AutoStackingNumber[baseDim]; + + for (int row = 0; row < dimension; row++) { + AutoStackingNumber sum = constant[row] != null ? constant[row] : ZERO; + AutoStackingNumber[] coeffRow = coefficients[row]; + for (int col = 0; col < baseDim; col++) { + AutoStackingNumber c = coeffRow[col]; + if (c != null && !c.isZero()) { + sum = sum.add(c.multiply(rollingState[col])); + } + } + nextState[row] = sum; + } + + for (int block = 1; block < order; block++) { + for (int seq = 0; seq < dimension; seq++) { + nextState[(block * dimension) + seq] = rollingState[((block - 1) * dimension) + seq]; + } + } + + rollingState = nextState; + } + + private AutoStackingNumber[][] matrixPow(AutoStackingNumber[][] base, long exp) { + int dim = base.length; + AutoStackingNumber[][] result = identity(dim); + AutoStackingNumber[][] current = base; + long e = exp; + while (e > 0) { + if ((e & 1L) == 1L) { + result = multiply(result, current); + } + e >>= 1; + if (e > 0) { + current = multiply(current, current); + } + } + return result; + } + + private AutoStackingNumber[][] identity(int dim) { + AutoStackingNumber[][] id = new AutoStackingNumber[dim][dim]; + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + id[i][j] = (i == j) ? ONE : ZERO; + } + } + return id; + } + + private AutoStackingNumber[][] multiply(AutoStackingNumber[][] a, AutoStackingNumber[][] b) { + int n = a.length; + AutoStackingNumber[][] out = new AutoStackingNumber[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + AutoStackingNumber sum = ZERO; + for (int k = 0; k < n; k++) { + sum = sum.add(a[i][k].multiply(b[k][j])); + } + out[i][j] = sum; + } + } + return out; + } + + private AutoStackingNumber[] multiply(AutoStackingNumber[][] a, AutoStackingNumber[] v) { + int n = a.length; + AutoStackingNumber[] out = new AutoStackingNumber[n]; + for (int i = 0; i < n; i++) { + AutoStackingNumber sum = ZERO; + for (int k = 0; k < n; k++) { + sum = sum.add(a[i][k].multiply(v[k])); + } + out[i] = sum; + } + return out; + } + + private static boolean hasNonZeroConstant(AutoStackingNumber[] constantVector) { + if (constantVector == null) return false; + for (AutoStackingNumber n : constantVector) { + if (n != null && !n.isZero()) return true; + } + return false; + } + + private static AutoStackingNumber[] zerosVector(int length) { + AutoStackingNumber[] out = new AutoStackingNumber[length]; + for (int i = 0; i < length; i++) { + out[i] = ZERO; + } + return out; + } + + private Integer validateOffsetInIntRange(long value) { + if (value < 0L || value > Integer.MAX_VALUE) { + return null; + } + return Integer.valueOf((int) value); + } + + private void resetRollingState() { + rollingIndex = Long.MIN_VALUE; + rollingState = null; + } +}