Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 203 additions & 17 deletions src/array/stencil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ function load_neighbor_region(arr, region_code::NTuple{N,Int}, neigh_dist) where
return move(task_processor(), collect(@view arr[start_idx:stop_idx]))
end

# In-place variant: load region directly into a pre-allocated destination buffer.
function load_neighbor_region_into!(dest, arr, region_code::NTuple{N,Int}, neigh_dist) where N
validate_neigh_dist(neigh_dist, size(arr))
start_idx = CartesianIndex(ntuple(N) do i
region_code[i] == -1 ? lastindex(arr, i) - get_neigh_dist(neigh_dist, i) + 1 : firstindex(arr, i)
end)
stop_idx = CartesianIndex(ntuple(N) do i
region_code[i] == +1 ? firstindex(arr, i) + get_neigh_dist(neigh_dist, i) - 1 : lastindex(arr, i)
end)
copyto!(dest, @view arr[start_idx:stop_idx])
end

is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size)))

#############################################################################
Expand Down Expand Up @@ -123,6 +135,9 @@ boundary_transition(::Wrap, idx, size) =
load_boundary_region(::Wrap, arr, region_code, neigh_dist, boundary_dims) =
load_neighbor_region(arr, region_code, neigh_dist)

load_boundary_region_into!(dest, ::Wrap, arr, region_code, neigh_dist, boundary_dims) =
load_neighbor_region_into!(dest, arr, region_code, neigh_dist)

function boundary_source_index(::Wrap, arr, rc, nd, idx_d, d)
if rc == -1
return lastindex(arr, d) - nd + idx_d
Expand Down Expand Up @@ -157,6 +172,9 @@ function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_d
return move(task_processor(), fill(pad.padval, region_size))
end

load_boundary_region_into!(dest, pad::Pad, arr, region_code, neigh_dist, boundary_dims) =
fill!(dest, pad.padval)

# Use edge as source index (value will be overridden by apply_boundary_value)
boundary_source_index(::Pad, arr, rc, nd, idx_d, d) =
rc == -1 ? firstindex(arr, d) : (rc == +1 ? lastindex(arr, d) : idx_d)
Expand Down Expand Up @@ -221,6 +239,10 @@ function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_di
return move(task_processor(), result)
end

function load_boundary_region_into!(dest, ::Clamp, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N
Kernel(load_boundary_region_kernel)(Clamp(), dest, arr, region_code, neigh_dist, boundary_dims; ndrange=length(dest))
end

function boundary_source_index(::Clamp, arr, rc, nd, idx_d, d)
if rc == -1
return firstindex(arr, d)
Expand Down Expand Up @@ -332,6 +354,18 @@ function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region
return move(task_processor(), result)
end

function load_boundary_region_into!(dest, ::LinearExtrapolate, arr::AbstractArray{T}, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {T<:Real,N}
extrap_dim = 0
for d in 1:N
if boundary_dims[d] && region_code[d] != 0
extrap_dim = d
break
end
end
nd = get_neigh_dist(neigh_dist, extrap_dim)
Kernel(load_boundary_region_kernel)(LinearExtrapolate(), dest, arr, region_code, neigh_dist, boundary_dims, Val(extrap_dim), Val(nd); ndrange=length(dest))
end

# Use edge as source index (value will be computed by apply_boundary_value)
boundary_source_index(::LinearExtrapolate, arr, rc, nd, idx_d, d) =
rc == -1 ? firstindex(arr, d) : (rc == +1 ? lastindex(arr, d) : idx_d)
Expand Down Expand Up @@ -434,6 +468,41 @@ function load_boundary_region(::Reflect{Symm}, arr, region_code::NTuple{N,Int},
return region
end

function load_boundary_region_into!(dest, ::Reflect{Symm}, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {N, Symm}
flipped_code = ntuple(N) do i
(region_code[i] != 0 && boundary_dims[i]) ? -region_code[i] : region_code[i]
end
skip = Symm ? 0 : 1
start_idx = CartesianIndex(ntuple(N) do i
needs_skip = boundary_dims[i] && region_code[i] != 0
actual_skip = needs_skip ? skip : 0
if flipped_code[i] == -1
lastindex(arr, i) - get_neigh_dist(neigh_dist, i) + 1 - actual_skip
elseif flipped_code[i] == +1
firstindex(arr, i) + actual_skip
else
firstindex(arr, i)
end
end)
stop_idx = CartesianIndex(ntuple(N) do i
needs_skip = boundary_dims[i] && region_code[i] != 0
actual_skip = needs_skip ? skip : 0
if flipped_code[i] == +1
firstindex(arr, i) + get_neigh_dist(neigh_dist, i) - 1 + actual_skip
elseif flipped_code[i] == -1
lastindex(arr, i) - actual_skip
else
lastindex(arr, i)
end
end)
copyto!(dest, @view arr[start_idx:stop_idx])
for i in 1:N
GPUArraysCore.@allowscalar if region_code[i] != 0 && boundary_dims[i]
reverse!(dest, dims=i)
end
end
end

function boundary_source_index(::Reflect{Symm}, arr, rc, nd, idx_d, d) where Symm
skip = Symm ? 0 : 1
if rc == -1
Expand Down Expand Up @@ -564,6 +633,10 @@ function load_boundary_region(boundary::Tuple, arr, region_code::NTuple{N,Int},
return move(task_processor(), result)
end

function load_boundary_region_into!(dest, boundary::Tuple, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N
Kernel(load_boundary_region_kernel)(boundary, dest, arr, region_code, neigh_dist, boundary_dims; ndrange=length(dest))
end

#############################################################################
# Chunk Selection and Halo Building
#############################################################################
Expand Down Expand Up @@ -615,6 +688,115 @@ function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
return accesses
end

# Returns (region_metadata, neighbor_chunk_dtasks) without spawning intermediate load tasks.
# region_metadata: Vector of (region_code, is_boundary, boundary_dims).
# neighbor_chunk_dtasks: Vector of raw chunk DTasks (resolved to arrays when build_halo_consolidated runs).
function select_neighborhood_info(chunks, idx, neigh_dist, boundary)
validate_neigh_dist(neigh_dist)
N = ndims(chunks)
chunk_dist = 1
region_metadata = Tuple[]
neighbor_chunks = Any[]

for i in 0:(3^N - 1)
region_code = ntuple(N) do d
((i ÷ 3^(d-1)) % 3) - 1
end
all(==(0), region_code) && continue

chunk_offset = CartesianIndex(ntuple(N) do d
region_code[d] * chunk_dist
end)
new_idx = idx + chunk_offset

if is_past_boundary(size(chunks), new_idx)
boundary_dims = ntuple(N) do d
new_idx[d] < 1 || new_idx[d] > size(chunks)[d]
end
if boundary_has_transition(boundary)
new_idx = boundary_transition(boundary, new_idx, size(chunks))
else
new_idx = idx
end
push!(region_metadata, (region_code, true, boundary_dims))
else
push!(region_metadata, (region_code, false, ntuple(_ -> false, N)))
end
push!(neighbor_chunks, chunks[new_idx])
end

@assert length(region_metadata) == 3^N - 1
return region_metadata, neighbor_chunks
end

# Per-thread cache: WeakKeyDict{DArray, Dict{(chunk_idx, halo_width), HaloArray}}.
# WeakKeyDict is used for the outer level so that the cache does not hold a strong reference
# to the source DArray — allowing its GC finalizer to fire when user code drops its last
# reference (see below). Using chunk_idx as part of the inner key ensures that within one
# DArray, every chunk has its own dedicated buffer — so if a single worker thread processes
# multiple same-shaped chunks in the same iteration sequentially, each gets a distinct
# HaloArray and there is no aliasing with a concurrently running inner-stencil task.
# Filling a cached buffer in-place is safe because spawn_datadeps blocks until all inner
# tasks complete before the next iteration's build_halo_consolidated calls run.
const HALO_ARRAY_CACHE = TaskLocalValue{WeakKeyDict{Any,Dict{Any,Any}}}(()->WeakKeyDict{Any,Dict{Any,Any}}())

# Consolidated halo builder: loads all neighbor regions directly into a HaloArray.
# `read_darray` and `chunk_idx` are used solely for cache lookup — they are not DTask
# arguments, so Dagger does not create extra data dependencies from them.
# First call per (DArray, chunk_idx, halo_width) allocates and caches; subsequent calls
# fill the cached HaloArray in-place — zero allocations on the hot path.
function build_halo_consolidated(read_darray, chunk_idx, neigh_dist, boundary, center, region_metadata, neighbor_chunks...)
N = ndims(center)
expected_halos = length(region_metadata)
@assert length(neighbor_chunks) == expected_halos
validate_neigh_dist(neigh_dist, size(center))
halo_width = ntuple(i -> get_neigh_dist(neigh_dist, i), N)

outer_cache = HALO_ARRAY_CACHE[]

# Create the inner cache on first encounter of this DArray on this thread, and register
# a finalizer that captures it. When the DArray becomes unreachable and is collected,
# the finalizer fires and unsafe_free!s every cached HaloArray for this (DArray, thread)
# pair. Because WeakKeyDict holds only a weak reference to read_darray, the DArray can
# actually be collected (a plain IdDict would keep it alive forever).
if !haskey(outer_cache, read_darray)
inner_cache = Dict{Any,Any}()
outer_cache[read_darray] = inner_cache
finalizer(read_darray) do _
for halo in values(inner_cache)
unsafe_free!(halo)
end
end
end
inner_cache = outer_cache[read_darray]
cache_key = (chunk_idx, halo_width)

if haskey(inner_cache, cache_key)
halo = inner_cache[cache_key]
copyto!(halo.center, center)
for i in 1:expected_halos
region_code, is_boundary, boundary_dims = region_metadata[i]
chunk = neighbor_chunks[i]
if is_boundary
load_boundary_region_into!(halo.halos[i], boundary, chunk, region_code, neigh_dist, boundary_dims)
else
load_neighbor_region_into!(halo.halos[i], chunk, region_code, neigh_dist)
end
end
return halo
else
halos = ntuple(expected_halos) do i
region_code, is_boundary, boundary_dims = region_metadata[i]
chunk = neighbor_chunks[i]
is_boundary ? load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims) :
load_neighbor_region(chunk, region_code, neigh_dist)
end
halo = HaloArray(copy(center), halos, halo_width)
inner_cache[cache_key] = halo
return halo
end
end

function build_halo(neigh_dist, boundary, center, all_halos...)
N = ndims(center)
expected_halos = 3^N - 1
Expand Down Expand Up @@ -791,8 +973,11 @@ macro stencil(orig_ex)
end
end

# 2. Stencil operations (inside spawn_datadeps)
datadeps_body = Expr(:block)
# 2. Stencil operations: one spawn_datadeps region per expression.
# Because spawn_datadeps blocks until all its tasks complete, each expression's
# region fully finishes before the next expression's halo tasks are spawned.
# This means HaloArray allocations can always live outside spawn_datadeps,
# avoiding Datadeps aliasing issues unconditionally.
for (;inner_ex, accessed_vars, write_var, write_idx, read_ex, read_vars, neighborhoods, is_allocation, source_var) in inners
# Generate a variable for chunk access
@gensym chunk_idx
Expand Down Expand Up @@ -821,21 +1006,24 @@ macro stencil(orig_ex)
end
inner_fn = Expr(:->, Expr(:tuple, Expr(:parameters, inner_write_var, actual_read_vars...)), new_inner_ex)

# 2a. Pre-spawn all halos for this expression
# This ensures all readers capture the "old" state before any writers start.
# 2a. Pre-spawn all halos for this expression outside spawn_datadeps.
# The preceding spawn_datadeps (if any) has already completed, so the
# source arrays reflect any writes from earlier expressions.
# Pass DTasks directly — no In/Read wrappers needed outside datadeps.
@gensym halo_tasks_map
push!(datadeps_body.args, :($halo_tasks_map = Dict{Symbol, Any}()))
push!(final_ex.args, :($halo_tasks_map = Dict{Symbol, Any}()))
for read_var in read_vars
if read_var in keys(neighborhoods)
neigh_dist, boundary = neighborhoods[read_var]
@gensym halo_tasks
push!(datadeps_body.args, :($halo_tasks = Array{$DTask}(undef, size($chunks($read_var)))))
push!(datadeps_body.args, quote
@gensym halo_tasks region_meta neighbor_cks
push!(final_ex.args, :($halo_tasks = Array{$DTask}(undef, size($chunks($read_var)))))
push!(final_ex.args, quote
for $chunk_idx in $CartesianIndices($chunks($read_var))
$halo_tasks[$chunk_idx] = Dagger.@spawn name="stencil_build_halo" $build_halo($neigh_dist, $boundary, map($Read, $select_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist, $boundary))...)
($region_meta, $neighbor_cks) = $select_neighborhood_info($chunks($read_var), $chunk_idx, $neigh_dist, $boundary)
$halo_tasks[$chunk_idx] = Dagger.@spawn name="stencil_build_halo" $build_halo_consolidated($read_var, $chunk_idx, $neigh_dist, $boundary, $chunks($read_var)[$chunk_idx], $region_meta, $neighbor_cks...)
end
end)
push!(datadeps_body.args, :($halo_tasks_map[$(QuoteNode(read_var))] = $halo_tasks))
push!(final_ex.args, :($halo_tasks_map[$(QuoteNode(read_var))] = $halo_tasks))
end
end

Expand All @@ -857,18 +1045,16 @@ macro stencil(orig_ex)
end
spawn_ex = :(Dagger.@spawn name="stencil_inner_fn" $inner_fn(;$(deps_ex...)))

# 2c. Generate loop to spawn stencil tasks
push!(datadeps_body.args, quote
# 2c. Each expression gets its own spawn_datadeps region. Because
# spawn_datadeps blocks on completion, the next expression's halo
# pre-spawns will always see fully up-to-date array data.
push!(final_ex.args, :(Dagger.spawn_datadeps() do
for $chunk_idx in $CartesianIndices($chunks($write_var))
$spawn_ex
end
end)
end))
end

push!(final_ex.args, :(Dagger.spawn_datadeps() do
$datadeps_body
end))

# 3. Return last allocated var if applicable
if !isempty(inners) && inners[end].is_allocation
push!(final_ex.args, inners[end].write_var)
Expand Down
Loading