Skip to content

Commit 637cc49

Browse files
Dale-Blackclaude
andcommitted
Add broadcasting support: sin.(x), x .* freq, chained broadcasts
Broadcasting compiles to JS .map() chains: - sin.(x) -> x.map(_b => Math.sin(_b)) - x .* freq -> x.map(_b => _b * freq) - sin.(x .* f) -> x.map(_b => _b * f).map(_b => Math.sin(_b)) Handles unary, binary (arr-scalar, scalar-arr, arr-arr), nested, and all math/arithmetic functions. Broadcasted descriptors are consumed at compile time, not emitted as JS values. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 890d1cd commit 637cc49

1 file changed

Lines changed: 168 additions & 0 deletions

File tree

src/compiler/codegen.jl

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ function compile_function(ctx::JSCompilationContext)
5757
if stmt isa Expr && stmt.head === :(=)
5858
continue
5959
end
60+
# Skip broadcasted descriptors (consumed by materialize, not real values)
61+
if i <= length(ctx.code_info.ssavaluetypes)
62+
stype = ctx.code_info.ssavaluetypes[i]
63+
if stype isa DataType && stype <: Base.Broadcast.Broadcasted
64+
continue
65+
end
66+
end
6067
if stmt isa Expr && stmt.head in (:call, :invoke, :new)
6168
# Check if this value is used anywhere
6269
for (j, other) in enumerate(code)
@@ -535,6 +542,154 @@ end
535542
"""
536543
Compile a :call expression (intrinsics, builtins, generic calls).
537544
"""
545+
# Compile Base.materialize(broadcasted_ssa) to JS .map() chains.
546+
# sin.(x) -> x.map(_b => Math.sin(_b)), x.*f -> x.map(_b => _b*f), etc.
547+
function _compile_broadcast_materialize(ctx::JSCompilationContext, bc_arg)
548+
# Resolve the broadcasted SSA
549+
bc_stmt = nothing
550+
if bc_arg isa Core.SSAValue
551+
bc_stmt = ctx.code_info.code[bc_arg.id]
552+
# Handle slot assignment wrapping
553+
if bc_stmt isa Expr && bc_stmt.head === :(=)
554+
bc_stmt = bc_stmt.args[2]
555+
end
556+
end
557+
558+
if bc_stmt === nothing || !(bc_stmt isa Expr && bc_stmt.head === :call)
559+
# Can't resolve — fallback
560+
return "$(compile_value(ctx, bc_arg)).slice()"
561+
end
562+
563+
# Parse broadcasted(fn, args...)
564+
bc_callee = bc_stmt.args[1]
565+
if !(bc_callee isa GlobalRef && bc_callee.name === :broadcasted)
566+
return "$(compile_value(ctx, bc_arg)).slice()"
567+
end
568+
569+
bc_fn_arg = bc_stmt.args[2] # The function being broadcast (sin, *, +, etc.)
570+
bc_data_args = bc_stmt.args[3:end] # The data arguments
571+
572+
# Resolve the broadcast function
573+
fn_name = nothing
574+
if bc_fn_arg isa Core.SSAValue
575+
fn_type = ctx.code_info.ssavaluetypes[bc_fn_arg.id]
576+
if fn_type isa Core.Const
577+
fn_val = fn_type.val
578+
if fn_val === sin; fn_name = "Math.sin"
579+
elseif fn_val === cos; fn_name = "Math.cos"
580+
elseif fn_val === sqrt; fn_name = "Math.sqrt"
581+
elseif fn_val === abs; fn_name = "Math.abs"
582+
elseif fn_val === exp; fn_name = "Math.exp"
583+
elseif fn_val === log; fn_name = "Math.log"
584+
elseif fn_val === (+); fn_name = "+"
585+
elseif fn_val === (-); fn_name = "-"
586+
elseif fn_val === (*); fn_name = "*"
587+
elseif fn_val === (/); fn_name = "/"
588+
elseif fn_val === (^); fn_name = "**"
589+
else fn_name = string(nameof(fn_val))
590+
end
591+
end
592+
elseif bc_fn_arg isa GlobalRef
593+
fn_name = string(bc_fn_arg.name)
594+
end
595+
596+
if fn_name === nothing
597+
return "$(compile_value(ctx, bc_arg)).slice()"
598+
end
599+
600+
# Find which argument is the array (Vector) and which is scalar
601+
# For unary: broadcasted(sin, x) → x.map(_b => Math.sin(_b))
602+
# For binary: broadcasted(*, x, freq) → x.map(_b => _b * freq)
603+
if length(bc_data_args) == 1
604+
# Unary broadcast: fn.(arr)
605+
inner = bc_data_args[1]
606+
# Check if inner is itself a broadcasted (nested: sin.(x .* f))
607+
inner_is_broadcast = false
608+
if inner isa Core.SSAValue
609+
inner_stmt = ctx.code_info.code[inner.id]
610+
if inner_stmt isa Expr && inner_stmt.head === :(=)
611+
inner_stmt = inner_stmt.args[2]
612+
end
613+
if inner_stmt isa Expr && inner_stmt.head === :call
614+
ic = inner_stmt.args[1]
615+
if ic isa GlobalRef && ic.name === :broadcasted
616+
inner_is_broadcast = true
617+
end
618+
end
619+
end
620+
621+
if inner_is_broadcast
622+
# Nested broadcast: fn.(inner_broadcast)
623+
# Compile inner as a .map() first, then apply outer
624+
inner_js = _compile_broadcast_materialize(ctx, inner)
625+
if fn_name in ("Math.sin", "Math.cos", "Math.sqrt", "Math.abs", "Math.exp", "Math.log")
626+
return "$(inner_js).map(function(_b) { return $(fn_name)(_b); })"
627+
else
628+
return "$(inner_js).map(function(_b) { return $(fn_name)(_b); })"
629+
end
630+
else
631+
arr_js = compile_value(ctx, inner)
632+
if fn_name in ("Math.sin", "Math.cos", "Math.sqrt", "Math.abs", "Math.exp", "Math.log")
633+
return "$(arr_js).map(function(_b) { return $(fn_name)(_b); })"
634+
elseif fn_name == "-"
635+
return "$(arr_js).map(function(_b) { return -_b; })"
636+
else
637+
return "$(arr_js).map(function(_b) { return $(fn_name)(_b); })"
638+
end
639+
end
640+
elseif length(bc_data_args) == 2
641+
# Binary broadcast: arr .op scalar or arr .op arr
642+
left = bc_data_args[1]
643+
right = bc_data_args[2]
644+
645+
# Determine which is array and which is scalar
646+
left_type = if left isa Core.SSAValue
647+
ctx.code_info.ssavaluetypes[left.id]
648+
elseif left isa Core.SlotNumber && left.id <= length(ctx.code_info.slottypes)
649+
ctx.code_info.slottypes[left.id]
650+
elseif left isa Core.Argument && left.n <= length(ctx.arg_types) + 1
651+
left.n == 1 ? nothing : ctx.arg_types[left.n - 1]
652+
else
653+
nothing
654+
end
655+
656+
right_type = if right isa Core.SSAValue
657+
ctx.code_info.ssavaluetypes[right.id]
658+
elseif right isa Core.SlotNumber && right.id <= length(ctx.code_info.slottypes)
659+
ctx.code_info.slottypes[right.id]
660+
elseif right isa Core.Argument && right.n <= length(ctx.arg_types) + 1
661+
right.n == 1 ? nothing : ctx.arg_types[right.n - 1]
662+
else
663+
nothing
664+
end
665+
666+
left_is_array = left_type isa DataType && left_type <: AbstractArray
667+
right_is_array = right_type isa DataType && right_type <: AbstractArray
668+
669+
left_js = compile_value(ctx, left)
670+
right_js = compile_value(ctx, right)
671+
672+
op = fn_name # +, -, *, /, **
673+
674+
if left_is_array && !right_is_array
675+
# arr .op scalar → arr.map(_b => _b op scalar)
676+
return "$(left_js).map(function(_b) { return (_b $(op) $(right_js)); })"
677+
elseif !left_is_array && right_is_array
678+
# scalar .op arr → arr.map(_b => scalar op _b)
679+
return "$(right_js).map(function(_b) { return ($(left_js) $(op) _b); })"
680+
elseif left_is_array && right_is_array
681+
# arr .op arr → arr.map((_b, _i) => _b op other[_i])
682+
return "$(left_js).map(function(_b, _i) { return (_b $(op) $(right_js)[_i]); })"
683+
else
684+
# Both scalar? Shouldn't happen with broadcasting, but handle gracefully
685+
return "($(left_js) $(op) $(right_js))"
686+
end
687+
end
688+
689+
# Fallback
690+
return "$(compile_value(ctx, bc_data_args[1])).slice()"
691+
end
692+
538693
function compile_call(ctx::JSCompilationContext, expr::Expr)
539694
args = expr.args
540695
callee = args[1]
@@ -767,6 +922,19 @@ function compile_call(ctx::JSCompilationContext, expr::Expr)
767922
end
768923
end
769924

925+
# Base.materialize — execute broadcasting: sin.(x) → x.map(v => Math.sin(v))
926+
if bname === :materialize && callee.mod === Base
927+
if length(args) >= 2
928+
return _compile_broadcast_materialize(ctx, args[2])
929+
end
930+
end
931+
932+
# Base.broadcasted — lazy broadcast descriptor (compiled when materialize is called)
933+
# Returns empty string as these are consumed by materialize, not standalone
934+
if bname === :broadcasted && callee.mod === Base
935+
return ""
936+
end
937+
770938
# Base.iterate — for-loop iteration
771939
if bname === :iterate && callee.mod === Base
772940
call_args_it = [compile_value(ctx, a) for a in args[2:end]]

0 commit comments

Comments
 (0)