Skip to content

Commit 8387306

Browse files
_get_f_placeholder_args: set ForceValueArgTag (#304)
* _get_f_placeholder_args: set ForceValueArgTag * Update requirements.txt * skip scalar arg handling in _args_to_device_buffers * Revert changes to requirements.txt * add a simple test * Fix test --------- Co-authored-by: Andreas Klöckner <inform@tiker.net>
1 parent 029026c commit 8387306

2 files changed

Lines changed: 32 additions & 4 deletions

File tree

arraycontext/impl/pytato/compile.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
218218
:attr:`BaseLazilyCompilingFunctionCaller.f`.
219219
"""
220220
if np.isscalar(arg):
221+
from pytato.tags import ForceValueArgTag
221222
name = arg_id_to_name[kw,]
222-
return pt.make_placeholder(name, (), np.dtype(type(arg)))
223+
return pt.make_placeholder(name, (), np.dtype(type(arg)),
224+
tags=frozenset({ForceValueArgTag()}))
223225
elif isinstance(arg, pt.Array):
224226
name = arg_id_to_name[kw,]
225227
# Transform the DAG to give metadata inference a chance to do its job
@@ -533,9 +535,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
533535
for arg_id, arg in arg_id_to_arg.items():
534536
if np.isscalar(arg):
535537
if isinstance(actx, PytatoPyOpenCLArrayContext):
536-
import pyopencl.array as cla
537-
arg = cla.to_device(actx.queue, np.array(arg),
538-
allocator=actx.allocator)
538+
# Scalar kernel args are passed as lp.ValueArgs
539+
pass
539540
elif isinstance(actx, PytatoJAXArrayContext):
540541
import jax
541542
arg = jax.device_put(arg)

test/test_pytato_arraycontext.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,33 @@ def test_transfer(actx_factory):
247247
# }}}
248248

249249

250+
def test_pass_args_compiled_func(actx_factory):
251+
import numpy as np
252+
253+
import loopy as lp
254+
import pyopencl as cl
255+
import pyopencl.array
256+
import pytato as pt
257+
258+
def twice(x, y, a):
259+
return 2 * x * y * a
260+
261+
actx = _PytatoPyOpenCLArrayContextForTests(actx_factory().queue)
262+
263+
dev_scalar = pt.make_data_wrapper(cl.array.to_device(actx.queue, np.float64(23)))
264+
265+
f = actx.compile(twice)
266+
267+
assert actx.to_numpy(f(99.0, np.float64(2.0), dev_scalar)) == 2*23*99*2
268+
269+
compiled_func, = f.program_cache.values()
270+
ep = compiled_func.pytato_program.program.t_unit.default_entrypoint
271+
272+
assert isinstance(ep.arg_dict["_actx_in_0"], lp.ValueArg)
273+
assert isinstance(ep.arg_dict["_actx_in_1"], lp.ValueArg)
274+
assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg)
275+
276+
250277
if __name__ == "__main__":
251278
import sys
252279
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)