Skip to content

Commit c9c3489

Browse files
committed
Revert "upgrade 'unevaluated array as argument' warning to error (inducer#305)"
This reverts commit 4aeaed4.
1 parent 0bdcdf6 commit c9c3489

5 files changed

Lines changed: 40 additions & 43 deletions

File tree

arraycontext/impl/jax/fake_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _empty_like(array):
8080

8181
def zeros_like(self, ary):
8282
def _zeros_like(array):
83-
return self._array_context.np.zeros(array.shape, array.dtype)
83+
return self._array_context.zeros(array.shape, array.dtype)
8484

8585
return self._array_context._rec_map_container(
8686
_zeros_like, ary, default_scalar=0)

arraycontext/impl/pytato/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
535535

536536
pt_prg = pt.generate_loopy(transformed_dag,
537537
options=opts,
538+
cl_device=self.queue.device,
538539
function_name=function_name,
539540
target=self.get_target()
540541
).bind_to_context(self.context)

arraycontext/impl/pytato/compile.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
529529
return pytato_program, name_in_program_to_tags, name_in_program_to_axes
530530

531531

532-
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg,
533-
fn_name="<unknown>"):
532+
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
534533
input_kwargs_for_loopy = {}
535534

536535
for arg_id, arg in arg_id_to_arg.items():
@@ -551,20 +550,32 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg,
551550
# got a frozen array => do nothing
552551
pass
553552
elif isinstance(arg, pt.Array):
554-
# got an array expression => abort
555-
raise ValueError(
556-
f"Argument '{arg_id}' to the '{fn_name}' compiled function is a"
557-
" pytato array expression. Evaluating it just-in-time"
558-
" potentially causes a significant overhead on each call to the"
559-
" function and is therefore unsupported. "
560-
)
553+
# got an array expression => evaluate it
554+
from warnings import warn
555+
warn(f"Argument array '{arg_id}' to a compiled function is "
556+
"unevaluated. Evaluating just-in-time, at "
557+
"considerable expense. This is deprecated and will stop "
558+
"working in 2023. To avoid this warning, force evaluation "
559+
"of all arguments via freeze/thaw.",
560+
DeprecationWarning, stacklevel=4)
561+
562+
arg = actx.freeze(arg)
561563
else:
562564
raise NotImplementedError(type(arg))
563565

564566
input_kwargs_for_loopy[input_id_to_name_in_program[arg_id]] = arg
565567

566568
return input_kwargs_for_loopy
567569

570+
571+
def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
572+
from warnings import warn
573+
warn("_args_to_cl_buffer has been renamed to"
574+
" _args_to_device_buffers. This will be"
575+
" an error in 2023.", DeprecationWarning, stacklevel=2)
576+
return _args_to_device_buffers(actx, input_id_to_name_in_program,
577+
arg_id_to_arg)
578+
568579
# }}}
569580

570581

@@ -620,7 +631,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
620631
type of the callable.
621632
"""
622633
actx: PytatoPyOpenCLArrayContext
623-
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
634+
pytato_program: pt.target.BoundProgram
624635
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
625636
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
626637
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
@@ -631,10 +642,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
631642
from .utils import get_cl_axes_from_pt_axes
632643
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
633644

634-
fn_name = self.pytato_program.program.entrypoint
635-
636645
input_kwargs_for_loopy = _args_to_device_buffers(
637-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
646+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
638647

639648
evt, out_dict = self.pytato_program(queue=self.actx.queue,
640649
allocator=self.actx.allocator,
@@ -665,7 +674,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
665674
Name of the output array in the program.
666675
"""
667676
actx: PytatoPyOpenCLArrayContext
668-
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
677+
pytato_program: pt.target.BoundProgram
669678
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
670679
output_tags: frozenset[Tag]
671680
output_axes: tuple[pt.Axis, ...]
@@ -675,10 +684,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
675684
from .utils import get_cl_axes_from_pt_axes
676685
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
677686

678-
fn_name = self.pytato_program.program.entrypoint
679-
680687
input_kwargs_for_loopy = _args_to_device_buffers(
681-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
688+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
682689

683690
evt, out_dict = self.pytato_program(queue=self.actx.queue,
684691
allocator=self.actx.allocator,
@@ -716,18 +723,16 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
716723
type of the callable.
717724
"""
718725
actx: PytatoJAXArrayContext
719-
pytato_program: pt.target.python.BoundJAXPythonProgram
726+
pytato_program: pt.target.BoundProgram
720727
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
721728
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
722729
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
723730
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
724731
output_template: ArrayContainer
725732

726733
def __call__(self, arg_id_to_arg) -> ArrayContainer:
727-
fn_name = self.pytato_program.entrypoint
728-
729734
input_kwargs_for_loopy = _args_to_device_buffers(
730-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
735+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
731736

732737
out_dict = self.pytato_program(**input_kwargs_for_loopy)
733738

@@ -749,17 +754,15 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
749754
Name of the output array in the program.
750755
"""
751756
actx: PytatoJAXArrayContext
752-
pytato_program: pt.target.python.BoundJAXPythonProgram
757+
pytato_program: pt.target.BoundProgram
753758
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
754759
output_tags: frozenset[Tag]
755760
output_axes: tuple[pt.Axis, ...]
756761
output_name: str
757762

758763
def __call__(self, arg_id_to_arg) -> ArrayContainer:
759-
fn_name = self.pytato_program.entrypoint
760-
761764
input_kwargs_for_loopy = _args_to_device_buffers(
762-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
765+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
763766

764767
_evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)
765768

test/test_arraycontext.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,7 @@ def test_actx_compile_kwargs(actx_factory):
11491149
def test_actx_compile_with_tuple_output_keys(actx_factory):
11501150
# arraycontext.git<=3c9aee68 would fail due to a bug in output
11511151
# key stringification logic.
1152+
from arraycontext import from_numpy, to_numpy
11521153
actx = actx_factory()
11531154
rng = np.random.default_rng()
11541155

@@ -1162,11 +1163,11 @@ def my_rhs(scale, vel):
11621163
v_x = rng.uniform(size=10)
11631164
v_y = rng.uniform(size=10)
11641165

1165-
vel = actx.from_numpy(Velocity2D(v_x, v_y, actx))
1166+
vel = from_numpy(Velocity2D(v_x, v_y, actx), actx)
11661167

11671168
scaled_speed = compiled_rhs(3.14, vel=vel)
11681169

1169-
result = actx.to_numpy(scaled_speed)[0, 0]
1170+
result = to_numpy(scaled_speed, actx)[0, 0]
11701171
np.testing.assert_allclose(result.u, -3.14*v_y)
11711172
np.testing.assert_allclose(result.v, 3.14*v_x)
11721173

@@ -1292,8 +1293,6 @@ class ArrayContainerWithNumpy:
12921293
u: np.ndarray
12931294
v: DOFArray
12941295

1295-
__array_ufunc__ = None
1296-
12971296

12981297
def test_array_container_with_numpy(actx_factory):
12991298
actx = actx_factory()
@@ -1412,16 +1411,14 @@ def test_compile_anonymous_function(actx_factory):
14121411

14131412
# See https://github.com/inducer/grudge/issues/287
14141413
actx = actx_factory()
1415-
1416-
ones = actx.thaw(actx.freeze(
1417-
actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
1418-
))
1419-
14201414
f = actx.compile(lambda x: 2*x+40)
1421-
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)
1422-
1415+
np.testing.assert_allclose(
1416+
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
1417+
42)
14231418
f = actx.compile(partial(lambda x: 2*x+40))
1424-
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)
1419+
np.testing.assert_allclose(
1420+
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
1421+
42)
14251422

14261423

14271424
@pytest.mark.parametrize(

test/testlib.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def array_context(self):
160160

161161
@with_container_arithmetic(
162162
bcasts_across_obj_array=False,
163-
container_types_bcast_across=(DOFArray, np.ndarray),
163+
bcast_container_types=(DOFArray, np.ndarray),
164164
matmul=True,
165165
rel_comparison=True,
166166
_cls_has_array_context_attr=True,
@@ -173,8 +173,6 @@ class MyContainerDOFBcast:
173173
momentum: np.ndarray
174174
enthalpy: DOFArray | np.ndarray
175175

176-
__array_ufunc__ = None
177-
178176
@property
179177
def array_context(self):
180178
if isinstance(self.mass, np.ndarray):
@@ -211,8 +209,6 @@ class Velocity2D:
211209
v: ArrayContainer
212210
array_context: ArrayContext
213211

214-
__array_ufunc__ = None
215-
216212

217213
@with_array_context.register(Velocity2D)
218214
# https://github.com/python/mypy/issues/13040

0 commit comments

Comments
 (0)