Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 66 additions & 85 deletions arimo/compiler/backend/IRLower.arm
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ public class IRLower {
private isIntegerExpr(expr: Expr) : Boolean {
if (expr.kind == ExprKind.INT_LIT) { return true; }
if (expr.kind == ExprKind.BOOL_LIT) { return true; }
if (expr.kind == ExprKind.IDENT) {
String cls = this.varClassOf(expr.strVal);
return cls == "Integer" || cls == "Boolean";
}
if (expr.kind == ExprKind.FIELD) {
String cls = this.inferClass(expr);
return cls == "Integer" || cls == "Boolean";
}
if (expr.kind == ExprKind.BINOP) {
Integer bop = expr.op;
if (bop == BinaryOp.ADD || bop == BinaryOp.SUB ||
Expand Down Expand Up @@ -2923,109 +2931,82 @@ public class IRLower {
this.emit(IRInstr.ret(buf));
}

// Simplified i64_to_str: fixed 32-byte buffer, write digits right-to-left.
// Avoids pre-count loop and minimizes slot count for SafeRegAlloc.
private generateI64ToStr() {
this.beginFn("__arimo_i64_to_str", IRType.PTR);
this.addParamToLast("n", IRType.I64);
this.resetFnContext();
this.emit(IRInstr.label("entry"));
IRValue nv = IRValue.reg("n", IRType.I64);
IRValue nv = IRValue.reg("n", IRType.I64);

// Allocate fixed 32-byte buffer
IRValue buf = this.emitHeapAlloc(IRValue.ofInt(32, IRType.I64));

// Special case: n == 0
// n == 0 → write "0" and return
this.emit(IRInstr.cmp(nv, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JNE, "its_nonzero"));
IRValue zeroBuf = this.emitHeapAlloc(IRValue.ofInt(4, IRType.I64));
String zp0R = this.newReg();
IRValue zp0 = IRValue.reg(zp0R, IRType.PTR);
this.emit(IRInstr.mov(zp0, zeroBuf));
this.emit(IRInstr.store(IRValue.ofInt(48, IRType.I64), zp0, IRType.I8)); // '0'=48
String zp1R = this.newReg();
IRValue zp1 = IRValue.reg(zp1R, IRType.PTR);
this.emit(IRInstr.binop(IROpcode.ADD, zp1, zeroBuf, IRValue.ofInt(1, IRType.I64)));
this.emit(IRInstr.store(IRValue.ofInt(0, IRType.I64), zp1, IRType.I8));
this.emit(IRInstr.ret(zeroBuf));

this.emit(IRInstr.label("its_nonzero"));
// neg = 0; abs_n = n; if n < 0: neg=1, abs_n = 0-n
IRValue neg = IRValue.reg("its_neg", IRType.I64);
IRValue abs_n = IRValue.reg("its_absn", IRType.I64);
this.emit(IRInstr.mov(neg, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JNE, "its_nz"));
this.emit(IRInstr.store(IRValue.ofInt(48, IRType.I8), buf, IRType.I8));
String b1R = this.newReg(); IRValue b1 = IRValue.reg(b1R, IRType.PTR);
this.emit(IRInstr.binop(IROpcode.ADD, b1, buf, IRValue.ofInt(1, IRType.I64)));
this.emit(IRInstr.store(IRValue.ofInt(0, IRType.I8), b1, IRType.I8));
this.emit(IRInstr.ret(buf));

this.emit(IRInstr.label("its_nz"));
// abs_n = n; neg = false
IRValue abs_n = IRValue.reg("its_an", IRType.I64);
this.emit(IRInstr.mov(abs_n, nv));
IRValue neg = IRValue.reg("its_ng", IRType.I64);
this.emit(IRInstr.mov(neg, IRValue.ofInt(0, IRType.I64)));
// if n < 0: neg = true, abs_n = 0 - n
this.emit(IRInstr.cmp(nv, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JGE, "its_pos"));
this.emit(IRInstr.mov(neg, IRValue.ofInt(1, IRType.I64)));
String anR = this.newReg();
IRValue an = IRValue.reg(anR, IRType.I64);
String anR = this.newReg(); IRValue an = IRValue.reg(anR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.SUB, an, IRValue.ofInt(0, IRType.I64), nv));
this.emit(IRInstr.mov(abs_n, an));
this.emit(IRInstr.label("its_pos"));

// Count digits
IRValue cnt = IRValue.reg("its_cnt", IRType.I64);
IRValue tmp = IRValue.reg("its_tmp", IRType.I64);
this.emit(IRInstr.mov(cnt, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.mov(tmp, abs_n));
this.emit(IRInstr.label("its_count"));
this.emit(IRInstr.cmp(tmp, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JLE, "its_counted"));
this.emit(IRInstr.binop(IROpcode.ADD, cnt, cnt, IRValue.ofInt(1, IRType.I64)));
String dvrR = this.newReg();
IRValue dvr = IRValue.reg(dvrR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.DIV, dvr, tmp, IRValue.ofInt(10, IRType.I64)));
this.emit(IRInstr.mov(tmp, dvr));
this.emit(IRInstr.jmp("its_count"));
this.emit(IRInstr.label("its_counted"));

// total = cnt + neg
String totR = this.newReg();
IRValue tot = IRValue.reg(totR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.ADD, tot, cnt, neg));
String aszR = this.newReg();
IRValue asz = IRValue.reg(aszR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.ADD, asz, tot, IRValue.ofInt(1, IRType.I64)));

IRValue buf = this.emitHeapAlloc(asz);

// null terminate
String nlPR = this.newReg();
IRValue nlP = IRValue.reg(nlPR, IRType.PTR);
this.emit(IRInstr.binop(IROpcode.ADD, nlP, buf, tot));
this.emit(IRInstr.store(IRValue.ofInt(0, IRType.I64), nlP, IRType.I8));
// Write '-' at buf[0] if negative
this.emit(IRInstr.cmp(neg, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JE, "its_wr"));
this.emit(IRInstr.store(IRValue.ofInt(45, IRType.I8), buf, IRType.I8));

// write '-' if neg
// pos = 30 (write from end of buffer); start pos = 30 if neg else 31
this.emit(IRInstr.label("its_wr"));
IRValue pos = IRValue.reg("its_ps", IRType.I64);
this.emit(IRInstr.mov(pos, IRValue.ofInt(31, IRType.I64)));
this.emit(IRInstr.cmp(neg, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JE, "its_no_neg"));
String np0R = this.newReg();
IRValue np0 = IRValue.reg(np0R, IRType.PTR);
this.emit(IRInstr.mov(np0, buf));
this.emit(IRInstr.store(IRValue.ofInt(45, IRType.I64), np0, IRType.I8)); // '-'=45
this.emit(IRInstr.label("its_no_neg"));

// write digits right to left: pos = tot-1
IRValue pos = IRValue.reg("its_pos", IRType.I64);
IRValue abs2 = IRValue.reg("its_ab2", IRType.I64);
this.emit(IRInstr.binop(IROpcode.SUB, pos, tot, IRValue.ofInt(1, IRType.I64)));
this.emit(IRInstr.mov(abs2, abs_n));
this.emit(IRInstr.label("its_write"));
this.emit(IRInstr.cmp(abs2, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JLE, "its_written"));
String modR = this.newReg();
IRValue md = IRValue.reg(modR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.MOD, md, abs2, IRValue.ofInt(10, IRType.I64)));
String digR = this.newReg();
IRValue dig = IRValue.reg(digR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.ADD, dig, md, IRValue.ofInt(48, IRType.I64))); // '0'=48
String dPR = this.newReg();
IRValue dP = IRValue.reg(dPR, IRType.PTR);
this.emit(IRInstr.binop(IROpcode.ADD, dP, buf, pos));
this.emit(IRInstr.store(dig, dP, IRType.I8));
String divR = this.newReg();
IRValue dv = IRValue.reg(divR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.DIV, dv, abs2, IRValue.ofInt(10, IRType.I64)));
this.emit(IRInstr.mov(abs2, dv));
this.emit(IRInstr.branch(IROpcode.JE, "its_lp"));
this.emit(IRInstr.mov(pos, IRValue.ofInt(30, IRType.I64)));

// Loop: while abs_n > 0, write digit, shift
this.emit(IRInstr.label("its_lp"));
this.emit(IRInstr.cmp(abs_n, IRValue.ofInt(0, IRType.I64)));
this.emit(IRInstr.branch(IROpcode.JLE, "its_dn"));
// digit = abs_n % 10 + '0'
String modR = this.newReg(); IRValue md = IRValue.reg(modR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.MOD, md, abs_n, IRValue.ofInt(10, IRType.I64)));
String dR = this.newReg(); IRValue dv = IRValue.reg(dR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.ADD, dv, md, IRValue.ofInt(48, IRType.I64)));
String dpR = this.newReg(); IRValue dp = IRValue.reg(dpR, IRType.PTR);
this.emit(IRInstr.binop(IROpcode.ADD, dp, buf, pos));
this.emit(IRInstr.store(dv, dp, IRType.I8));
// abs_n /= 10
String divR = this.newReg(); IRValue dv2 = IRValue.reg(divR, IRType.I64);
this.emit(IRInstr.binop(IROpcode.DIV, dv2, abs_n, IRValue.ofInt(10, IRType.I64)));
this.emit(IRInstr.mov(abs_n, dv2));
// pos--
this.emit(IRInstr.binop(IROpcode.SUB, pos, pos, IRValue.ofInt(1, IRType.I64)));
this.emit(IRInstr.jmp("its_write"));
this.emit(IRInstr.label("its_written"));
this.emit(IRInstr.ret(buf));
this.emit(IRInstr.jmp("its_lp"));

// Return pointer to first digit: buf+pos+1 (right after last written char)
this.emit(IRInstr.label("its_dn"));
String retR = this.newReg(); IRValue retV = IRValue.reg(retR, IRType.PTR);
this.emit(IRInstr.binop(IROpcode.ADD, retV, buf, pos));
String ret2R = this.newReg(); IRValue ret2 = IRValue.reg(ret2R, IRType.PTR);
this.emit(IRInstr.binop(IROpcode.ADD, ret2, retV, IRValue.ofInt(1, IRType.I64)));
this.emit(IRInstr.ret(ret2));
}

// ===== Entry point =====
Expand Down
31 changes: 28 additions & 3 deletions arimo/compiler/backend/IRToX64.arm
Original file line number Diff line number Diff line change
Expand Up @@ -789,9 +789,28 @@ public class IRToX64 {
if (op == IROpcode.ADD || op == IROpcode.SUB ||
op == IROpcode.AND || op == IROpcode.OR || op == IROpcode.XOR ||
op == IROpcode.MUL) {
List<Integer> regs = this.safeRa.allocScratch3();
IRValue a = instr.operands.get(0) as IRValue;
IRValue b = instr.operands.get(1) as IRValue;
// ADD with IMM_INT: use LEA to encode displacement directly.
if (op == IROpcode.ADD && b.kind == IRValueKind.IMM_INT) {
List<Integer> regs = this.safeRa.allocScratch2();
Integer ra = this.safeLoadVal(a, regs.get(0) as Integer);
Integer dr = regs.get(1) as Integer;
this.enc.leaMem(dr, ra, b.immInt);
this.safeStoreDst(instr.dst.name, dr);
return;
}
// SUB with IMM_INT: use subRI to avoid scratch register for immediate.
if (op == IROpcode.SUB && b.kind == IRValueKind.IMM_INT) {
List<Integer> regs = this.safeRa.allocScratch2();
Integer ra = this.safeLoadVal(a, regs.get(0) as Integer);
Integer dr = regs.get(1) as Integer;
if (dr != ra) { this.enc.movRR(dr, ra); }
this.enc.subRI(dr, b.immInt);
this.safeStoreDst(instr.dst.name, dr);
return;
}
List<Integer> regs = this.safeRa.allocScratch3();
Integer ra = this.safeLoadVal(a, regs.get(0) as Integer);
Integer rb = this.safeLoadVal(b, regs.get(1) as Integer);
Integer dr = regs.get(2) as Integer;
Expand Down Expand Up @@ -872,13 +891,19 @@ public class IRToX64 {
// --- CMP ---

if (op == IROpcode.CMP) {
List<Integer> regs = this.safeRa.allocScratch2();
IRValue a = instr.operands.get(0) as IRValue;
IRValue b = instr.operands.get(1) as IRValue;
// CMP with IMM: use cmpRI to avoid scratch register for the immediate
if (b.kind == IRValueKind.IMM_INT) {
Integer sr = this.safeRa.allocScratch();
Integer ra = this.safeLoadVal(a, sr);
this.enc.cmpRI(ra, b.immInt);
return;
}
List<Integer> regs = this.safeRa.allocScratch2();
Integer ra = this.safeLoadVal(a, regs.get(0) as Integer);
Integer rb = this.safeLoadVal(b, regs.get(1) as Integer);
this.enc.cmpRR(ra, rb);
// CMP has no dst — flags are set, consumed by following branch
return;
}

Expand Down
Loading