diff --git a/src/wrap_cl_part_2.cpp b/src/wrap_cl_part_2.cpp index 7d8c71fe0..e87c672d2 100644 --- a/src/wrap_cl_part_2.cpp +++ b/src/wrap_cl_part_2.cpp @@ -605,7 +605,10 @@ void pyopencl_expose_part_2(py::module_ &m) { knl.set_arg_buf_pack(i, typechar, arg); }, indices_chars_and_args); }) - .DEF_SIMPLE_METHOD(set_arg) + .def("set_arg", &cls::set_arg, + py::arg("arg_index"), + py::arg("arg").none(true) + ) #if PYOPENCL_CL_VERSION >= 0x1020 .DEF_SIMPLE_METHOD(get_arg_info) #endif diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 5bd31c0a4..be2e9de47 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -1560,6 +1560,21 @@ def test_buffer_release(ctx_factory: cl.CtxFactory): b.release() +def test_set_arg_none(ctx_factory: cl.CtxFactory): + # https://github.com/inducer/pyopencl/issues/897 + ctx = ctx_factory() + prg = cl.Program(ctx, """ + __kernel void sum( + __global const float *a_g, __global const float *b_g, __global float *res_g) + { + int gid = get_global_id(0); + res_g[gid] = a_g[gid] + b_g[gid]; + } + """).build() + + prg.sum.set_arg(0, None) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: