Skip to content

Commit c418921

Browse files
Dale-Blackclaude
andcommitted
ND array transpilation: zeros, ones, indexing, size (column-major flat)
Representation: flat JS Array + _size metadata (column-major, matching Julia). A[i,j] → A[(j-1)*nrows + (i-1)] — same stride math as WASM compilers. Construction: - zeros(m,n), ones(m,n), fill(v,m,n) → jl_ndarray runtime helper - Creates flat Array with _size property for shape metadata Indexing: - A[i,j] read → column-major flat index via _size[0] - A[i,j] = val → same stride math for write - A[i,j,k] read → 3D column-major stride - Handlers in ALL three code paths: compile_invoke, compile_call SSA, compile_call GlobalRef Size: - size(A) → A._size.slice() (or [A.length] for 1D) - size(A,d) → A._size[d-1] Runtime helper: - jl_ndarray(fill_val, dims) — creates flat array with _size Tests: 8 e2e Node.js tests — zeros, ones, length, size(1), size(2), set+get, column-major verification, unrolled matmul. Known limitation: nested for loops (for i; for j; for k) have a pre-existing JST bug where inner break exits outer while loop. Single for loops work correctly. Filed as TODO. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2025e0b commit c418921

3 files changed

Lines changed: 132 additions & 12 deletions

File tree

src/compiler/codegen.jl

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -878,16 +878,21 @@ function compile_call(ctx::JSCompilationContext, expr::Expr)
878878
# Array indexing: arr[i] → arr[i-1]
879879
# Range indexing: arr[a:b] → arr.slice(a-1, b)
880880
if length(call_args) == 2
881-
idx_arg = args[3] # The index argument (before compilation)
881+
idx_arg = args[3]
882882
idx_type = nothing
883883
if idx_arg isa Core.SSAValue
884884
idx_type = try ctx.code_info.ssavaluetypes[idx_arg.id] catch; nothing end
885885
end
886-
# Check if index is a range (UnitRange) → slice
887886
if idx_type !== nothing && idx_type isa DataType && idx_type <: AbstractRange
888887
return "$(call_args[1]).slice(($(call_args[2])).start-1,($(call_args[2])).stop)"
889888
end
890889
return "$(call_args[1])[($(call_args[2])) - 1]"
890+
elseif length(call_args) == 3
891+
# A[i,j] → column-major flat index: (j-1)*nrows + (i-1)
892+
return "$(call_args[1])[(($(call_args[3]))-1)*$(call_args[1])._size[0]+(($(call_args[2]))-1)]"
893+
elseif length(call_args) == 4
894+
# A[i,j,k] → column-major: (k-1)*m*n + (j-1)*m + (i-1)
895+
return "$(call_args[1])[(($(call_args[4]))-1)*$(call_args[1])._size[0]*$(call_args[1])._size[1]+(($(call_args[3]))-1)*$(call_args[1])._size[0]+(($(call_args[2]))-1)]"
891896
end
892897
return "[]"
893898
end
@@ -1031,6 +1036,18 @@ function compile_call(ctx::JSCompilationContext, expr::Expr)
10311036
if fn_name == "sort" && length(call_args) >= 1
10321037
return "$(call_args[1]).slice().sort()"
10331038
end
1039+
if fn_name == "setindex!" && length(call_args) >= 3
1040+
if length(call_args) == 4
1041+
# A[i,j] = val → column-major
1042+
return "($(call_args[1])[(($(call_args[4]))-1)*$(call_args[1])._size[0]+(($(call_args[3]))-1)] = $(call_args[2]))"
1043+
else
1044+
return "($(call_args[1])[($(call_args[3]))-1] = $(call_args[2]))"
1045+
end
1046+
end
1047+
if fn_name == "getindex" && length(call_args) == 3
1048+
# A[i,j] → column-major
1049+
return "$(call_args[1])[(($(call_args[3]))-1)*$(call_args[1])._size[0]+(($(call_args[2]))-1)]"
1050+
end
10341051
if fn_name == "copy" && length(call_args) >= 1
10351052
return "$(call_args[1]).slice()"
10361053
end
@@ -1058,15 +1075,39 @@ function compile_call(ctx::JSCompilationContext, expr::Expr)
10581075
return "$(call_args[2]).reduce($(call_args[1]))"
10591076
end
10601077

1061-
# ─── Construction (by name) ───
1062-
if fn_name == "zeros" && length(call_args) >= 1
1063-
return "new Array($(call_args[1])).fill(0)"
1078+
# ─── Construction (by name, 1D and ND) ───
1079+
if fn_name == "zeros"
1080+
if length(call_args) == 1
1081+
return "new Array($(call_args[1])).fill(0)"
1082+
elseif length(call_args) >= 2
1083+
require_runtime!(ctx, :jl_ndarray)
1084+
return "jl_ndarray(0,[$(join(call_args, ","))])"
1085+
end
10641086
end
1065-
if fn_name == "ones" && length(call_args) >= 1
1066-
return "new Array($(call_args[1])).fill(1)"
1087+
if fn_name == "ones"
1088+
if length(call_args) == 1
1089+
return "new Array($(call_args[1])).fill(1)"
1090+
elseif length(call_args) >= 2
1091+
require_runtime!(ctx, :jl_ndarray)
1092+
return "jl_ndarray(1,[$(join(call_args, ","))])"
1093+
end
10671094
end
1068-
if fn_name == "fill" && length(call_args) >= 2
1069-
return "new Array($(call_args[2])).fill($(call_args[1]))"
1095+
if fn_name == "fill"
1096+
if length(call_args) == 2
1097+
return "new Array($(call_args[2])).fill($(call_args[1]))"
1098+
elseif length(call_args) >= 3
1099+
require_runtime!(ctx, :jl_ndarray)
1100+
return "jl_ndarray($(call_args[1]),[$(join(call_args[2:end], ","))])"
1101+
end
1102+
end
1103+
1104+
# ─── Size (by name) ───
1105+
if fn_name == "size"
1106+
if length(call_args) == 1
1107+
return "($(call_args[1])._size||[$(call_args[1]).length]).slice()"
1108+
elseif length(call_args) == 2
1109+
return "($(call_args[1])._size?$(call_args[1])._size[($(call_args[2]))-1]:$(call_args[1]).length)"
1110+
end
10701111
end
10711112

10721113
# ─── Parsing (by name) ───
@@ -1265,10 +1306,24 @@ function compile_call(ctx::JSCompilationContext, expr::Expr)
12651306
return "[]"
12661307
end
12671308
end
1268-
# Array indexing
1309+
# Array indexing (1D and ND)
12691310
call_args_gr = [compile_value(ctx, a) for a in args[2:end]]
12701311
if length(call_args_gr) == 2
12711312
return "$(call_args_gr[1])[($(call_args_gr[2])) - 1]"
1313+
elseif length(call_args_gr) == 3
1314+
# A[i,j] → column-major
1315+
return "$(call_args_gr[1])[(($(call_args_gr[3]))-1)*$(call_args_gr[1])._size[0]+(($(call_args_gr[2]))-1)]"
1316+
end
1317+
end
1318+
1319+
# Base.setindex! — array assignment (1D and ND)
1320+
if bname === :setindex! && callee.mod === Base
1321+
call_args_gr = [compile_value(ctx, a) for a in args[2:end]]
1322+
if length(call_args_gr) == 4
1323+
# A[i,j] = val → column-major
1324+
return "($(call_args_gr[1])[(($(call_args_gr[4]))-1)*$(call_args_gr[1])._size[0]+(($(call_args_gr[3]))-1)] = $(call_args_gr[2]))"
1325+
elseif length(call_args_gr) == 3
1326+
return "($(call_args_gr[1])[($(call_args_gr[3]))-1] = $(call_args_gr[2]))"
12721327
end
12731328
end
12741329

@@ -1677,8 +1732,15 @@ function compile_invoke(ctx::JSCompilationContext, expr::Expr)
16771732
elseif func_name == "setindex!"
16781733
arr_val = compile_value(ctx, expr.args[3])
16791734
val_val = compile_value(ctx, expr.args[4])
1680-
idx_val = compile_value(ctx, expr.args[5])
1681-
return "($(arr_val)[($(idx_val))-1] = $(val_val))"
1735+
if length(expr.args) == 6
1736+
# A[i,j] = val → column-major
1737+
i_val = compile_value(ctx, expr.args[5])
1738+
j_val = compile_value(ctx, expr.args[6])
1739+
return "($(arr_val)[(($(j_val))-1)*$(arr_val)._size[0]+(($(i_val))-1)] = $(val_val))"
1740+
else
1741+
idx_val = compile_value(ctx, expr.args[5])
1742+
return "($(arr_val)[($(idx_val))-1] = $(val_val))"
1743+
end
16821744
elseif func_name == "deleteat!"
16831745
arr_val = compile_value(ctx, expr.args[3])
16841746
idx_val = compile_value(ctx, expr.args[4])
@@ -1825,6 +1887,15 @@ function compile_invoke(ctx::JSCompilationContext, expr::Expr)
18251887
return "$(call_args[1]).length === 0"
18261888
end
18271889

1890+
# size(A) → shape tuple, size(A, d) → dimension size
1891+
if func_name == "size"
1892+
if length(call_args) == 1
1893+
return "($(call_args[1])._size||[$(call_args[1]).length]).slice()"
1894+
elseif length(call_args) == 2
1895+
return "($(call_args[1])._size?$(call_args[1])._size[($(call_args[2]))-1]:$(call_args[1]).length)"
1896+
end
1897+
end
1898+
18281899
# convert(T, x) → just x (type conversions are compile-time in JS)
18291900
if func_name == "convert" && length(call_args) >= 2
18301901
return call_args[2]

src/compiler/runtime.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,15 @@ function jl_objectid(x) {
218218
if (id === undefined) { id = ++_jl_oid_ctr; _jl_oid_map.set(x, id); }
219219
return id;
220220
}""",
221+
222+
:jl_ndarray => """
223+
function jl_ndarray(fill_val, dims) {
224+
var n = 1;
225+
for (var i = 0; i < dims.length; i++) n *= dims[i];
226+
var a = new Array(n).fill(fill_val);
227+
a._size = dims.slice();
228+
return a;
229+
}""",
221230
)
222231

223232
# Dependency-ordered list of symbols for deterministic output
@@ -230,6 +239,7 @@ const RUNTIME_ORDER = [
230239
:jl_checked_add, :jl_checked_sub, :jl_checked_mul,
231240
:jl_println, :jl_print,
232241
:jl_objectid,
242+
:jl_ndarray,
233243
]
234244

235245
"""

test/runtests.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3515,4 +3515,43 @@ process.stdout.write(JSON.stringify(r));
35153515
fn = () -> startswith("hello", "he")
35163516
@test compile_unopt_and_run(fn, "") == "true"
35173517
end
3518+
3519+
# ─── ND Array tests (e2e via Node.js) ───
3520+
@testset "ND: zeros(3,4)" begin
3521+
@test compile_unopt_and_run(() -> zeros(3, 4), "") == "[0,0,0,0,0,0,0,0,0,0,0,0]"
3522+
end
3523+
@testset "ND: ones(2,3)" begin
3524+
@test compile_unopt_and_run(() -> ones(2, 3), "") == "[1,1,1,1,1,1]"
3525+
end
3526+
@testset "ND: length(zeros(3,4))" begin
3527+
@test compile_unopt_and_run(() -> length(zeros(3, 4)), "") == "12"
3528+
end
3529+
@testset "ND: size(A,1)" begin
3530+
@test compile_unopt_and_run(() -> size(zeros(3, 4), 1), "") == "3"
3531+
end
3532+
@testset "ND: size(A,2)" begin
3533+
@test compile_unopt_and_run(() -> size(zeros(3, 4), 2), "") == "4"
3534+
end
3535+
@testset "ND: A[i,j] set+get" begin
3536+
@test compile_unopt_and_run(() -> begin A=zeros(2,3); A[1,2]=42.0; A[1,2] end, "") == "42"
3537+
end
3538+
@testset "ND: column-major indexing" begin
3539+
@test compile_unopt_and_run(() -> begin A=zeros(2,3); A[2,1]=10.0; A[1,2]=20.0; A[2,2]=30.0; A end, "") == "[0,10,20,30,0,0]"
3540+
end
3541+
@testset "ND: unrolled matmul" begin
3542+
fn = () -> begin
3543+
A=zeros(2,2); A[1,1]=1.0; A[1,2]=2.0; A[2,1]=3.0; A[2,2]=4.0
3544+
B=zeros(2,2); B[1,1]=5.0; B[1,2]=6.0; B[2,1]=7.0; B[2,2]=8.0
3545+
C=zeros(2,2)
3546+
C[1,1] = A[1,1]*B[1,1] + A[1,2]*B[2,1]
3547+
C[2,1] = A[2,1]*B[1,1] + A[2,2]*B[2,1]
3548+
C[1,2] = A[1,1]*B[1,2] + A[1,2]*B[2,2]
3549+
C[2,2] = A[2,1]*B[1,2] + A[2,2]*B[2,2]
3550+
return C
3551+
end
3552+
@test compile_unopt_and_run(fn, "") == "[19,43,22,50]"
3553+
end
3554+
# Note: nested for loops (for i; for j; for k) have a known JST bug
3555+
# where inner loop `break` exits the outer while loop. Single for loops work.
3556+
# TODO: fix nested for loop compilation in JST
35183557
end

0 commit comments

Comments
 (0)