Skip to content
23 changes: 21 additions & 2 deletions quadrants/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,11 @@ void TaskCodeGenLLVM::visit(AssertStmt *stmt) {
auto arguments = create_entry_block_alloca(argument_buffer_size);

std::vector<llvm::Value *> args;
args.emplace_back(get_runtime());
// On CPU, use the context-aware variant that returns non-zero on failure
// so we can emit an early return and avoid the subsequent out-of-bounds
// memory access. On GPU, asm("exit;") kills the thread directly.
bool use_ctx_variant = arch_is_cpu(current_arch());
args.emplace_back(use_ctx_variant ? get_context() : get_runtime());
args.emplace_back(builder->CreateIsNotNull(llvm_val[stmt->cond]));
args.emplace_back(builder->CreateGlobalStringPtr(stmt->text));

Expand Down Expand Up @@ -1362,7 +1366,22 @@ void TaskCodeGenLLVM::visit(AssertStmt *stmt) {
builder->CreateGEP(argument_buffer_size, arguments,
{tlctx->get_constant(0), tlctx->get_constant(0)}));

llvm_val[stmt] = call("quadrants_assert_format", std::move(args));
llvm_val[stmt] = call(use_ctx_variant ? "quadrants_assert_format_ctx"
: "quadrants_assert_format",
std::move(args));

if (use_ctx_variant) {
auto *assert_abort =
llvm::BasicBlock::Create(*llvm_context, "assert_abort", func);
auto *assert_cont =
llvm::BasicBlock::Create(*llvm_context, "assert_cont", func);
auto *failed =
builder->CreateICmpNE(llvm_val[stmt], tlctx->get_constant(0));
builder->CreateCondBr(failed, assert_abort, assert_cont);
builder->SetInsertPoint(assert_abort);
builder->CreateRetVoid();
builder->SetInsertPoint(assert_cont);
}
}

void TaskCodeGenLLVM::visit(SNodeOpStmt *stmt) {
Expand Down
6 changes: 6 additions & 0 deletions quadrants/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ struct RuntimeContext {
// LLVMRuntime is shared among functions. So we moved the pointer to
// RuntimeContext which each function have one.
uint64_t *result_buffer;

// Set to 1 by quadrants_assert_format_ctx when a runtime assertion (e.g.
// out-of-bounds check) fails on CPU. The codegen emits an early return
// after each assert call when this is set, and the task runner breaks out
// of its loop.
int32_t cpu_assert_failed{0};
};

#if defined(QD_RUNTIME_HOST)
Expand Down
3 changes: 3 additions & 0 deletions quadrants/runtime/cpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ namespace cpu {
void KernelLauncher::launch_offloaded_tasks(
LaunchContextBuilder &ctx,
const std::vector<TaskFunc> &task_funcs) {
ctx.get_context().cpu_assert_failed = 0;
for (auto task : task_funcs) {
task(&ctx.get_context());
if (ctx.get_context().cpu_assert_failed)
break;
}
}

Expand Down
35 changes: 34 additions & 1 deletion quadrants/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ STRUCT_FIELD_ARRAY(PhysicalCoordinates, val);

STRUCT_FIELD(RuntimeContext, runtime);
STRUCT_FIELD(RuntimeContext, result_buffer)
STRUCT_FIELD(RuntimeContext, cpu_assert_failed)

#include "quadrants/runtime/llvm/runtime_module/atomic.h"

Expand Down Expand Up @@ -886,6 +887,26 @@ void quadrants_assert_format(LLVMRuntime *runtime,
#endif
}

// Context-aware variant called by bounds-check assertions in JIT'd code.
// Returns 1 when the assertion failed (so the codegen can emit an early
// return), 0 otherwise. This replaces a previous setjmp/longjmp approach
// that crashed on Windows because JIT'd frames lack SEH unwind tables.
i32 quadrants_assert_format_ctx(RuntimeContext *context,
u1 test,
const char *format,
int num_arguments,
uint64 *arguments) {
quadrants_assert_format(context->runtime, test, format, num_arguments,
arguments);
#if !ARCH_cuda && !ARCH_amdgpu
if (enable_assert && test == 0) {
context->cpu_assert_failed = 1;
return 1;
}
#endif
return 0;
}

void quadrants_assert_runtime(LLVMRuntime *runtime, u1 test, const char *msg) {
quadrants_assert_format(runtime, test, msg, 0, nullptr);
}
Expand Down Expand Up @@ -1510,6 +1531,8 @@ void cpu_struct_for_block_helper(void *ctx_, int thread_id, int i) {

RuntimeContext this_thread_context = *ctx->context;
this_thread_context.cpu_thread_id = thread_id;
this_thread_context.cpu_assert_failed = 0;

if (lower < upper) {
(*ctx->task)(&this_thread_context, tls_buffer,
&ctx->list->get<Element>(element_id), lower, upper);
Expand Down Expand Up @@ -1591,20 +1614,27 @@ void cpu_parallel_range_for_task(void *range_context,

RuntimeContext this_thread_context = *ctx.context;
this_thread_context.cpu_thread_id = thread_id;
this_thread_context.cpu_assert_failed = 0;

if (ctx.step == 1) {
int block_start = ctx.begin + task_id * ctx.block_size;
int block_end = std::min(block_start + ctx.block_size, ctx.end);
for (int i = block_start; i < block_end; i++) {
ctx.body(&this_thread_context, tls_ptr, i);
if (this_thread_context.cpu_assert_failed)
break;
}
} else if (ctx.step == -1) {
int block_start = ctx.end - task_id * ctx.block_size;
int block_end = std::max(ctx.begin, block_start * ctx.block_size);
for (int i = block_start - 1; i >= block_end; i--) {
ctx.body(&this_thread_context, tls_ptr, i);
if (this_thread_context.cpu_assert_failed)
break;
}
}
if (ctx.epilogue)

if (!this_thread_context.cpu_assert_failed && ctx.epilogue)
ctx.epilogue(ctx.context, tls_ptr);
}

Expand Down Expand Up @@ -1689,6 +1719,7 @@ void cpu_parallel_mesh_for_task(void *range_context,

RuntimeContext this_thread_context = *ctx.context;
this_thread_context.cpu_thread_id = thread_id;
this_thread_context.cpu_assert_failed = 0;

int block_start = task_id * ctx.block_size;
int block_end = std::min(block_start + ctx.block_size, ctx.num_patches);
Expand All @@ -1697,6 +1728,8 @@ void cpu_parallel_mesh_for_task(void *range_context,
if (ctx.prologue)
ctx.prologue(ctx.context, tls_ptr, idx);
ctx.body(&this_thread_context, tls_ptr, idx);
if (this_thread_context.cpu_assert_failed)
break;
if (ctx.epilogue)
ctx.epilogue(ctx.context, tls_ptr, idx);
}
Expand Down
87 changes: 87 additions & 0 deletions tests/python/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,90 @@ def func():
x[3, 7] = 2

func()


@test_utils.test(
arch=[qd.cpu],
require=qd.extension.assertion,
debug=True,
check_out_of_bound=True,
gdb_trigger=False,
)
def test_ndarray_oob_cpu_raises_not_segfaults():
"""Out-of-bounds ndarray access in a parallel kernel on CPU should raise
QuadrantsAssertionError instead of segfaulting."""
arr = qd.ndarray(dtype=qd.f32, shape=(4,))

@qd.kernel
def write_oob(a: qd.types.ndarray(dtype=qd.f32, ndim=1)):
for i in range(10):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this test is valid, because this si a paallel loop? we should make ti serial I think?

a[i] = 1.0

with pytest.raises(AssertionError, match=r"Out of bound access"):
write_oob(arr)


@test_utils.test(
arch=[qd.cpu],
require=qd.extension.assertion,
debug=True,
check_out_of_bound=True,
gdb_trigger=False,
)
def test_ndarray_oob_cpu_small_array():
"""Reproduces the pattern from the temperature-sensor segfault: a kernel
accesses a very small (shape-1) array with an index that goes out of
bounds. Before the longjmp fix this would SIGSEGV on CPU in debug mode."""
small = qd.ndarray(dtype=qd.f32, shape=(1,))
small.fill(42.0)

@qd.kernel
def read_oob(a: qd.types.ndarray(dtype=qd.f32, ndim=1)) -> qd.f32:
return a[5]

with pytest.raises(AssertionError, match=r"Out of bound access"):
read_oob(small)


@test_utils.test(
arch=[qd.cpu],
require=qd.extension.assertion,
debug=True,
check_out_of_bound=True,
gdb_trigger=False,
)
def test_ndarray_oob_cpu_2d():
"""2D ndarray out-of-bounds on CPU should produce a clear error."""
arr = qd.ndarray(dtype=qd.f32, shape=(3, 4))

@qd.kernel
def write_oob_2d(a: qd.types.ndarray(dtype=qd.f32, ndim=2)):
for i in range(1):
a[10, 0] = 1.0

with pytest.raises(AssertionError, match=r"Out of bound access"):
write_oob_2d(arr)


@test_utils.test(
arch=[qd.cpu],
require=qd.extension.assertion,
debug=True,
check_out_of_bound=True,
gdb_trigger=False,
)
def test_ndarray_inbounds_cpu_still_works():
"""Verify that the setjmp/longjmp mechanism does not break normal
in-bounds ndarray access."""
n = 8
arr = qd.ndarray(dtype=qd.f32, shape=(n,))

@qd.kernel
def fill(a: qd.types.ndarray(dtype=qd.f32, ndim=1)):
for i in range(n):
a[i] = qd.cast(i * 10, qd.f32)

fill(arr)
result = arr.to_numpy()
for i in range(n):
assert result[i] == pytest.approx(i * 10)
Loading