Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified source_.jar
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
unit test.linearrecurrenceoptimization

share main() {
out("=== Vector linear recurrence optimization ===")
start := timer()

a := [0 to 200]
b := [0 to 200]

a[0] = 1
a[1] = 2
b[0] = 0
b[1] = 1

for i of [2 to 200] {
a[i] = a[i-1] + b[i-1]
b[i] = a[i-1] - b[i-1]
}

out("pair1-a[10]=" + a[10] + " expected=48")
out("pair1-b[10]=" + b[10] + " expected=16")
out("pair1-a[20]=" + a[20] + " expected=1536")
out("pair1-b[20]=" + b[20] + " expected=512")

c := [0 to 200]
d := [0 to 200]

c[0] = 1
c[1] = 2
d[0] = 3
d[1] = 4

for i of [2 to 200] {
c[i] = c[i-1] + d[i-2]
d[i] = 2 * c[i-2] - d[i-1] + 3
}

out("pair2-c[10]=" + c[10] + " expected=95")
out("pair2-d[10]=" + d[10] + " expected=41")
out("pair2-c[20]=" + c[20] + " expected=3070")
out("pair2-d[20]=" + d[20] + " expected=1367")
out("elapsed_ms=" + (timer() - start))
}
282 changes: 279 additions & 3 deletions src/main/java/cod/interpreter/handler/LoopOptimizationHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
public class LoopOptimizationHandler {
private static final int LAZY_THRESHOLD = 10;
private static final int MAX_SUPPORTED_LAG = 64;
private static final int MIN_VECTOR_SEQUENCES = 2;
private static final int MAX_VECTOR_SEQUENCES = 64;

private final InterpreterVisitor dispatcher;
private final TypeHandler typeSystem;
Expand Down Expand Up @@ -218,6 +220,17 @@ public Object tryOptimizedExecution(For node, int loopId) {
}
}

List<PatternHandler.PatternResult> vectorRecurrencePatterns = extractVectorLinearRecurrencePatterns(node);
if (!vectorRecurrencePatterns.isEmpty()) {
try {
Object result = patternHandler.applyPatterns(node, vectorRecurrencePatterns);
ArrayTracker.markLoopOptimized(loopId);
return result;
} catch (Exception e) {
DebugSystem.debug("OPTIMIZER", "Vector recurrence pattern failed: " + e.getMessage());
}
}

List<PatternHandler.PatternResult> multiArrayPatterns = extractMultiArraySequencePatterns(node);
if (!multiArrayPatterns.isEmpty()) {
try {
Expand Down Expand Up @@ -282,6 +295,155 @@ public Object tryOptimizedExecution(For node, int loopId) {
return null;
}

public List<PatternHandler.PatternResult> extractVectorLinearRecurrencePatterns(For node) {
List<PatternHandler.PatternResult> results = new ArrayList<PatternHandler.PatternResult>();
if (node == null || node.body == null || node.body.statements == null) {
return results;
}

List<Assignment> assignments = collectVectorAssignments(node);
if (assignments.size() < MIN_VECTOR_SEQUENCES || assignments.size() > MAX_VECTOR_SEQUENCES) {
return results;
}

List<String> orderedTargets = new ArrayList<String>();
List<Expr> targetExprs = new ArrayList<Expr>();
List<NaturalArray> targetArrays = new ArrayList<NaturalArray>();
Map<String, Integer> targetIndexByName = new LinkedHashMap<String, Integer>();

for (Assignment assignment : assignments) {
IndexAccess leftAccess = (IndexAccess) assignment.left;
Identifier targetId = (Identifier) leftAccess.array;
String targetName = targetId.name;
if (targetIndexByName.containsKey(targetName)) {
return new ArrayList<PatternHandler.PatternResult>();
}
Object resolved = dispatcher.dispatch(targetId);
resolved = typeSystem.unwrap(resolved);
if (!(resolved instanceof NaturalArray)) {
return new ArrayList<PatternHandler.PatternResult>();
}
orderedTargets.add(targetName);
targetExprs.add(targetId);
targetArrays.add((NaturalArray) resolved);
targetIndexByName.put(targetName, targetIndexByName.size());
}

long expectedSize = targetArrays.get(0).size();
for (int i = 1; i < targetArrays.size(); i++) {
if (targetArrays.get(i).size() != expectedSize) {
return new ArrayList<PatternHandler.PatternResult>();
}
}

int dimension = assignments.size();
// coeffByLag[targetRow][lag][sourceColumn]
AutoStackingNumber[][][] coeffByLag = new AutoStackingNumber[dimension][MAX_SUPPORTED_LAG + 1][dimension];
AutoStackingNumber[] constants = new AutoStackingNumber[dimension];
for (int row = 0; row < dimension; row++) {
constants[row] = AutoStackingNumber.fromLong(0L);
for (int lag = 0; lag <= MAX_SUPPORTED_LAG; lag++) {
for (int col = 0; col < dimension; col++) {
coeffByLag[row][lag][col] = AutoStackingNumber.fromLong(0L);
}
}
}

int maxLag = 0;
for (int row = 0; row < assignments.size(); row++) {
Assignment assign = assignments.get(row);
AutoStackingNumber[] constantRef = new AutoStackingNumber[]{AutoStackingNumber.fromLong(0L)};
if (!collectVectorLinearTerms(
assign.right,
targetIndexByName,
node.iterator,
coeffByLag[row],
constantRef,
AutoStackingNumber.fromLong(1L))) {
return new ArrayList<PatternHandler.PatternResult>();
}
constants[row] = constantRef[0];
}

for (int row = 0; row < dimension; row++) {
boolean hasDependency = false;
for (int lag = 1; lag <= MAX_SUPPORTED_LAG; lag++) {
for (int col = 0; col < dimension; col++) {
if (!coeffByLag[row][lag][col].isZero()) {
hasDependency = true;
if (lag > maxLag) maxLag = lag;
}
}
}
if (!hasDependency) {
return new ArrayList<PatternHandler.PatternResult>();
}
}
if (maxLag <= 0) {
return new ArrayList<PatternHandler.PatternResult>();
}

long[] bounds = resolveLoopBounds(node);
if (bounds == null) {
return new ArrayList<PatternHandler.PatternResult>();
}
long min = bounds[0];
long max = bounds[1];
long recurrenceStart = min;
if (recurrenceStart < maxLag) {
recurrenceStart = maxLag;
}
if (recurrenceStart > max) {
return new ArrayList<PatternHandler.PatternResult>();
}

long seedStart = recurrenceStart - maxLag;
AutoStackingNumber[][] seedValues = new AutoStackingNumber[dimension][maxLag];
for (int seq = 0; seq < dimension; seq++) {
NaturalArray arr = targetArrays.get(seq);
for (int offset = 0; offset < maxLag; offset++) {
long seedIndex = seedStart + offset;
Object seedObj = arr.get(seedIndex);
AutoStackingNumber seedNum = typeSystem.toAutoStackingNumber(seedObj);
if (seedNum == null) {
return new ArrayList<PatternHandler.PatternResult>();
}
seedValues[seq][offset] = seedNum;
}
}

AutoStackingNumber[][] flatCoefficients = new AutoStackingNumber[dimension][dimension * maxLag];
for (int row = 0; row < dimension; row++) {
for (int lag = 1; lag <= maxLag; lag++) {
for (int col = 0; col < dimension; col++) {
int flatCol = ((lag - 1) * dimension) + col;
flatCoefficients[row][flatCol] = coeffByLag[row][lag][col];
}
}
}

PatternHandler.VectorRecurrencePattern pattern = new PatternHandler.VectorRecurrencePattern(
targetExprs,
dimension,
maxLag,
flatCoefficients,
constants,
recurrenceStart,
seedStart,
seedValues,
targetIndexByName
);

for (Expr targetExpr : targetExprs) {
results.add(new PatternHandler.PatternResult(
PatternHandler.PatternType.VECTOR_LINEAR_RECURRENCE,
pattern,
targetExpr
));
}
return results;
}

public PatternHandler.LinearRecurrencePattern extractLinearRecurrencePattern(For node) {
if (node == null || node.body == null || node.body.statements == null) {
return null;
Expand Down Expand Up @@ -367,9 +529,9 @@ public PatternHandler.LinearRecurrencePattern extractLinearRecurrencePattern(For
AutoStackingNumber[] seed = new AutoStackingNumber[order];
long seedStart = recurrenceStart - order;
for (int i = 0; i < order; i++) {
long idxSeed = seedStart + i;
Object vObj = targetArray.get(idxSeed);
AutoStackingNumber v = typeSystem.toAutoStackingNumber(vObj);
long seedIndex = seedStart + i;
Object seedValue = targetArray.get(seedIndex);
AutoStackingNumber v = typeSystem.toAutoStackingNumber(seedValue);
if (v == null) {
return null;
}
Expand Down Expand Up @@ -444,6 +606,15 @@ private static class TermRef {
TermRef(int lag) { this.lag = lag; }
}

private static class VectorTermRef {
final int lag;
final int sequenceIndex;
VectorTermRef(int lag, int sequenceIndex) {
this.lag = lag;
this.sequenceIndex = sequenceIndex;
}
}

private TermRef extractIndexedTargetTerm(Expr expr, String targetArrayName, String iterator) {
if (!(expr instanceof IndexAccess)) {
return null;
Expand All @@ -463,6 +634,111 @@ private TermRef extractIndexedTargetTerm(Expr expr, String targetArrayName, Stri
return new TermRef(lag);
}

private List<Assignment> collectVectorAssignments(For node) {
List<Assignment> assignments = new ArrayList<Assignment>();
if (node == null || node.body == null || node.body.statements == null) {
return assignments;
}
for (Stmt stmt : node.body.statements) {
if (!(stmt instanceof Assignment)) {
return new ArrayList<Assignment>();
}
Assignment assign = (Assignment) stmt;
if (assign.isDeclaration || !(assign.left instanceof IndexAccess)) {
return new ArrayList<Assignment>();
}
IndexAccess access = (IndexAccess) assign.left;
if (!(access.array instanceof Identifier) || !(access.index instanceof Identifier)) {
return new ArrayList<Assignment>();
}
Identifier idx = (Identifier) access.index;
if (!node.iterator.equals(idx.name)) {
return new ArrayList<Assignment>();
}
assignments.add(assign);
}
return assignments;
}

private boolean collectVectorLinearTerms(
Expr expr,
Map<String, Integer> targetIndexByName,
String iterator,
AutoStackingNumber[][] coefficientsByLagAndSequence,
AutoStackingNumber[] constant,
AutoStackingNumber sign
) {
if (expr == null) return false;

if (expr instanceof BinaryOp) {
BinaryOp bin = (BinaryOp) expr;
if ("+".equals(bin.op)) {
return collectVectorLinearTerms(bin.left, targetIndexByName, iterator, coefficientsByLagAndSequence, constant, sign) &&
collectVectorLinearTerms(bin.right, targetIndexByName, iterator, coefficientsByLagAndSequence, constant, sign);
}
if ("-".equals(bin.op)) {
return collectVectorLinearTerms(bin.left, targetIndexByName, iterator, coefficientsByLagAndSequence, constant, sign) &&
collectVectorLinearTerms(bin.right, targetIndexByName, iterator, coefficientsByLagAndSequence, constant,
sign.multiply(AutoStackingNumber.fromLong(-1L)));
}
if ("*".equals(bin.op)) {
VectorTermRef ref = extractIndexedVectorTerm(bin.left, targetIndexByName, iterator);
AutoStackingNumber scalar = toNumericLiteral(bin.right);
if (ref == null || scalar == null) {
ref = extractIndexedVectorTerm(bin.right, targetIndexByName, iterator);
scalar = toNumericLiteral(bin.left);
}
if (ref != null && scalar != null) {
AutoStackingNumber delta = sign.multiply(scalar);
coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex] =
coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex].add(delta);
return true;
}
return false;
}
return false;
}

VectorTermRef ref = extractIndexedVectorTerm(expr, targetIndexByName, iterator);
if (ref != null) {
coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex] =
coefficientsByLagAndSequence[ref.lag][ref.sequenceIndex].add(sign);
return true;
}

AutoStackingNumber literal = toNumericLiteral(expr);
if (literal != null) {
constant[0] = constant[0].add(sign.multiply(literal));
return true;
}

return false;
}

private VectorTermRef extractIndexedVectorTerm(
Expr expr,
Map<String, Integer> targetIndexByName,
String iterator
) {
if (!(expr instanceof IndexAccess)) {
return null;
}
IndexAccess access = (IndexAccess) expr;
if (!(access.array instanceof Identifier)) {
return null;
}
String arrayName = ((Identifier) access.array).name;
Integer sequenceIndex = targetIndexByName.get(arrayName);
if (sequenceIndex == null) {
return null;
}
int lag = extractLag(access.index, iterator);
if (lag <= 0 || lag > MAX_SUPPORTED_LAG) {
return null;
}
return new VectorTermRef(lag, sequenceIndex.intValue());
}

private int extractLag(Expr indexExpr, String iterator) {
if (indexExpr instanceof BinaryOp) {
BinaryOp bin = (BinaryOp) indexExpr;
Expand Down
Loading